diff options
-rw-r--r-- | immutables/_map.c | 66 | ||||
-rw-r--r-- | immutables/map.py | 61 | ||||
-rw-r--r-- | tests/test_map.py | 46 |
3 files changed, 160 insertions, 13 deletions
diff --git a/immutables/_map.c b/immutables/_map.c index e39e85f..9a81556 100644 --- a/immutables/_map.c +++ b/immutables/_map.c @@ -387,7 +387,7 @@ map_node_update(uint64_t mutid, static int -map_update_inplace(uint64_t mutid, MapObject *o, PyObject *src); +map_update_inplace(uint64_t mutid, BaseMapObject *o, PyObject *src); static MapObject * map_update(uint64_t mutid, MapObject *o, PyObject *src); @@ -2153,6 +2153,8 @@ map_node_assoc(MapNode *node, map_node_{nodetype}_assoc method. */ + *added_leaf = 0; + if (IS_BITMAP_NODE(node)) { return map_node_bitmap_assoc( (MapNode_Bitmap *)node, @@ -2892,10 +2894,27 @@ map_tp_init(MapObject *self, PyObject *args, PyObject *kwds) } if (arg != NULL) { - mutid = mutid_counter++; - if (map_update_inplace(mutid, self, arg)) { + if (Map_Check(arg)) { + MapObject *other = (MapObject *)arg; + + Py_INCREF(other->h_root); + Py_SETREF(self->h_root, other->h_root); + + self->h_count = other->h_count; + self->h_hash = other->h_hash; + } + else if (MapMutation_Check(arg)) { + PyErr_Format( + PyExc_TypeError, + "cannot create Maps from MapMutations"); return -1; } + else { + mutid = mutid_counter++; + if (map_update_inplace(mutid, (BaseMapObject *)self, arg)) { + return -1; + } + } } if (kwds != NULL) { @@ -2907,7 +2926,7 @@ map_tp_init(MapObject *self, PyObject *args, PyObject *kwds) mutid = mutid_counter++; } - if (map_update_inplace(mutid, self, kwds)) { + if (map_update_inplace(mutid, (BaseMapObject *)self, kwds)) { return -1; } } @@ -3665,14 +3684,14 @@ map_node_update(uint64_t mutid, static int -map_update_inplace(uint64_t mutid, MapObject *o, PyObject *src) +map_update_inplace(uint64_t mutid, BaseMapObject *o, PyObject *src) { MapNode *new_root = NULL; Py_ssize_t new_count; int ret = map_node_update( mutid, src, - o->h_root, o->h_count, + o->b_root, o->b_count, &new_root, &new_count); if (ret) { @@ -3681,8 +3700,8 @@ map_update_inplace(uint64_t mutid, MapObject *o, PyObject *src) assert(new_root); - Py_SETREF(o->h_root, new_root); - o->h_count = new_count; + Py_SETREF(o->b_root, new_root); + o->b_count = new_count; return 0; } @@ -3853,6 +3872,35 @@ mapmut_tp_richcompare(PyObject *v, PyObject *w, int op) } static PyObject * +mapmut_py_update(MapMutationObject *self, PyObject *args, PyObject *kwds) +{ + PyObject *arg = NULL; + + if (!PyArg_UnpackTuple(args, "update", 0, 1, &arg)) { + return NULL; + } + + if (arg != NULL) { + if (map_update_inplace(self->m_mutid, (BaseMapObject *)self, arg)) { + return NULL; + } + } + + if (kwds != NULL) { + if (!PyArg_ValidateKeywordArguments(kwds)) { + return NULL; + } + + if (map_update_inplace(self->m_mutid, (BaseMapObject *)self, kwds)) { + return NULL; + } + } + + Py_RETURN_NONE; +} + + +static PyObject * mapmut_py_finalize(MapMutationObject *self, PyObject *args) { self->m_mutid = 0; @@ -3970,6 +4018,8 @@ static PyMethodDef MapMutation_methods[] = { {"get", (PyCFunction)map_py_get, METH_VARARGS, NULL}, {"pop", (PyCFunction)mapmut_py_pop, METH_VARARGS, NULL}, {"finish", (PyCFunction)mapmut_py_finalize, METH_NOARGS, NULL}, + {"update", (PyCFunction)mapmut_py_update, + METH_VARARGS | METH_KEYWORDS, NULL}, {"__enter__", (PyCFunction)mapmut_py_enter, METH_NOARGS, NULL}, {"__exit__", (PyCFunction)mapmut_py_exit, METH_VARARGS, NULL}, {NULL, NULL} diff --git a/immutables/map.py b/immutables/map.py index fc40ac8..7b230dd 100644 --- a/immutables/map.py +++ b/immutables/map.py @@ -438,6 +438,14 @@ class Map: self.__root = BitmapNode(0, 0, [], 0) self.__hash = -1 + if isinstance(col, Map): + self.__count = col.__count + self.__root = col.__root + self.__hash = col.__hash + col = None + elif isinstance(col, MapMutation): + raise TypeError('cannot create Maps from MapMutations') + if col or kw: init = self.update(col, **kw) self.__count = init.__count @@ -640,6 +648,9 @@ class MapMutation: self.finish() return False + def __iter__(self): + raise TypeError('{} is not iterable'.format(type(self))) + def __delitem__(self, key): if self.__mutid == 0: raise ValueError('mutation {!r} has been finished'.format(self)) @@ -707,6 +718,56 @@ class MapMutation: else: return True + def update(self, col=None, **kw): + it = None + if col is not None: + if hasattr(col, 'items'): + it = iter(col.items()) + else: + it = iter(col) + + if it is not None: + if kw: + it = iter(itertools.chain(it, kw.items())) + else: + if kw: + it = iter(kw.items()) + + if it is None: + + return self + + root = self.__root + count = self.__count + + i = 0 + while True: + try: + tup = next(it) + except StopIteration: + break + + try: + tup = tuple(tup) + except TypeError: + raise TypeError( + 'cannot convert map update ' + 'sequence element #{} to a sequence'.format(i)) from None + key, val, *r = tup + if r: + raise ValueError( + 'map update sequence element #{} has length ' + '{}; 2 is required'.format(i, len(r) + 2)) + + root, added = root.assoc(0, map_hash(key), key, val, self.__mutid) + if added: + count += 1 + + i += 1 + + self.__root = root + self.__count = count + def finish(self): self.__mutid = 0 return Map._new(self.__count, self.__root) diff --git a/tests/test_map.py b/tests/test_map.py index 959e00c..66d07e7 100644 --- a/tests/test_map.py +++ b/tests/test_map.py @@ -1089,11 +1089,6 @@ class BaseMapTest: with self.assertRaises(HashingError): self.Map(src) - src = self.Map({key1: 123}) - with HashKeyCrasher(error_on_hash=True): - with self.assertRaises(HashingError): - self.Map(src) - src = [(1, 2), (key1, 123)] with HashKeyCrasher(error_on_hash=True): with self.assertRaises(HashingError): @@ -1195,6 +1190,47 @@ class BaseMapTest: self.assertEqual(mm.finish(), self.Map(z=100, b=2)) self.assertEqual(m, self.Map(a=1, b=2)) + def test_map_mut_16(self): + m = self.Map(a=1, b=2) + hash(m) + + m2 = self.Map(m) + m3 = self.Map(m, c=3) + + self.assertEqual(m, m2) + self.assertEqual(len(m), len(m2)) + self.assertEqual(hash(m), hash(m2)) + + self.assertIsNot(m, m2) + self.assertEqual(m3, self.Map(a=1, b=2, c=3)) + + def test_map_mut_17(self): + m = self.Map(a=1) + with m.mutate() as mm: + with self.assertRaisesRegex( + TypeError, 'cannot create Maps from MapMutations'): + self.Map(mm) + + def test_map_mut_18(self): + m = self.Map(a=1, b=2) + with m.mutate() as mm: + mm.update(self.Map(x=1), z=2) + mm.update(c=3) + mm.update({'n': 100, 'a': 20}) + m2 = mm.finish() + + expected = self.Map( + {'b': 2, 'c': 3, 'n': 100, 'z': 2, 'x': 1, 'a': 20}) + + self.assertEqual(len(m2), 6) + self.assertEqual(m2, expected) + self.assertEqual(m, self.Map(a=1, b=2)) + + def test_map_mut_19(self): + m = self.Map(a=1, b=2) + m2 = m.update({'a': 20}) + self.assertEqual(len(m2), 2) + def test_map_mut_stress(self): COLLECTION_SIZE = 7000 TEST_ITERS_EVERY = 647 |