From 2a14bc7bb3e523d74c9abe5246303976c013d91a Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Tue, 20 Nov 2018 13:15:08 -0500 Subject: Make MapMutation a context manager --- immutables/_map.c | 21 ++++++++++++++++++++- immutables/map.py | 7 +++++++ tests/test_map.py | 21 +++++++++++++++++++++ 3 files changed, 48 insertions(+), 1 deletion(-) diff --git a/immutables/_map.c b/immutables/_map.c index 1abf730..9f57582 100644 --- a/immutables/_map.c +++ b/immutables/_map.c @@ -3850,7 +3850,6 @@ mapmut_tp_richcompare(PyObject *v, PyObject *w, int op) } } - static PyObject * mapmut_py_finalize(MapMutationObject *self, PyObject *args) { @@ -3868,6 +3867,24 @@ mapmut_py_finalize(MapMutationObject *self, PyObject *args) return (PyObject *)o; } +static PyObject * +mapmut_py_enter(MapMutationObject *self, PyObject *args) +{ + Py_INCREF(self); + return (PyObject *)self; +} + +static PyObject * +mapmut_py_exit(MapMutationObject *self, PyObject *args) +{ + PyObject *ret = mapmut_py_finalize(self, NULL); + if (ret == NULL) { + return NULL; + } + Py_DECREF(ret); + Py_RETURN_FALSE; +} + static int mapmut_tp_ass_sub(MapMutationObject *self, PyObject *key, PyObject *val) { @@ -3951,6 +3968,8 @@ static PyMethodDef MapMutation_methods[] = { {"get", (PyCFunction)map_py_get, METH_VARARGS, NULL}, {"pop", (PyCFunction)mapmut_py_pop, METH_VARARGS, NULL}, {"finalize", (PyCFunction)mapmut_py_finalize, METH_NOARGS, NULL}, + {"__enter__", (PyCFunction)mapmut_py_enter, METH_NOARGS, NULL}, + {"__exit__", (PyCFunction)mapmut_py_exit, METH_VARARGS, NULL}, {NULL, NULL} }; diff --git a/immutables/map.py b/immutables/map.py index abfe9ed..bea7ad9 100644 --- a/immutables/map.py +++ b/immutables/map.py @@ -633,6 +633,13 @@ class MapMutation: def set(self, key, val): self[key] = val + def __enter__(self): + return self + + def __exit__(self, *exc): + self.finalize() + return False + def __delitem__(self, key): if self.__mutid == 0: raise ValueError(f'mutation {self!r} has been finalized') diff --git a/tests/test_map.py b/tests/test_map.py index 4fc87e3..d3b11fb 100644 --- a/tests/test_map.py +++ b/tests/test_map.py @@ -1174,6 +1174,27 @@ class BaseMapTest: with self.assertRaises(EqError): mm.set(key2, 123) + def test_map_mut_14(self): + m = self.Map(a=1, b=2) + + with m.mutate() as mm: + mm['z'] = 100 + del mm['a'] + + self.assertEqual(mm.finalize(), self.Map(z=100, b=2)) + + def test_map_mut_15(self): + m = self.Map(a=1, b=2) + + with self.assertRaises(ZeroDivisionError): + with m.mutate() as mm: + mm['z'] = 100 + del mm['a'] + 1 / 0 + + self.assertEqual(mm.finalize(), self.Map(z=100, b=2)) + self.assertEqual(m, self.Map(a=1, b=2)) + def test_map_mut_stress(self): COLLECTION_SIZE = 7000 TEST_ITERS_EVERY = 647 -- cgit v1.2.3