diff options
author | Yury Selivanov <yury@magic.io> | 2018-11-20 13:15:08 -0500 |
---|---|---|
committer | Yury Selivanov <yury@magic.io> | 2018-11-20 14:25:07 -0500 |
commit | 2a14bc7bb3e523d74c9abe5246303976c013d91a (patch) | |
tree | 050b86c11521580d44af1d8d09132467001f2e85 | |
parent | 24c575b6eec2abe396ece52460672d26e01ad284 (diff) | |
download | immutables-2a14bc7bb3e523d74c9abe5246303976c013d91a.tar.gz immutables-2a14bc7bb3e523d74c9abe5246303976c013d91a.zip |
Make MapMutation a context manager
-rw-r--r-- | immutables/_map.c | 21 | ||||
-rw-r--r-- | immutables/map.py | 7 | ||||
-rw-r--r-- | tests/test_map.py | 21 |
3 files changed, 48 insertions, 1 deletions
diff --git a/immutables/_map.c b/immutables/_map.c index 1abf730..9f57582 100644 --- a/immutables/_map.c +++ b/immutables/_map.c @@ -3850,7 +3850,6 @@ mapmut_tp_richcompare(PyObject *v, PyObject *w, int op) } } - static PyObject * mapmut_py_finalize(MapMutationObject *self, PyObject *args) { @@ -3868,6 +3867,24 @@ mapmut_py_finalize(MapMutationObject *self, PyObject *args) return (PyObject *)o; } +static PyObject * +mapmut_py_enter(MapMutationObject *self, PyObject *args) +{ + Py_INCREF(self); + return (PyObject *)self; +} + +static PyObject * +mapmut_py_exit(MapMutationObject *self, PyObject *args) +{ + PyObject *ret = mapmut_py_finalize(self, NULL); + if (ret == NULL) { + return NULL; + } + Py_DECREF(ret); + Py_RETURN_FALSE; +} + static int mapmut_tp_ass_sub(MapMutationObject *self, PyObject *key, PyObject *val) { @@ -3951,6 +3968,8 @@ static PyMethodDef MapMutation_methods[] = { {"get", (PyCFunction)map_py_get, METH_VARARGS, NULL}, {"pop", (PyCFunction)mapmut_py_pop, METH_VARARGS, NULL}, {"finalize", (PyCFunction)mapmut_py_finalize, METH_NOARGS, NULL}, + {"__enter__", (PyCFunction)mapmut_py_enter, METH_NOARGS, NULL}, + {"__exit__", (PyCFunction)mapmut_py_exit, METH_VARARGS, NULL}, {NULL, NULL} }; diff --git a/immutables/map.py b/immutables/map.py index abfe9ed..bea7ad9 100644 --- a/immutables/map.py +++ b/immutables/map.py @@ -633,6 +633,13 @@ class MapMutation: def set(self, key, val): self[key] = val + def __enter__(self): + return self + + def __exit__(self, *exc): + self.finalize() + return False + def __delitem__(self, key): if self.__mutid == 0: raise ValueError(f'mutation {self!r} has been finalized') diff --git a/tests/test_map.py b/tests/test_map.py index 4fc87e3..d3b11fb 100644 --- a/tests/test_map.py +++ b/tests/test_map.py @@ -1174,6 +1174,27 @@ class BaseMapTest: with self.assertRaises(EqError): mm.set(key2, 123) + def test_map_mut_14(self): + m = self.Map(a=1, b=2) + + with m.mutate() as mm: + mm['z'] = 100 + del mm['a'] + + self.assertEqual(mm.finalize(), self.Map(z=100, b=2)) + + def test_map_mut_15(self): + m = self.Map(a=1, b=2) + + with self.assertRaises(ZeroDivisionError): + with m.mutate() as mm: + mm['z'] = 100 + del mm['a'] + 1 / 0 + + self.assertEqual(mm.finalize(), self.Map(z=100, b=2)) + self.assertEqual(m, self.Map(a=1, b=2)) + def test_map_mut_stress(self): COLLECTION_SIZE = 7000 TEST_ITERS_EVERY = 647 |