From 666613d4a5400fd1ce0267a0bcf9cd70afae7392 Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Mon, 19 Nov 2018 21:00:42 -0500 Subject: Implement Map.update(); support initializing Map from dict/iter/map --- immutables/_map.c | 431 +++++++++++++++++++++++++++++++++++++++++++++++++++--- immutables/map.py | 57 +++++++- tests/test_map.py | 137 ++++++++++++++--- 3 files changed, 583 insertions(+), 42 deletions(-) diff --git a/immutables/_map.c b/immutables/_map.c index b276d8a..dbe5e01 100644 --- a/immutables/_map.c +++ b/immutables/_map.c @@ -312,7 +312,6 @@ typedef struct { static volatile uint64_t mutid_counter = 1; static MapNode_Bitmap *_empty_bitmap_node; -static MapObject *_empty_map; /* Create a new HAMT immutable mapping. */ @@ -380,6 +379,19 @@ map_node_collision_new(int32_t hash, Py_ssize_t size, uint64_t mutid); static inline Py_ssize_t map_node_collision_count(MapNode_Collision *node); +static int +map_node_update(uint64_t mutid, + PyObject *seq, + MapNode *root, Py_ssize_t count, + MapNode **new_root, Py_ssize_t *new_count); + + +static int +map_update_inplace(uint64_t mutid, MapObject *o, PyObject *src); + +static MapObject * +map_update(uint64_t mutid, MapObject *o, PyObject *src); + #ifdef NDEBUG static void @@ -2570,13 +2582,6 @@ map_alloc(void) static MapObject * map_new(void) { - if (_empty_map != NULL) { - /* HAMT is an immutable object so we can easily cache an - empty instance. */ - Py_INCREF(_empty_map); - return _empty_map; - } - MapObject *o = map_alloc(); if (o == NULL) { return NULL; @@ -2588,11 +2593,6 @@ map_new(void) return NULL; } - if (_empty_map == NULL) { - Py_INCREF(o); - _empty_map = o; - } - return o; } @@ -2787,17 +2787,45 @@ map_dump(MapObject *self); static PyObject * map_tp_new(PyTypeObject *type, PyObject *args, PyObject *kwds) { - if (kwds != NULL && PyDict_Size(kwds) != 0) { - PyErr_SetString(PyExc_TypeError, "__new__ takes no keyword arguments"); - return NULL; + return (PyObject*)map_new(); +} + + +static int +map_tp_init(MapObject *self, PyObject *args, PyObject *kwds) +{ + PyObject *arg = NULL; + uint64_t mutid = 0; + + if (!PyArg_UnpackTuple(args, "immutables.Map", 0, 1, &arg)) { + return -1; } - if (args != NULL && PyTuple_Size(args) != 0) { - PyErr_SetString(PyExc_TypeError, "__new__ takes no positional arguments"); - return NULL; + + if (arg != NULL) { + mutid = mutid_counter++; + if (map_update_inplace(mutid, self, arg)) { + return -1; + } } - return (PyObject*)map_new(); + + if (kwds != NULL) { + if (!PyArg_ValidateKeywordArguments(kwds)) { + return -1; + } + + if (!mutid) { + mutid = mutid_counter++; + } + + if (map_update_inplace(mutid, self, kwds)) { + return -1; + } + } + + return 0; } + static int map_tp_clear(BaseMapObject *self) { @@ -2966,6 +2994,50 @@ map_py_mutate(MapObject *self, PyObject *args) return (PyObject *)o; } +static PyObject * +map_py_update(MapObject *self, PyObject *args, PyObject *kwds) +{ + PyObject *arg = NULL; + MapObject *new = NULL; + uint64_t mutid = 0; + + if (!PyArg_UnpackTuple(args, "update", 0, 1, &arg)) { + return NULL; + } + + if (arg != NULL) { + mutid = mutid_counter++; + new = map_update(mutid, self, arg); + if (new == NULL) { + return NULL; + } + } + else { + Py_INCREF(self); + new = self; + } + + if (kwds != NULL) { + if (!PyArg_ValidateKeywordArguments(kwds)) { + Py_DECREF(new); + return NULL; + } + + if (!mutid) { + mutid = mutid_counter++; + } + + MapObject *new2 = map_update(mutid, new, kwds); + Py_DECREF(new); + if (new2 == NULL) { + return NULL; + } + new = new2; + } + + return (PyObject *)new; +} + static PyObject * map_py_items(MapObject *self, PyObject *args) { @@ -3155,6 +3227,7 @@ static PyMethodDef Map_methods[] = { {"items", (PyCFunction)map_py_items, METH_NOARGS, NULL}, {"keys", (PyCFunction)map_py_keys, METH_NOARGS, NULL}, {"values", (PyCFunction)map_py_values, METH_NOARGS, NULL}, + {"update", (PyCFunction)map_py_update, METH_VARARGS | METH_KEYWORDS, NULL}, {"__dump__", (PyCFunction)map_py_dump, METH_NOARGS, NULL}, {NULL, NULL} }; @@ -3192,6 +3265,7 @@ PyTypeObject _Map_Type = { .tp_traverse = (traverseproc)map_tp_traverse, .tp_clear = (inquiry)map_tp_clear, .tp_new = map_tp_new, + .tp_init = (initproc)map_tp_init, .tp_weaklistoffset = offsetof(MapObject, h_weakreflist), .tp_hash = (hashfunc)map_py_hash, .tp_repr = (reprfunc)map_py_repr, @@ -3201,6 +3275,322 @@ PyTypeObject _Map_Type = { /////////////////////////////////// MapMutation +static int +map_node_update_from_map(uint64_t mutid, + MapObject *map, + MapNode *root, Py_ssize_t count, + MapNode **new_root, Py_ssize_t *new_count) +{ + assert(Map_Check(map)); + + MapIteratorState iter; + map_iter_t iter_res; + + MapNode *last_root; + Py_ssize_t last_count; + + Py_INCREF(root); + last_root = root; + last_count = count; + + map_iterator_init(&iter, map->h_root); + do { + PyObject *key; + PyObject *val; + int32_t key_hash; + int added_leaf; + + iter_res = map_iterator_next(&iter, &key, &val); + if (iter_res == I_ITEM) { + key_hash = map_hash(key); + if (key_hash == -1) { + goto err; + } + + MapNode *iter_root = map_node_assoc( + last_root, + 0, key_hash, key, val, &added_leaf, + mutid); + + if (iter_root == NULL) { + goto err; + } + + if (added_leaf) { + last_count++; + } + + Py_SETREF(last_root, iter_root); + } + } while (iter_res != I_END); + + *new_root = last_root; + *new_count = last_count; + + return 0; + +err: + Py_DECREF(last_root); + return -1; +} + + +static int +map_node_update_from_dict(uint64_t mutid, + PyObject *dct, + MapNode *root, Py_ssize_t count, + MapNode **new_root, Py_ssize_t *new_count) +{ + assert(PyDict_Check(dct)); + + PyObject *it = PyObject_GetIter(dct); + if (it == NULL) { + return -1; + } + + MapNode *last_root; + Py_ssize_t last_count; + + Py_INCREF(root); + last_root = root; + last_count = count; + + PyObject *key; + + while ((key = PyIter_Next(it))) { + PyObject *val; + int added_leaf; + int32_t key_hash; + + key_hash = map_hash(key); + if (key_hash == -1) { + Py_DECREF(key); + goto err; + } + + val = PyDict_GetItemWithError(dct, key); + if (val == NULL) { + Py_DECREF(key); + goto err; + } + + MapNode *iter_root = map_node_assoc( + last_root, + 0, key_hash, key, val, &added_leaf, + mutid); + + Py_DECREF(key); + + if (iter_root == NULL) { + goto err; + } + + if (added_leaf) { + last_count++; + } + + Py_SETREF(last_root, iter_root); + } + + if (key == NULL && PyErr_Occurred()) { + goto err; + } + + Py_DECREF(it); + + *new_root = last_root; + *new_count = last_count; + + return 0; + +err: + Py_DECREF(it); + Py_DECREF(last_root); + return -1; +} + + +static int +map_node_update_from_seq(uint64_t mutid, + PyObject *seq, + MapNode *root, Py_ssize_t count, + MapNode **new_root, Py_ssize_t *new_count) +{ + PyObject *it; + Py_ssize_t i; + PyObject *item = NULL; + PyObject *fast = NULL; + + MapNode *last_root; + Py_ssize_t last_count; + + it = PyObject_GetIter(seq); + if (it == NULL) { + return -1; + } + + Py_INCREF(root); + last_root = root; + last_count = count; + + for (i = 0; ; i++) { + PyObject *key, *val; + Py_ssize_t n; + int32_t key_hash; + int added_leaf; + + item = PyIter_Next(it); + if (item == NULL) { + if (PyErr_Occurred()) { + goto err; + } + break; + } + + fast = PySequence_Fast(item, ""); + if (fast == NULL) { + if (PyErr_ExceptionMatches(PyExc_TypeError)) + PyErr_Format(PyExc_TypeError, + "cannot convert map update " + "sequence element #%zd to a sequence", + i); + goto err; + } + + n = PySequence_Fast_GET_SIZE(fast); + if (n != 2) { + PyErr_Format(PyExc_ValueError, + "map update sequence element #%zd " + "has length %zd; 2 is required", + i, n); + goto err; + } + + key = PySequence_Fast_GET_ITEM(fast, 0); + val = PySequence_Fast_GET_ITEM(fast, 1); + Py_INCREF(key); + Py_INCREF(val); + + key_hash = map_hash(key); + if (key_hash == -1) { + Py_DECREF(key); + Py_DECREF(val); + goto err; + } + + MapNode *iter_root = map_node_assoc( + last_root, + 0, key_hash, key, val, &added_leaf, + mutid); + + Py_DECREF(key); + Py_DECREF(val); + + if (iter_root == NULL) { + goto err; + } + + if (added_leaf) { + last_count++; + } + + Py_SETREF(last_root, iter_root); + + Py_DECREF(fast); + Py_DECREF(item); + } + + Py_DECREF(it); + + *new_root = last_root; + *new_count = last_count; + + return 0; + +err: + Py_DECREF(last_root); + Py_XDECREF(item); + Py_XDECREF(fast); + Py_DECREF(it); + return -1; +} + + +static int +map_node_update(uint64_t mutid, + PyObject *src, + MapNode *root, Py_ssize_t count, + MapNode **new_root, Py_ssize_t *new_count) +{ + if (Map_Check(src)) { + return map_node_update_from_map( + mutid, (MapObject *)src, root, count, new_root, new_count); + } + else if (PyDict_Check(src)) { + return map_node_update_from_dict( + mutid, src, root, count, new_root, new_count); + } + else { + return map_node_update_from_seq( + mutid, src, root, count, new_root, new_count); + } +} + + +static int +map_update_inplace(uint64_t mutid, MapObject *o, PyObject *src) +{ + MapNode *new_root = NULL; + Py_ssize_t new_count; + + int ret = map_node_update( + mutid, src, + o->h_root, o->h_count, + &new_root, &new_count); + + if (ret) { + return -1; + } + + assert(new_root); + + Py_SETREF(o->h_root, new_root); + o->h_count = new_count; + + return 0; +} + + +static MapObject * +map_update(uint64_t mutid, MapObject *o, PyObject *src) +{ + MapNode *new_root = NULL; + Py_ssize_t new_count; + + int ret = map_node_update( + mutid, src, + o->h_root, o->h_count, + &new_root, &new_count); + + if (ret) { + return NULL; + } + + assert(new_root); + + MapObject *new = map_alloc(); + if (new == NULL) { + Py_DECREF(new_root); + return NULL; + } + + Py_XSETREF(new->h_root, new_root); + new->h_count = new_count; + + return new; +} + + static PyObject * mapmut_py_set(MapMutationObject *o, PyObject *args) { @@ -3425,7 +3815,6 @@ PyTypeObject _Map_CollisionNode_Type = { static void module_free(void *m) { - Py_CLEAR(_empty_map); Py_CLEAR(_empty_bitmap_node); } diff --git a/immutables/map.py b/immutables/map.py index 8c7d862..9884956 100644 --- a/immutables/map.py +++ b/immutables/map.py @@ -409,11 +409,16 @@ class GenWrapper: class Map: - def __init__(self): + def __init__(self, col=None, **kw): self.__count = 0 self.__root = BitmapNode(0, 0, [], 0) self.__hash = -1 + if col or kw: + init = self.update(col, **kw) + self.__count = init.__count + self.__root = init.__root + @classmethod def _new(cls, count, root): m = Map.__new__(Map) @@ -443,6 +448,56 @@ class Map: return True + def update(self, col=None, **kw): + it = None + if col is not None: + if hasattr(col, 'items'): + it = iter(col.items()) + else: + it = iter(col) + + if it is not None: + if kw: + it = iter(itertools.chain(it, kw.items())) + else: + if kw: + it = iter(kw.items()) + + if it is None: + + return self + + mutid = _mut_id() + root = self.__root + count = self.__count + + i = 0 + while True: + try: + tup = next(it) + except StopIteration: + break + + try: + tup = tuple(tup) + except TypeError: + raise TypeError( + f'cannot convert map update ' + f'sequence element #{i} to a sequence') from None + key, val, *r = tup + if r: + raise ValueError( + f'map update sequence element #{i} has length ' + f'{len(r) + 2}; 2 is required') + + root, added = root.assoc(0, map_hash(key), key, val, mutid) + if added: + count += 1 + + i += 1 + + return Map._new(count, root) + def mutate(self): return MapMutation(self.__count, self.__root) diff --git a/tests/test_map.py b/tests/test_map.py index e2ba2b4..afa2e2e 100644 --- a/tests/test_map.py +++ b/tests/test_map.py @@ -60,7 +60,7 @@ class KeyStr(str): return super().__eq__(other) -class HaskKeyCrasher: +class HashKeyCrasher: def __init__(self, *, error_on_hash=False, error_on_eq=False, error_on_repr=False): @@ -93,13 +93,6 @@ class BaseMapTest: Map = None - def test_init_no_args(self): - with self.assertRaisesRegex(TypeError, 'positional argument'): - self.Map(dict(a=1)) - - with self.assertRaisesRegex(TypeError, 'keyword argument'): - self.Map(a=1) - def test_hashkey_helper_1(self): k1 = HashKey(10, 'aaa') k2 = HashKey(10, 'bbb') @@ -260,14 +253,14 @@ class BaseMapTest: key = KeyStr(i) if not (i % CRASH_HASH_EVERY): - with HaskKeyCrasher(error_on_hash=True): + with HashKeyCrasher(error_on_hash=True): with self.assertRaises(HashingError): h.set(key, i) h = h.set(key, i) if not (i % CRASH_EQ_EVERY): - with HaskKeyCrasher(error_on_eq=True): + with HashKeyCrasher(error_on_eq=True): with self.assertRaises(EqError): h.get(KeyStr(i)) # really trigger __eq__ @@ -289,12 +282,12 @@ class BaseMapTest: key = KeyStr(i) if not (iter_i % CRASH_HASH_EVERY): - with HaskKeyCrasher(error_on_hash=True): + with HashKeyCrasher(error_on_hash=True): with self.assertRaises(HashingError): h.delete(key) if not (iter_i % CRASH_EQ_EVERY): - with HaskKeyCrasher(error_on_eq=True): + with HashKeyCrasher(error_on_eq=True): with self.assertRaises(EqError): h.delete(KeyStr(i)) @@ -807,11 +800,11 @@ class BaseMapTest: self.assertFalse(B in h) with self.assertRaises(EqError): - with HaskKeyCrasher(error_on_eq=True): + with HashKeyCrasher(error_on_eq=True): AA in h with self.assertRaises(HashingError): - with HaskKeyCrasher(error_on_hash=True): + with HashKeyCrasher(error_on_hash=True): AA in h def test_map_getitem_1(self): @@ -830,11 +823,11 @@ class BaseMapTest: h[B] with self.assertRaises(EqError): - with HaskKeyCrasher(error_on_eq=True): + with HashKeyCrasher(error_on_eq=True): h[AA] with self.assertRaises(HashingError): - with HaskKeyCrasher(error_on_hash=True): + with HashKeyCrasher(error_on_hash=True): h[AA] def test_repr_1(self): @@ -850,11 +843,11 @@ class BaseMapTest: A = HashKey(100, 'A') with self.assertRaises(ReprError): - with HaskKeyCrasher(error_on_repr=True): + with HashKeyCrasher(error_on_repr=True): repr(h.set(1, 2).set(A, 3).set(3, 4)) with self.assertRaises(ReprError): - with HaskKeyCrasher(error_on_repr=True): + with HashKeyCrasher(error_on_repr=True): repr(h.set(1, 2).set(2, A).set(3, 4)) def test_repr_3(self): @@ -895,12 +888,12 @@ class BaseMapTest: m = h.set(1, 2).set(A, 3).set(3, 4) with self.assertRaises(HashingError): - with HaskKeyCrasher(error_on_hash=True): + with HashKeyCrasher(error_on_hash=True): hash(m) m = h.set(1, 2).set(2, A).set(3, 4) with self.assertRaises(HashingError): - with HaskKeyCrasher(error_on_hash=True): + with HashKeyCrasher(error_on_hash=True): hash(m) def test_abc_1(self): @@ -983,6 +976,110 @@ class BaseMapTest: hm2.delete('a') self.assertNotEqual(hm1, hm2) + def test_map_mut_5(self): + h = self.Map({'a': 1, 'b': 2}, z=100) + self.assertTrue(isinstance(h, self.Map)) + self.assertEqual(dict(h.items()), {'a': 1, 'b': 2, 'z': 100}) + + h2 = h.update(z=200, y=-1) + self.assertEqual(dict(h.items()), {'a': 1, 'b': 2, 'z': 100}) + self.assertEqual(dict(h2.items()), {'a': 1, 'b': 2, 'z': 200, 'y': -1}) + + h3 = h2.update([(1, 2), (3, 4)]) + self.assertEqual(dict(h.items()), {'a': 1, 'b': 2, 'z': 100}) + self.assertEqual(dict(h2.items()), {'a': 1, 'b': 2, 'z': 200, 'y': -1}) + self.assertEqual(dict(h3.items()), + {'a': 1, 'b': 2, 'z': 200, 'y': -1, 1: 2, 3: 4}) + + h4 = h3.update() + self.assertIs(h4, h3) + + h5 = h4.update(self.Map({'zzz': 'yyz'})) + + self.assertEqual(dict(h5.items()), + {'a': 1, 'b': 2, 'z': 200, 'y': -1, 1: 2, 3: 4, + 'zzz': 'yyz'}) + + def test_map_mut_6(self): + h = self.Map({'a': 1, 'b': 2}, z=100) + self.assertEqual(dict(h.items()), {'a': 1, 'b': 2, 'z': 100}) + + with self.assertRaisesRegex(TypeError, 'not iterable'): + h.update(1) + + with self.assertRaisesRegex(ValueError, 'map update sequence element'): + h.update([(1, 2), (3, 4, 5)]) + + with self.assertRaisesRegex(TypeError, 'cannot convert map update'): + h.update([(1, 2), 1]) + + self.assertEqual(dict(h.items()), {'a': 1, 'b': 2, 'z': 100}) + + def test_map_mut_7(self): + key = HashKey(123, 'aaa') + + h = self.Map({'a': 1, 'b': 2}, z=100) + self.assertEqual(dict(h.items()), {'a': 1, 'b': 2, 'z': 100}) + + upd = {key: 1} + with HashKeyCrasher(error_on_hash=True): + with self.assertRaises(HashingError): + h.update(upd) + + upd = self.Map({key: 'zzz'}) + with HashKeyCrasher(error_on_hash=True): + with self.assertRaises(HashingError): + h.update(upd) + + upd = [(1, 2), (key, 'zzz')] + with HashKeyCrasher(error_on_hash=True): + with self.assertRaises(HashingError): + h.update(upd) + + self.assertEqual(dict(h.items()), {'a': 1, 'b': 2, 'z': 100}) + + def test_map_mut_8(self): + key1 = HashKey(123, 'aaa') + key2 = HashKey(123, 'bbb') + + h = self.Map({key1: 123}) + self.assertEqual(dict(h.items()), {key1: 123}) + + upd = {key2: 1} + with HashKeyCrasher(error_on_eq=True): + with self.assertRaises(EqError): + h.update(upd) + + upd = self.Map({key2: 'zzz'}) + with HashKeyCrasher(error_on_eq=True): + with self.assertRaises(EqError): + h.update(upd) + + upd = [(1, 2), (key2, 'zzz')] + with HashKeyCrasher(error_on_eq=True): + with self.assertRaises(EqError): + h.update(upd) + + self.assertEqual(dict(h.items()), {key1: 123}) + + def test_map_mut_9(self): + key1 = HashKey(123, 'aaa') + + src = {key1: 123} + with HashKeyCrasher(error_on_hash=True): + with self.assertRaises(HashingError): + self.Map(src) + + src = self.Map({key1: 123}) + with HashKeyCrasher(error_on_hash=True): + with self.assertRaises(HashingError): + self.Map(src) + + src = [(1, 2), (key1, 123)] + with HashKeyCrasher(error_on_hash=True): + with self.assertRaises(HashingError): + self.Map(src) + def test_map_mut_stress(self): COLLECTION_SIZE = 7000 TEST_ITERS_EVERY = 647 -- cgit v1.2.3