aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--immutables/_map.c21
-rw-r--r--immutables/map.py7
-rw-r--r--tests/test_map.py21
3 files changed, 48 insertions, 1 deletions
diff --git a/immutables/_map.c b/immutables/_map.c
index 1abf730..9f57582 100644
--- a/immutables/_map.c
+++ b/immutables/_map.c
@@ -3850,7 +3850,6 @@ mapmut_tp_richcompare(PyObject *v, PyObject *w, int op)
}
}
-
static PyObject *
mapmut_py_finalize(MapMutationObject *self, PyObject *args)
{
@@ -3868,6 +3867,24 @@ mapmut_py_finalize(MapMutationObject *self, PyObject *args)
return (PyObject *)o;
}
+static PyObject *
+mapmut_py_enter(MapMutationObject *self, PyObject *args)
+{
+ Py_INCREF(self);
+ return (PyObject *)self;
+}
+
+static PyObject *
+mapmut_py_exit(MapMutationObject *self, PyObject *args)
+{
+ PyObject *ret = mapmut_py_finalize(self, NULL);
+ if (ret == NULL) {
+ return NULL;
+ }
+ Py_DECREF(ret);
+ Py_RETURN_FALSE;
+}
+
static int
mapmut_tp_ass_sub(MapMutationObject *self, PyObject *key, PyObject *val)
{
@@ -3951,6 +3968,8 @@ static PyMethodDef MapMutation_methods[] = {
{"get", (PyCFunction)map_py_get, METH_VARARGS, NULL},
{"pop", (PyCFunction)mapmut_py_pop, METH_VARARGS, NULL},
{"finalize", (PyCFunction)mapmut_py_finalize, METH_NOARGS, NULL},
+ {"__enter__", (PyCFunction)mapmut_py_enter, METH_NOARGS, NULL},
+ {"__exit__", (PyCFunction)mapmut_py_exit, METH_VARARGS, NULL},
{NULL, NULL}
};
diff --git a/immutables/map.py b/immutables/map.py
index abfe9ed..bea7ad9 100644
--- a/immutables/map.py
+++ b/immutables/map.py
@@ -633,6 +633,13 @@ class MapMutation:
def set(self, key, val):
self[key] = val
+ def __enter__(self):
+ return self
+
+ def __exit__(self, *exc):
+ self.finalize()
+ return False
+
def __delitem__(self, key):
if self.__mutid == 0:
raise ValueError(f'mutation {self!r} has been finalized')
diff --git a/tests/test_map.py b/tests/test_map.py
index 4fc87e3..d3b11fb 100644
--- a/tests/test_map.py
+++ b/tests/test_map.py
@@ -1174,6 +1174,27 @@ class BaseMapTest:
with self.assertRaises(EqError):
mm.set(key2, 123)
+ def test_map_mut_14(self):
+ m = self.Map(a=1, b=2)
+
+ with m.mutate() as mm:
+ mm['z'] = 100
+ del mm['a']
+
+ self.assertEqual(mm.finalize(), self.Map(z=100, b=2))
+
+ def test_map_mut_15(self):
+ m = self.Map(a=1, b=2)
+
+ with self.assertRaises(ZeroDivisionError):
+ with m.mutate() as mm:
+ mm['z'] = 100
+ del mm['a']
+ 1 / 0
+
+ self.assertEqual(mm.finalize(), self.Map(z=100, b=2))
+ self.assertEqual(m, self.Map(a=1, b=2))
+
def test_map_mut_stress(self):
COLLECTION_SIZE = 7000
TEST_ITERS_EVERY = 647