aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--immutables/_map.c37
-rw-r--r--immutables/map.py3
-rw-r--r--tests/test_map.py10
3 files changed, 50 insertions, 0 deletions
diff --git a/immutables/_map.c b/immutables/_map.c
index 8064754..bf6b640 100644
--- a/immutables/_map.c
+++ b/immutables/_map.c
@@ -3306,6 +3306,42 @@ map_py_hash(MapObject *self)
return self->h_hash;
}
+static PyObject *
+map_reduce(MapObject *self)
+{
+ MapIteratorState iter;
+ map_iter_t iter_res;
+
+ PyObject *dict = PyDict_New();
+ if (dict == NULL) {
+ return NULL;
+ }
+
+ map_iterator_init(&iter, self->h_root);
+ do {
+ PyObject *key;
+ PyObject *val;
+
+ iter_res = map_iterator_next(&iter, &key, &val);
+ if (iter_res == I_ITEM) {
+ if (PyDict_SetItem(dict, key, val) < 0) {
+ Py_DECREF(dict);
+ return NULL;
+ }
+ }
+ } while (iter_res != I_END);
+
+ PyObject *args = PyTuple_Pack(1, dict);
+ Py_DECREF(dict);
+ if (args == NULL) {
+ return NULL;
+ }
+
+ PyObject *tup = PyTuple_Pack(2, Py_TYPE(self), args);
+ Py_DECREF(args);
+ return tup;
+}
+
static PyMethodDef Map_methods[] = {
{"set", (PyCFunction)map_py_set, METH_VARARGS, NULL},
@@ -3316,6 +3352,7 @@ static PyMethodDef Map_methods[] = {
{"keys", (PyCFunction)map_py_keys, METH_NOARGS, NULL},
{"values", (PyCFunction)map_py_values, METH_NOARGS, NULL},
{"update", (PyCFunction)map_py_update, METH_VARARGS | METH_KEYWORDS, NULL},
+ {"__reduce__", (PyCFunction)map_reduce, METH_NOARGS, NULL},
{"__dump__", (PyCFunction)map_py_dump, METH_NOARGS, NULL},
{NULL, NULL}
};
diff --git a/immutables/map.py b/immutables/map.py
index f498d38..c6dd15d 100644
--- a/immutables/map.py
+++ b/immutables/map.py
@@ -450,6 +450,9 @@ class Map:
m.__hash = -1
return m
+ def __reduce__(self):
+ return (type(self), (dict(self.items()),))
+
def __len__(self):
return self.__count
diff --git a/tests/test_map.py b/tests/test_map.py
index 13c6cb1..bbf0a52 100644
--- a/tests/test_map.py
+++ b/tests/test_map.py
@@ -1,5 +1,6 @@
import collections.abc
import gc
+import pickle
import random
import unittest
import weakref
@@ -1148,6 +1149,15 @@ class BaseMapTest:
self.assertEqual(dict(h.items()), d)
self.assertEqual(len(h), len(d))
+ def test_map_pickle(self):
+ h = self.Map(a=1, b=2)
+ for proto in range(pickle.HIGHEST_PROTOCOL):
+ p = pickle.dumps(h, proto)
+ uh = pickle.loads(p)
+
+ self.assertTrue(isinstance(uh, self.Map))
+ self.assertEqual(h, uh)
+
class PyMapTest(BaseMapTest, unittest.TestCase):