From 4276e0c82cb0c2f7f7858063a2fe8bdd8b4240cf Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Mon, 19 Nov 2018 22:29:06 -0500 Subject: Implement pickle support --- immutables/_map.c | 37 +++++++++++++++++++++++++++++++++++++ immutables/map.py | 3 +++ tests/test_map.py | 10 ++++++++++ 3 files changed, 50 insertions(+) diff --git a/immutables/_map.c b/immutables/_map.c index 8064754..bf6b640 100644 --- a/immutables/_map.c +++ b/immutables/_map.c @@ -3306,6 +3306,42 @@ map_py_hash(MapObject *self) return self->h_hash; } +static PyObject * +map_reduce(MapObject *self) +{ + MapIteratorState iter; + map_iter_t iter_res; + + PyObject *dict = PyDict_New(); + if (dict == NULL) { + return NULL; + } + + map_iterator_init(&iter, self->h_root); + do { + PyObject *key; + PyObject *val; + + iter_res = map_iterator_next(&iter, &key, &val); + if (iter_res == I_ITEM) { + if (PyDict_SetItem(dict, key, val) < 0) { + Py_DECREF(dict); + return NULL; + } + } + } while (iter_res != I_END); + + PyObject *args = PyTuple_Pack(1, dict); + Py_DECREF(dict); + if (args == NULL) { + return NULL; + } + + PyObject *tup = PyTuple_Pack(2, Py_TYPE(self), args); + Py_DECREF(args); + return tup; +} + static PyMethodDef Map_methods[] = { {"set", (PyCFunction)map_py_set, METH_VARARGS, NULL}, @@ -3316,6 +3352,7 @@ static PyMethodDef Map_methods[] = { {"keys", (PyCFunction)map_py_keys, METH_NOARGS, NULL}, {"values", (PyCFunction)map_py_values, METH_NOARGS, NULL}, {"update", (PyCFunction)map_py_update, METH_VARARGS | METH_KEYWORDS, NULL}, + {"__reduce__", (PyCFunction)map_reduce, METH_NOARGS, NULL}, {"__dump__", (PyCFunction)map_py_dump, METH_NOARGS, NULL}, {NULL, NULL} }; diff --git a/immutables/map.py b/immutables/map.py index f498d38..c6dd15d 100644 --- a/immutables/map.py +++ b/immutables/map.py @@ -450,6 +450,9 @@ class Map: m.__hash = -1 return m + def __reduce__(self): + return (type(self), (dict(self.items()),)) + def __len__(self): return self.__count diff --git a/tests/test_map.py b/tests/test_map.py index 13c6cb1..bbf0a52 100644 --- a/tests/test_map.py +++ b/tests/test_map.py @@ -1,5 +1,6 @@ import collections.abc import gc +import pickle import random import unittest import weakref @@ -1148,6 +1149,15 @@ class BaseMapTest: self.assertEqual(dict(h.items()), d) self.assertEqual(len(h), len(d)) + def test_map_pickle(self): + h = self.Map(a=1, b=2) + for proto in range(pickle.HIGHEST_PROTOCOL): + p = pickle.dumps(h, proto) + uh = pickle.loads(p) + + self.assertTrue(isinstance(uh, self.Map)) + self.assertEqual(h, uh) + class PyMapTest(BaseMapTest, unittest.TestCase): -- cgit v1.2.3