aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--immutables/_map.c66
-rw-r--r--immutables/map.py61
-rw-r--r--tests/test_map.py46
3 files changed, 160 insertions, 13 deletions
diff --git a/immutables/_map.c b/immutables/_map.c
index e39e85f..9a81556 100644
--- a/immutables/_map.c
+++ b/immutables/_map.c
@@ -387,7 +387,7 @@ map_node_update(uint64_t mutid,
static int
-map_update_inplace(uint64_t mutid, MapObject *o, PyObject *src);
+map_update_inplace(uint64_t mutid, BaseMapObject *o, PyObject *src);
static MapObject *
map_update(uint64_t mutid, MapObject *o, PyObject *src);
@@ -2153,6 +2153,8 @@ map_node_assoc(MapNode *node,
map_node_{nodetype}_assoc method.
*/
+ *added_leaf = 0;
+
if (IS_BITMAP_NODE(node)) {
return map_node_bitmap_assoc(
(MapNode_Bitmap *)node,
@@ -2892,10 +2894,27 @@ map_tp_init(MapObject *self, PyObject *args, PyObject *kwds)
}
if (arg != NULL) {
- mutid = mutid_counter++;
- if (map_update_inplace(mutid, self, arg)) {
+ if (Map_Check(arg)) {
+ MapObject *other = (MapObject *)arg;
+
+ Py_INCREF(other->h_root);
+ Py_SETREF(self->h_root, other->h_root);
+
+ self->h_count = other->h_count;
+ self->h_hash = other->h_hash;
+ }
+ else if (MapMutation_Check(arg)) {
+ PyErr_Format(
+ PyExc_TypeError,
+ "cannot create Maps from MapMutations");
return -1;
}
+ else {
+ mutid = mutid_counter++;
+ if (map_update_inplace(mutid, (BaseMapObject *)self, arg)) {
+ return -1;
+ }
+ }
}
if (kwds != NULL) {
@@ -2907,7 +2926,7 @@ map_tp_init(MapObject *self, PyObject *args, PyObject *kwds)
mutid = mutid_counter++;
}
- if (map_update_inplace(mutid, self, kwds)) {
+ if (map_update_inplace(mutid, (BaseMapObject *)self, kwds)) {
return -1;
}
}
@@ -3665,14 +3684,14 @@ map_node_update(uint64_t mutid,
static int
-map_update_inplace(uint64_t mutid, MapObject *o, PyObject *src)
+map_update_inplace(uint64_t mutid, BaseMapObject *o, PyObject *src)
{
MapNode *new_root = NULL;
Py_ssize_t new_count;
int ret = map_node_update(
mutid, src,
- o->h_root, o->h_count,
+ o->b_root, o->b_count,
&new_root, &new_count);
if (ret) {
@@ -3681,8 +3700,8 @@ map_update_inplace(uint64_t mutid, MapObject *o, PyObject *src)
assert(new_root);
- Py_SETREF(o->h_root, new_root);
- o->h_count = new_count;
+ Py_SETREF(o->b_root, new_root);
+ o->b_count = new_count;
return 0;
}
@@ -3853,6 +3872,35 @@ mapmut_tp_richcompare(PyObject *v, PyObject *w, int op)
}
static PyObject *
+mapmut_py_update(MapMutationObject *self, PyObject *args, PyObject *kwds)
+{
+ PyObject *arg = NULL;
+
+ if (!PyArg_UnpackTuple(args, "update", 0, 1, &arg)) {
+ return NULL;
+ }
+
+ if (arg != NULL) {
+ if (map_update_inplace(self->m_mutid, (BaseMapObject *)self, arg)) {
+ return NULL;
+ }
+ }
+
+ if (kwds != NULL) {
+ if (!PyArg_ValidateKeywordArguments(kwds)) {
+ return NULL;
+ }
+
+ if (map_update_inplace(self->m_mutid, (BaseMapObject *)self, kwds)) {
+ return NULL;
+ }
+ }
+
+ Py_RETURN_NONE;
+}
+
+
+static PyObject *
mapmut_py_finalize(MapMutationObject *self, PyObject *args)
{
self->m_mutid = 0;
@@ -3970,6 +4018,8 @@ static PyMethodDef MapMutation_methods[] = {
{"get", (PyCFunction)map_py_get, METH_VARARGS, NULL},
{"pop", (PyCFunction)mapmut_py_pop, METH_VARARGS, NULL},
{"finish", (PyCFunction)mapmut_py_finalize, METH_NOARGS, NULL},
+ {"update", (PyCFunction)mapmut_py_update,
+ METH_VARARGS | METH_KEYWORDS, NULL},
{"__enter__", (PyCFunction)mapmut_py_enter, METH_NOARGS, NULL},
{"__exit__", (PyCFunction)mapmut_py_exit, METH_VARARGS, NULL},
{NULL, NULL}
diff --git a/immutables/map.py b/immutables/map.py
index fc40ac8..7b230dd 100644
--- a/immutables/map.py
+++ b/immutables/map.py
@@ -438,6 +438,14 @@ class Map:
self.__root = BitmapNode(0, 0, [], 0)
self.__hash = -1
+ if isinstance(col, Map):
+ self.__count = col.__count
+ self.__root = col.__root
+ self.__hash = col.__hash
+ col = None
+ elif isinstance(col, MapMutation):
+ raise TypeError('cannot create Maps from MapMutations')
+
if col or kw:
init = self.update(col, **kw)
self.__count = init.__count
@@ -640,6 +648,9 @@ class MapMutation:
self.finish()
return False
+ def __iter__(self):
+ raise TypeError('{} is not iterable'.format(type(self)))
+
def __delitem__(self, key):
if self.__mutid == 0:
raise ValueError('mutation {!r} has been finished'.format(self))
@@ -707,6 +718,56 @@ class MapMutation:
else:
return True
+ def update(self, col=None, **kw):
+ it = None
+ if col is not None:
+ if hasattr(col, 'items'):
+ it = iter(col.items())
+ else:
+ it = iter(col)
+
+ if it is not None:
+ if kw:
+ it = iter(itertools.chain(it, kw.items()))
+ else:
+ if kw:
+ it = iter(kw.items())
+
+ if it is None:
+
+ return self
+
+ root = self.__root
+ count = self.__count
+
+ i = 0
+ while True:
+ try:
+ tup = next(it)
+ except StopIteration:
+ break
+
+ try:
+ tup = tuple(tup)
+ except TypeError:
+ raise TypeError(
+ 'cannot convert map update '
+ 'sequence element #{} to a sequence'.format(i)) from None
+ key, val, *r = tup
+ if r:
+ raise ValueError(
+ 'map update sequence element #{} has length '
+ '{}; 2 is required'.format(i, len(r) + 2))
+
+ root, added = root.assoc(0, map_hash(key), key, val, self.__mutid)
+ if added:
+ count += 1
+
+ i += 1
+
+ self.__root = root
+ self.__count = count
+
def finish(self):
self.__mutid = 0
return Map._new(self.__count, self.__root)
diff --git a/tests/test_map.py b/tests/test_map.py
index 959e00c..66d07e7 100644
--- a/tests/test_map.py
+++ b/tests/test_map.py
@@ -1089,11 +1089,6 @@ class BaseMapTest:
with self.assertRaises(HashingError):
self.Map(src)
- src = self.Map({key1: 123})
- with HashKeyCrasher(error_on_hash=True):
- with self.assertRaises(HashingError):
- self.Map(src)
-
src = [(1, 2), (key1, 123)]
with HashKeyCrasher(error_on_hash=True):
with self.assertRaises(HashingError):
@@ -1195,6 +1190,47 @@ class BaseMapTest:
self.assertEqual(mm.finish(), self.Map(z=100, b=2))
self.assertEqual(m, self.Map(a=1, b=2))
+ def test_map_mut_16(self):
+ m = self.Map(a=1, b=2)
+ hash(m)
+
+ m2 = self.Map(m)
+ m3 = self.Map(m, c=3)
+
+ self.assertEqual(m, m2)
+ self.assertEqual(len(m), len(m2))
+ self.assertEqual(hash(m), hash(m2))
+
+ self.assertIsNot(m, m2)
+ self.assertEqual(m3, self.Map(a=1, b=2, c=3))
+
+ def test_map_mut_17(self):
+ m = self.Map(a=1)
+ with m.mutate() as mm:
+ with self.assertRaisesRegex(
+ TypeError, 'cannot create Maps from MapMutations'):
+ self.Map(mm)
+
+ def test_map_mut_18(self):
+ m = self.Map(a=1, b=2)
+ with m.mutate() as mm:
+ mm.update(self.Map(x=1), z=2)
+ mm.update(c=3)
+ mm.update({'n': 100, 'a': 20})
+ m2 = mm.finish()
+
+ expected = self.Map(
+ {'b': 2, 'c': 3, 'n': 100, 'z': 2, 'x': 1, 'a': 20})
+
+ self.assertEqual(len(m2), 6)
+ self.assertEqual(m2, expected)
+ self.assertEqual(m, self.Map(a=1, b=2))
+
+ def test_map_mut_19(self):
+ m = self.Map(a=1, b=2)
+ m2 = m.update({'a': 20})
+ self.assertEqual(len(m2), 2)
+
def test_map_mut_stress(self):
COLLECTION_SIZE = 7000
TEST_ITERS_EVERY = 647