summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYury Selivanov <yury@magic.io>2018-11-20 13:02:14 -0500
committerYury Selivanov <yury@magic.io>2018-11-20 14:25:07 -0500
commit24c575b6eec2abe396ece52460672d26e01ad284 (patch)
treea9f708d3d104f79b8b65648123a1b672ce7d292b
parent4276e0c82cb0c2f7f7858063a2fe8bdd8b4240cf (diff)
downloadimmutables-24c575b6eec2abe396ece52460672d26e01ad284.tar.gz
immutables-24c575b6eec2abe396ece52460672d26e01ad284.zip
Implement mutable mapping API for MapMutation; add after-finalize checks
-rw-r--r--immutables/_map.c214
-rw-r--r--immutables/map.py47
-rw-r--r--tests/test_map.py79
3 files changed, 271 insertions, 69 deletions
diff --git a/immutables/_map.c b/immutables/_map.c
index bf6b640..1abf730 100644
--- a/immutables/_map.c
+++ b/immutables/_map.c
@@ -37,7 +37,7 @@ Now let's partition this bit representation of the hash into blocks of
0b00_00000_10010_11101_00101_01011_10000 = 19830128
(6) (5) (4) (3) (2) (1)
-Each block of 5 bits represents a number betwen 0 and 31. So if we have
+Each block of 5 bits represents a number between 0 and 31. So if we have
a tree that consists of nodes, each of which is an array of 32 pointers,
those 5-bit blocks will encode a position on a single tree level.
@@ -885,7 +885,7 @@ map_node_bitmap_assoc(MapNode_Bitmap *self,
pairs.
Small map objects (<30 keys) usually don't have any
- Array nodes at all. Betwen ~30 and ~400 keys map
+ Array nodes at all. Between ~30 and ~400 keys map
objects usually have one Array node, and usually it's
a root node.
*/
@@ -2460,7 +2460,7 @@ map_without(MapObject *o, PyObject *key)
return NULL;
}
- MapNode *new_root;
+ MapNode *new_root = NULL;
map_without_t res = map_node_without(
(MapNode *)(o->h_root),
@@ -3715,31 +3715,74 @@ map_update(uint64_t mutid, MapObject *o, PyObject *src)
return new;
}
+static int
+mapmut_check_finalized(MapMutationObject *o)
+{
+ if (o->m_mutid == 0) {
+ PyErr_Format(
+ PyExc_ValueError,
+ "mutation %R has been finalized",
+ o, NULL);
+ return -1;
+ }
-static PyObject *
-mapmut_py_set(MapMutationObject *o, PyObject *args)
+ return 0;
+}
+
+static int
+mapmut_delete(MapMutationObject *o, PyObject *key, int32_t key_hash)
{
- PyObject *key;
- PyObject *val;
+ MapNode *new_root = NULL;
- if (!PyArg_UnpackTuple(args, "set", 2, 2, &key, &val)) {
- return NULL;
- }
+ assert(key_hash != -1);
+ map_without_t res = map_node_without(
+ (MapNode *)(o->m_root),
+ 0, key_hash, key,
+ &new_root,
+ o->m_mutid);
- int32_t key_hash;
- int added_leaf = 0;
+ switch (res) {
+ case W_ERROR:
+ return -1;
- key_hash = map_hash(key);
- if (key_hash == -1) {
- return NULL;
+ case W_EMPTY:
+ new_root = map_node_bitmap_new(0, o->m_mutid);
+ if (new_root == NULL) {
+ return -1;
+ }
+ Py_SETREF(o->m_root, new_root);
+ o->m_count = 0;
+ return 0;
+
+ case W_NOT_FOUND:
+ PyErr_SetObject(PyExc_KeyError, key);
+ return -1;
+
+ case W_NEWNODE: {
+ assert(new_root != NULL);
+ Py_SETREF(o->m_root, new_root);
+ o->m_count--;
+ return 0;
+ }
+
+ default:
+ abort();
}
+}
+
+static int
+mapmut_set(MapMutationObject *o, PyObject *key, int32_t key_hash,
+ PyObject *val)
+{
+ int added_leaf = 0;
+ assert(key_hash != -1);
MapNode *new_root = map_node_assoc(
(MapNode *)(o->m_root),
0, key_hash, key, val, &added_leaf,
o->m_mutid);
if (new_root == NULL) {
- return NULL;
+ return -1;
}
if (added_leaf) {
@@ -3748,62 +3791,39 @@ mapmut_py_set(MapMutationObject *o, PyObject *args)
if (new_root == o->m_root) {
Py_DECREF(new_root);
- goto done;
+ return 0;
}
Py_SETREF(o->m_root, new_root);
-
-done:
- Py_RETURN_NONE;
+ return 0;
}
-
static PyObject *
-mapmut_py_delete(MapMutationObject *o, PyObject *key)
+mapmut_py_set(MapMutationObject *o, PyObject *args)
{
- int32_t key_hash = map_hash(key);
- if (key_hash == -1) {
+ PyObject *key;
+ PyObject *val;
+
+ if (!PyArg_UnpackTuple(args, "set", 2, 2, &key, &val)) {
return NULL;
}
- MapNode *new_root;
-
- map_without_t res = map_node_without(
- (MapNode *)(o->m_root),
- 0, key_hash, key,
- &new_root,
- o->m_mutid);
+ if (mapmut_check_finalized(o)) {
+ return NULL;
+ }
- switch (res) {
- case W_ERROR:
- return NULL;
- case W_EMPTY:
- new_root = map_node_bitmap_new(0, o->m_mutid);
- if (new_root == NULL) {
- return NULL;
- }
- Py_SETREF(o->m_root, new_root);
- o->m_count = 0;
- goto done;
+ int32_t key_hash = map_hash(key);
+ if (key_hash == -1) {
+ return NULL;
+ }
- case W_NOT_FOUND:
- PyErr_SetObject(PyExc_KeyError, key);
- return NULL;
- case W_NEWNODE: {
- assert(new_root != NULL);
- Py_SETREF(o->m_root, new_root);
- o->m_count--;
- goto done;
- }
- default:
- abort();
+ if (mapmut_set(o, key, key_hash, val)) {
+ return NULL;
}
-done:
Py_RETURN_NONE;
}
-
static PyObject *
mapmut_tp_richcompare(PyObject *v, PyObject *w, int op)
{
@@ -3848,11 +3868,88 @@ mapmut_py_finalize(MapMutationObject *self, PyObject *args)
return (PyObject *)o;
}
+static int
+mapmut_tp_ass_sub(MapMutationObject *self, PyObject *key, PyObject *val)
+{
+ if (mapmut_check_finalized(self)) {
+ return -1;
+ }
+
+ int32_t key_hash = map_hash(key);
+ if (key_hash == -1) {
+ return -1;
+ }
+
+ if (val == NULL) {
+ return mapmut_delete(self, key, key_hash);
+ }
+ else {
+ return mapmut_set(self, key, key_hash, val);
+ }
+}
+
+static PyObject *
+mapmut_py_pop(MapMutationObject *self, PyObject *args)
+{
+ PyObject *key, *deflt = NULL, *val = NULL;
+
+ if(!PyArg_UnpackTuple(args, "pop", 1, 2, &key, &deflt)) {
+ return NULL;
+ }
+
+ if (mapmut_check_finalized(self)) {
+ return NULL;
+ }
+
+ if (!self->m_count) {
+ goto not_found;
+ }
+
+ int32_t key_hash = map_hash(key);
+ if (key_hash == -1) {
+ return NULL;
+ }
+
+ map_find_t find_res = map_node_find(self->m_root, 0, key_hash, key, &val);
+
+ switch (find_res) {
+ case F_ERROR:
+ return NULL;
+
+ case F_NOT_FOUND:
+ goto not_found;
+
+ case F_FOUND:
+ break;
+
+ default:
+ abort();
+ }
+
+ Py_INCREF(val);
+
+ if (mapmut_delete(self, key, key_hash)) {
+ Py_DECREF(val);
+ return NULL;
+ }
+
+ return val;
+
+not_found:
+ if (deflt) {
+ Py_INCREF(deflt);
+ return deflt;
+ }
+
+ PyErr_SetObject(PyExc_KeyError, key);
+ return NULL;
+}
+
static PyMethodDef MapMutation_methods[] = {
{"set", (PyCFunction)mapmut_py_set, METH_VARARGS, NULL},
{"get", (PyCFunction)map_py_get, METH_VARARGS, NULL},
- {"delete", (PyCFunction)mapmut_py_delete, METH_O, NULL},
+ {"pop", (PyCFunction)mapmut_py_pop, METH_VARARGS, NULL},
{"finalize", (PyCFunction)mapmut_py_finalize, METH_NOARGS, NULL},
{NULL, NULL}
};
@@ -3871,8 +3968,9 @@ static PySequenceMethods MapMutation_as_sequence = {
};
static PyMappingMethods MapMutation_as_mapping = {
- (lenfunc)map_tp_len, /* mp_length */
- (binaryfunc)map_tp_subscript, /* mp_subscript */
+ (lenfunc)map_tp_len, /* mp_length */
+ (binaryfunc)map_tp_subscript, /* mp_subscript */
+ (objobjargproc)mapmut_tp_ass_sub, /* mp_subscript */
};
PyTypeObject _MapMutation_Type = {
diff --git a/immutables/map.py b/immutables/map.py
index c6dd15d..abfe9ed 100644
--- a/immutables/map.py
+++ b/immutables/map.py
@@ -44,6 +44,7 @@ def map_bitindex(bitmap, bit):
W_EMPTY, W_NEWNODE, W_NOT_FOUND = range(3)
+void = object()
class BitmapNode:
@@ -630,16 +631,9 @@ class MapMutation:
self.__mutid = _mut_id()
def set(self, key, val):
- if self.__mutid == 0:
- raise ValueError(f'mutation {self!r} has been finalized')
-
- self.__root, added = self.__root.assoc(
- 0, map_hash(key), key, val, self.__mutid)
-
- if added:
- self.__count += 1
+ self[key] = val
- def delete(self, key):
+ def __delitem__(self, key):
if self.__mutid == 0:
raise ValueError(f'mutation {self!r} has been finalized')
@@ -654,6 +648,41 @@ class MapMutation:
self.__root = new_root
self.__count -= 1
+ def __setitem__(self, key, val):
+ if self.__mutid == 0:
+ raise ValueError(f'mutation {self!r} has been finalized')
+
+ self.__root, added = self.__root.assoc(
+ 0, map_hash(key), key, val, self.__mutid)
+
+ if added:
+ self.__count += 1
+
+ def pop(self, key, *args):
+ if self.__mutid == 0:
+ raise ValueError(f'mutation {self!r} has been finalized')
+
+ if len(args) > 1:
+ raise TypeError(
+ 'pop() accepts 1 to 2 positional arguments, '
+ 'got {}'.format(len(args) + 1))
+ elif len(args) == 1:
+ default = args[0]
+ else:
+ default = void
+
+ val = self.get(key, default)
+
+ try:
+ del self[key]
+ except KeyError:
+ if val is void:
+ raise
+ return val
+ else:
+ assert val is not void
+ return val
+
def get(self, key, default=None):
try:
return self.__root.find(0, map_hash(key), key)
diff --git a/tests/test_map.py b/tests/test_map.py
index bbf0a52..4fc87e3 100644
--- a/tests/test_map.py
+++ b/tests/test_map.py
@@ -992,7 +992,7 @@ class BaseMapTest:
hm2.set('a', 10)
self.assertEqual(hm1, hm2)
- hm2.delete('a')
+ self.assertEqual(hm2.pop('a'), 10)
self.assertNotEqual(hm1, hm2)
def test_map_mut_5(self):
@@ -1099,6 +1099,81 @@ class BaseMapTest:
with self.assertRaises(HashingError):
self.Map(src)
+ def test_map_mut_10(self):
+ key1 = HashKey(123, 'aaa')
+
+ m = self.Map({key1: 123})
+
+ mm = m.mutate()
+ with HashKeyCrasher(error_on_hash=True):
+ with self.assertRaises(HashingError):
+ del mm[key1]
+
+ mm = m.mutate()
+ with HashKeyCrasher(error_on_hash=True):
+ with self.assertRaises(HashingError):
+ mm.pop(key1, None)
+
+ mm = m.mutate()
+ with HashKeyCrasher(error_on_hash=True):
+ with self.assertRaises(HashingError):
+ mm.set(key1, 123)
+
+ def test_map_mut_11(self):
+ m = self.Map({'a': 1, 'b': 2})
+
+ mm = m.mutate()
+ self.assertEqual(mm.pop('a', 1), 1)
+ self.assertEqual(mm.finalize(), self.Map({'b': 2}))
+
+ mm = m.mutate()
+ self.assertEqual(mm.pop('b', 1), 2)
+ self.assertEqual(mm.finalize(), self.Map({'a': 1}))
+
+ mm = m.mutate()
+ self.assertEqual(mm.pop('b', 1), 2)
+ del mm['a']
+ self.assertEqual(mm.finalize(), self.Map())
+
+ def test_map_mut_12(self):
+ m = self.Map({'a': 1, 'b': 2})
+
+ mm = m.mutate()
+ mm.finalize()
+
+ with self.assertRaisesRegex(ValueError, 'has been finalized'):
+ mm.pop('a')
+
+ with self.assertRaisesRegex(ValueError, 'has been finalized'):
+ del mm['a']
+
+ with self.assertRaisesRegex(ValueError, 'has been finalized'):
+ mm.set('a', 'b')
+
+ with self.assertRaisesRegex(ValueError, 'has been finalized'):
+ mm['a'] = 'b'
+
+ def test_map_mut_13(self):
+ key1 = HashKey(123, 'aaa')
+ key2 = HashKey(123, 'aaa')
+
+ m = self.Map({key1: 123})
+
+ mm = m.mutate()
+ with HashKeyCrasher(error_on_eq=True):
+ with self.assertRaises(EqError):
+ del mm[key2]
+
+ mm = m.mutate()
+ with HashKeyCrasher(error_on_eq=True):
+ with self.assertRaises(EqError):
+ mm.pop(key2, None)
+
+ mm = m.mutate()
+ with HashKeyCrasher(error_on_eq=True):
+ with self.assertRaises(EqError):
+ mm.set(key2, 123)
+
def test_map_mut_stress(self):
COLLECTION_SIZE = 7000
TEST_ITERS_EVERY = 647
@@ -1138,7 +1213,7 @@ class BaseMapTest:
break
del d[key]
- hm.delete(key)
+ del hm[key]
self.assertEqual(len(hm), len(d))