aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYury Selivanov <yury@magic.io>2018-11-19 21:00:42 -0500
committerYury Selivanov <yury@magic.io>2018-11-20 14:25:07 -0500
commit666613d4a5400fd1ce0267a0bcf9cd70afae7392 (patch)
tree199eeb041f7b17152885dab6715a6c940dcf452a
parent309f2991557673c67c3d8fae995c8b23cc0d4d7c (diff)
downloadimmutables-666613d4a5400fd1ce0267a0bcf9cd70afae7392.tar.gz
immutables-666613d4a5400fd1ce0267a0bcf9cd70afae7392.zip
Implement Map.update(); support initializing Map from dict/iter/map
-rw-r--r--immutables/_map.c431
-rw-r--r--immutables/map.py57
-rw-r--r--tests/test_map.py137
3 files changed, 583 insertions, 42 deletions
diff --git a/immutables/_map.c b/immutables/_map.c
index b276d8a..dbe5e01 100644
--- a/immutables/_map.c
+++ b/immutables/_map.c
@@ -312,7 +312,6 @@ typedef struct {
static volatile uint64_t mutid_counter = 1;
static MapNode_Bitmap *_empty_bitmap_node;
-static MapObject *_empty_map;
/* Create a new HAMT immutable mapping. */
@@ -380,6 +379,19 @@ map_node_collision_new(int32_t hash, Py_ssize_t size, uint64_t mutid);
static inline Py_ssize_t
map_node_collision_count(MapNode_Collision *node);
+static int
+map_node_update(uint64_t mutid,
+ PyObject *seq,
+ MapNode *root, Py_ssize_t count,
+ MapNode **new_root, Py_ssize_t *new_count);
+
+
+static int
+map_update_inplace(uint64_t mutid, MapObject *o, PyObject *src);
+
+static MapObject *
+map_update(uint64_t mutid, MapObject *o, PyObject *src);
+
#ifdef NDEBUG
static void
@@ -2570,13 +2582,6 @@ map_alloc(void)
static MapObject *
map_new(void)
{
- if (_empty_map != NULL) {
- /* HAMT is an immutable object so we can easily cache an
- empty instance. */
- Py_INCREF(_empty_map);
- return _empty_map;
- }
-
MapObject *o = map_alloc();
if (o == NULL) {
return NULL;
@@ -2588,11 +2593,6 @@ map_new(void)
return NULL;
}
- if (_empty_map == NULL) {
- Py_INCREF(o);
- _empty_map = o;
- }
-
return o;
}
@@ -2787,17 +2787,45 @@ map_dump(MapObject *self);
static PyObject *
map_tp_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
{
- if (kwds != NULL && PyDict_Size(kwds) != 0) {
- PyErr_SetString(PyExc_TypeError, "__new__ takes no keyword arguments");
- return NULL;
+ return (PyObject*)map_new();
+}
+
+
+static int
+map_tp_init(MapObject *self, PyObject *args, PyObject *kwds)
+{
+ PyObject *arg = NULL;
+ uint64_t mutid = 0;
+
+ if (!PyArg_UnpackTuple(args, "immutables.Map", 0, 1, &arg)) {
+ return -1;
}
- if (args != NULL && PyTuple_Size(args) != 0) {
- PyErr_SetString(PyExc_TypeError, "__new__ takes no positional arguments");
- return NULL;
+
+ if (arg != NULL) {
+ mutid = mutid_counter++;
+ if (map_update_inplace(mutid, self, arg)) {
+ return -1;
+ }
}
- return (PyObject*)map_new();
+
+ if (kwds != NULL) {
+ if (!PyArg_ValidateKeywordArguments(kwds)) {
+ return -1;
+ }
+
+ if (!mutid) {
+ mutid = mutid_counter++;
+ }
+
+ if (map_update_inplace(mutid, self, kwds)) {
+ return -1;
+ }
+ }
+
+ return 0;
}
+
static int
map_tp_clear(BaseMapObject *self)
{
@@ -2967,6 +2995,50 @@ map_py_mutate(MapObject *self, PyObject *args)
}
static PyObject *
+map_py_update(MapObject *self, PyObject *args, PyObject *kwds)
+{
+ PyObject *arg = NULL;
+ MapObject *new = NULL;
+ uint64_t mutid = 0;
+
+ if (!PyArg_UnpackTuple(args, "update", 0, 1, &arg)) {
+ return NULL;
+ }
+
+ if (arg != NULL) {
+ mutid = mutid_counter++;
+ new = map_update(mutid, self, arg);
+ if (new == NULL) {
+ return NULL;
+ }
+ }
+ else {
+ Py_INCREF(self);
+ new = self;
+ }
+
+ if (kwds != NULL) {
+ if (!PyArg_ValidateKeywordArguments(kwds)) {
+ Py_DECREF(new);
+ return NULL;
+ }
+
+ if (!mutid) {
+ mutid = mutid_counter++;
+ }
+
+ MapObject *new2 = map_update(mutid, new, kwds);
+ Py_DECREF(new);
+ if (new2 == NULL) {
+ return NULL;
+ }
+ new = new2;
+ }
+
+ return (PyObject *)new;
+}
+
+static PyObject *
map_py_items(MapObject *self, PyObject *args)
{
return map_new_iteritems(self);
@@ -3155,6 +3227,7 @@ static PyMethodDef Map_methods[] = {
{"items", (PyCFunction)map_py_items, METH_NOARGS, NULL},
{"keys", (PyCFunction)map_py_keys, METH_NOARGS, NULL},
{"values", (PyCFunction)map_py_values, METH_NOARGS, NULL},
+ {"update", (PyCFunction)map_py_update, METH_VARARGS | METH_KEYWORDS, NULL},
{"__dump__", (PyCFunction)map_py_dump, METH_NOARGS, NULL},
{NULL, NULL}
};
@@ -3192,6 +3265,7 @@ PyTypeObject _Map_Type = {
.tp_traverse = (traverseproc)map_tp_traverse,
.tp_clear = (inquiry)map_tp_clear,
.tp_new = map_tp_new,
+ .tp_init = (initproc)map_tp_init,
.tp_weaklistoffset = offsetof(MapObject, h_weakreflist),
.tp_hash = (hashfunc)map_py_hash,
.tp_repr = (reprfunc)map_py_repr,
@@ -3201,6 +3275,322 @@ PyTypeObject _Map_Type = {
/////////////////////////////////// MapMutation
+static int
+map_node_update_from_map(uint64_t mutid,
+ MapObject *map,
+ MapNode *root, Py_ssize_t count,
+ MapNode **new_root, Py_ssize_t *new_count)
+{
+ assert(Map_Check(map));
+
+ MapIteratorState iter;
+ map_iter_t iter_res;
+
+ MapNode *last_root;
+ Py_ssize_t last_count;
+
+ Py_INCREF(root);
+ last_root = root;
+ last_count = count;
+
+ map_iterator_init(&iter, map->h_root);
+ do {
+ PyObject *key;
+ PyObject *val;
+ int32_t key_hash;
+ int added_leaf;
+
+ iter_res = map_iterator_next(&iter, &key, &val);
+ if (iter_res == I_ITEM) {
+ key_hash = map_hash(key);
+ if (key_hash == -1) {
+ goto err;
+ }
+
+ MapNode *iter_root = map_node_assoc(
+ last_root,
+ 0, key_hash, key, val, &added_leaf,
+ mutid);
+
+ if (iter_root == NULL) {
+ goto err;
+ }
+
+ if (added_leaf) {
+ last_count++;
+ }
+
+ Py_SETREF(last_root, iter_root);
+ }
+ } while (iter_res != I_END);
+
+ *new_root = last_root;
+ *new_count = last_count;
+
+ return 0;
+
+err:
+ Py_DECREF(last_root);
+ return -1;
+}
+
+
+static int
+map_node_update_from_dict(uint64_t mutid,
+ PyObject *dct,
+ MapNode *root, Py_ssize_t count,
+ MapNode **new_root, Py_ssize_t *new_count)
+{
+ assert(PyDict_Check(dct));
+
+ PyObject *it = PyObject_GetIter(dct);
+ if (it == NULL) {
+ return -1;
+ }
+
+ MapNode *last_root;
+ Py_ssize_t last_count;
+
+ Py_INCREF(root);
+ last_root = root;
+ last_count = count;
+
+ PyObject *key;
+
+ while ((key = PyIter_Next(it))) {
+ PyObject *val;
+ int added_leaf;
+ int32_t key_hash;
+
+ key_hash = map_hash(key);
+ if (key_hash == -1) {
+ Py_DECREF(key);
+ goto err;
+ }
+
+ val = PyDict_GetItemWithError(dct, key);
+ if (val == NULL) {
+ Py_DECREF(key);
+ goto err;
+ }
+
+ MapNode *iter_root = map_node_assoc(
+ last_root,
+ 0, key_hash, key, val, &added_leaf,
+ mutid);
+
+ Py_DECREF(key);
+
+ if (iter_root == NULL) {
+ goto err;
+ }
+
+ if (added_leaf) {
+ last_count++;
+ }
+
+ Py_SETREF(last_root, iter_root);
+ }
+
+ if (key == NULL && PyErr_Occurred()) {
+ goto err;
+ }
+
+ Py_DECREF(it);
+
+ *new_root = last_root;
+ *new_count = last_count;
+
+ return 0;
+
+err:
+ Py_DECREF(it);
+ Py_DECREF(last_root);
+ return -1;
+}
+
+
+static int
+map_node_update_from_seq(uint64_t mutid,
+ PyObject *seq,
+ MapNode *root, Py_ssize_t count,
+ MapNode **new_root, Py_ssize_t *new_count)
+{
+ PyObject *it;
+ Py_ssize_t i;
+ PyObject *item = NULL;
+ PyObject *fast = NULL;
+
+ MapNode *last_root;
+ Py_ssize_t last_count;
+
+ it = PyObject_GetIter(seq);
+ if (it == NULL) {
+ return -1;
+ }
+
+ Py_INCREF(root);
+ last_root = root;
+ last_count = count;
+
+ for (i = 0; ; i++) {
+ PyObject *key, *val;
+ Py_ssize_t n;
+ int32_t key_hash;
+ int added_leaf;
+
+ item = PyIter_Next(it);
+ if (item == NULL) {
+ if (PyErr_Occurred()) {
+ goto err;
+ }
+ break;
+ }
+
+ fast = PySequence_Fast(item, "");
+ if (fast == NULL) {
+ if (PyErr_ExceptionMatches(PyExc_TypeError))
+ PyErr_Format(PyExc_TypeError,
+ "cannot convert map update "
+ "sequence element #%zd to a sequence",
+ i);
+ goto err;
+ }
+
+ n = PySequence_Fast_GET_SIZE(fast);
+ if (n != 2) {
+ PyErr_Format(PyExc_ValueError,
+ "map update sequence element #%zd "
+ "has length %zd; 2 is required",
+ i, n);
+ goto err;
+ }
+
+ key = PySequence_Fast_GET_ITEM(fast, 0);
+ val = PySequence_Fast_GET_ITEM(fast, 1);
+ Py_INCREF(key);
+ Py_INCREF(val);
+
+ key_hash = map_hash(key);
+ if (key_hash == -1) {
+ Py_DECREF(key);
+ Py_DECREF(val);
+ goto err;
+ }
+
+ MapNode *iter_root = map_node_assoc(
+ last_root,
+ 0, key_hash, key, val, &added_leaf,
+ mutid);
+
+ Py_DECREF(key);
+ Py_DECREF(val);
+
+ if (iter_root == NULL) {
+ goto err;
+ }
+
+ if (added_leaf) {
+ last_count++;
+ }
+
+ Py_SETREF(last_root, iter_root);
+
+ Py_DECREF(fast);
+ Py_DECREF(item);
+ }
+
+ Py_DECREF(it);
+
+ *new_root = last_root;
+ *new_count = last_count;
+
+ return 0;
+
+err:
+ Py_DECREF(last_root);
+ Py_XDECREF(item);
+ Py_XDECREF(fast);
+ Py_DECREF(it);
+ return -1;
+}
+
+
+static int
+map_node_update(uint64_t mutid,
+ PyObject *src,
+ MapNode *root, Py_ssize_t count,
+ MapNode **new_root, Py_ssize_t *new_count)
+{
+ if (Map_Check(src)) {
+ return map_node_update_from_map(
+ mutid, (MapObject *)src, root, count, new_root, new_count);
+ }
+ else if (PyDict_Check(src)) {
+ return map_node_update_from_dict(
+ mutid, src, root, count, new_root, new_count);
+ }
+ else {
+ return map_node_update_from_seq(
+ mutid, src, root, count, new_root, new_count);
+ }
+}
+
+
+static int
+map_update_inplace(uint64_t mutid, MapObject *o, PyObject *src)
+{
+ MapNode *new_root = NULL;
+ Py_ssize_t new_count;
+
+ int ret = map_node_update(
+ mutid, src,
+ o->h_root, o->h_count,
+ &new_root, &new_count);
+
+ if (ret) {
+ return -1;
+ }
+
+ assert(new_root);
+
+ Py_SETREF(o->h_root, new_root);
+ o->h_count = new_count;
+
+ return 0;
+}
+
+
+static MapObject *
+map_update(uint64_t mutid, MapObject *o, PyObject *src)
+{
+ MapNode *new_root = NULL;
+ Py_ssize_t new_count;
+
+ int ret = map_node_update(
+ mutid, src,
+ o->h_root, o->h_count,
+ &new_root, &new_count);
+
+ if (ret) {
+ return NULL;
+ }
+
+ assert(new_root);
+
+ MapObject *new = map_alloc();
+ if (new == NULL) {
+ Py_DECREF(new_root);
+ return NULL;
+ }
+
+ Py_XSETREF(new->h_root, new_root);
+ new->h_count = new_count;
+
+ return new;
+}
+
+
static PyObject *
mapmut_py_set(MapMutationObject *o, PyObject *args)
{
@@ -3425,7 +3815,6 @@ PyTypeObject _Map_CollisionNode_Type = {
static void
module_free(void *m)
{
- Py_CLEAR(_empty_map);
Py_CLEAR(_empty_bitmap_node);
}
diff --git a/immutables/map.py b/immutables/map.py
index 8c7d862..9884956 100644
--- a/immutables/map.py
+++ b/immutables/map.py
@@ -409,11 +409,16 @@ class GenWrapper:
class Map:
- def __init__(self):
+ def __init__(self, col=None, **kw):
self.__count = 0
self.__root = BitmapNode(0, 0, [], 0)
self.__hash = -1
+ if col or kw:
+ init = self.update(col, **kw)
+ self.__count = init.__count
+ self.__root = init.__root
+
@classmethod
def _new(cls, count, root):
m = Map.__new__(Map)
@@ -443,6 +448,56 @@ class Map:
return True
+ def update(self, col=None, **kw):
+ it = None
+ if col is not None:
+ if hasattr(col, 'items'):
+ it = iter(col.items())
+ else:
+ it = iter(col)
+
+ if it is not None:
+ if kw:
+ it = iter(itertools.chain(it, kw.items()))
+ else:
+ if kw:
+ it = iter(kw.items())
+
+ if it is None:
+
+ return self
+
+ mutid = _mut_id()
+ root = self.__root
+ count = self.__count
+
+ i = 0
+ while True:
+ try:
+ tup = next(it)
+ except StopIteration:
+ break
+
+ try:
+ tup = tuple(tup)
+ except TypeError:
+ raise TypeError(
+ f'cannot convert map update '
+ f'sequence element #{i} to a sequence') from None
+ key, val, *r = tup
+ if r:
+ raise ValueError(
+ f'map update sequence element #{i} has length '
+ f'{len(r) + 2}; 2 is required')
+
+ root, added = root.assoc(0, map_hash(key), key, val, mutid)
+ if added:
+ count += 1
+
+ i += 1
+
+ return Map._new(count, root)
+
def mutate(self):
return MapMutation(self.__count, self.__root)
diff --git a/tests/test_map.py b/tests/test_map.py
index e2ba2b4..afa2e2e 100644
--- a/tests/test_map.py
+++ b/tests/test_map.py
@@ -60,7 +60,7 @@ class KeyStr(str):
return super().__eq__(other)
-class HaskKeyCrasher:
+class HashKeyCrasher:
def __init__(self, *, error_on_hash=False, error_on_eq=False,
error_on_repr=False):
@@ -93,13 +93,6 @@ class BaseMapTest:
Map = None
- def test_init_no_args(self):
- with self.assertRaisesRegex(TypeError, 'positional argument'):
- self.Map(dict(a=1))
-
- with self.assertRaisesRegex(TypeError, 'keyword argument'):
- self.Map(a=1)
-
def test_hashkey_helper_1(self):
k1 = HashKey(10, 'aaa')
k2 = HashKey(10, 'bbb')
@@ -260,14 +253,14 @@ class BaseMapTest:
key = KeyStr(i)
if not (i % CRASH_HASH_EVERY):
- with HaskKeyCrasher(error_on_hash=True):
+ with HashKeyCrasher(error_on_hash=True):
with self.assertRaises(HashingError):
h.set(key, i)
h = h.set(key, i)
if not (i % CRASH_EQ_EVERY):
- with HaskKeyCrasher(error_on_eq=True):
+ with HashKeyCrasher(error_on_eq=True):
with self.assertRaises(EqError):
h.get(KeyStr(i)) # really trigger __eq__
@@ -289,12 +282,12 @@ class BaseMapTest:
key = KeyStr(i)
if not (iter_i % CRASH_HASH_EVERY):
- with HaskKeyCrasher(error_on_hash=True):
+ with HashKeyCrasher(error_on_hash=True):
with self.assertRaises(HashingError):
h.delete(key)
if not (iter_i % CRASH_EQ_EVERY):
- with HaskKeyCrasher(error_on_eq=True):
+ with HashKeyCrasher(error_on_eq=True):
with self.assertRaises(EqError):
h.delete(KeyStr(i))
@@ -807,11 +800,11 @@ class BaseMapTest:
self.assertFalse(B in h)
with self.assertRaises(EqError):
- with HaskKeyCrasher(error_on_eq=True):
+ with HashKeyCrasher(error_on_eq=True):
AA in h
with self.assertRaises(HashingError):
- with HaskKeyCrasher(error_on_hash=True):
+ with HashKeyCrasher(error_on_hash=True):
AA in h
def test_map_getitem_1(self):
@@ -830,11 +823,11 @@ class BaseMapTest:
h[B]
with self.assertRaises(EqError):
- with HaskKeyCrasher(error_on_eq=True):
+ with HashKeyCrasher(error_on_eq=True):
h[AA]
with self.assertRaises(HashingError):
- with HaskKeyCrasher(error_on_hash=True):
+ with HashKeyCrasher(error_on_hash=True):
h[AA]
def test_repr_1(self):
@@ -850,11 +843,11 @@ class BaseMapTest:
A = HashKey(100, 'A')
with self.assertRaises(ReprError):
- with HaskKeyCrasher(error_on_repr=True):
+ with HashKeyCrasher(error_on_repr=True):
repr(h.set(1, 2).set(A, 3).set(3, 4))
with self.assertRaises(ReprError):
- with HaskKeyCrasher(error_on_repr=True):
+ with HashKeyCrasher(error_on_repr=True):
repr(h.set(1, 2).set(2, A).set(3, 4))
def test_repr_3(self):
@@ -895,12 +888,12 @@ class BaseMapTest:
m = h.set(1, 2).set(A, 3).set(3, 4)
with self.assertRaises(HashingError):
- with HaskKeyCrasher(error_on_hash=True):
+ with HashKeyCrasher(error_on_hash=True):
hash(m)
m = h.set(1, 2).set(2, A).set(3, 4)
with self.assertRaises(HashingError):
- with HaskKeyCrasher(error_on_hash=True):
+ with HashKeyCrasher(error_on_hash=True):
hash(m)
def test_abc_1(self):
@@ -983,6 +976,110 @@ class BaseMapTest:
hm2.delete('a')
self.assertNotEqual(hm1, hm2)
+ def test_map_mut_5(self):
+ h = self.Map({'a': 1, 'b': 2}, z=100)
+ self.assertTrue(isinstance(h, self.Map))
+ self.assertEqual(dict(h.items()), {'a': 1, 'b': 2, 'z': 100})
+
+ h2 = h.update(z=200, y=-1)
+ self.assertEqual(dict(h.items()), {'a': 1, 'b': 2, 'z': 100})
+ self.assertEqual(dict(h2.items()), {'a': 1, 'b': 2, 'z': 200, 'y': -1})
+
+ h3 = h2.update([(1, 2), (3, 4)])
+ self.assertEqual(dict(h.items()), {'a': 1, 'b': 2, 'z': 100})
+ self.assertEqual(dict(h2.items()), {'a': 1, 'b': 2, 'z': 200, 'y': -1})
+ self.assertEqual(dict(h3.items()),
+ {'a': 1, 'b': 2, 'z': 200, 'y': -1, 1: 2, 3: 4})
+
+ h4 = h3.update()
+ self.assertIs(h4, h3)
+
+ h5 = h4.update(self.Map({'zzz': 'yyz'}))
+
+ self.assertEqual(dict(h5.items()),
+ {'a': 1, 'b': 2, 'z': 200, 'y': -1, 1: 2, 3: 4,
+ 'zzz': 'yyz'})
+
+ def test_map_mut_6(self):
+ h = self.Map({'a': 1, 'b': 2}, z=100)
+ self.assertEqual(dict(h.items()), {'a': 1, 'b': 2, 'z': 100})
+
+ with self.assertRaisesRegex(TypeError, 'not iterable'):
+ h.update(1)
+
+ with self.assertRaisesRegex(ValueError, 'map update sequence element'):
+ h.update([(1, 2), (3, 4, 5)])
+
+ with self.assertRaisesRegex(TypeError, 'cannot convert map update'):
+ h.update([(1, 2), 1])
+
+ self.assertEqual(dict(h.items()), {'a': 1, 'b': 2, 'z': 100})
+
+ def test_map_mut_7(self):
+ key = HashKey(123, 'aaa')
+
+ h = self.Map({'a': 1, 'b': 2}, z=100)
+ self.assertEqual(dict(h.items()), {'a': 1, 'b': 2, 'z': 100})
+
+ upd = {key: 1}
+ with HashKeyCrasher(error_on_hash=True):
+ with self.assertRaises(HashingError):
+ h.update(upd)
+
+ upd = self.Map({key: 'zzz'})
+ with HashKeyCrasher(error_on_hash=True):
+ with self.assertRaises(HashingError):
+ h.update(upd)
+
+ upd = [(1, 2), (key, 'zzz')]
+ with HashKeyCrasher(error_on_hash=True):
+ with self.assertRaises(HashingError):
+ h.update(upd)
+
+ self.assertEqual(dict(h.items()), {'a': 1, 'b': 2, 'z': 100})
+
+ def test_map_mut_8(self):
+ key1 = HashKey(123, 'aaa')
+ key2 = HashKey(123, 'bbb')
+
+ h = self.Map({key1: 123})
+ self.assertEqual(dict(h.items()), {key1: 123})
+
+ upd = {key2: 1}
+ with HashKeyCrasher(error_on_eq=True):
+ with self.assertRaises(EqError):
+ h.update(upd)
+
+ upd = self.Map({key2: 'zzz'})
+ with HashKeyCrasher(error_on_eq=True):
+ with self.assertRaises(EqError):
+ h.update(upd)
+
+ upd = [(1, 2), (key2, 'zzz')]
+ with HashKeyCrasher(error_on_eq=True):
+ with self.assertRaises(EqError):
+ h.update(upd)
+
+ self.assertEqual(dict(h.items()), {key1: 123})
+
+ def test_map_mut_9(self):
+ key1 = HashKey(123, 'aaa')
+
+ src = {key1: 123}
+ with HashKeyCrasher(error_on_hash=True):
+ with self.assertRaises(HashingError):
+ self.Map(src)
+
+ src = self.Map({key1: 123})
+ with HashKeyCrasher(error_on_hash=True):
+ with self.assertRaises(HashingError):
+ self.Map(src)
+
+ src = [(1, 2), (key1, 123)]
+ with HashKeyCrasher(error_on_hash=True):
+ with self.assertRaises(HashingError):
+ self.Map(src)
+
def test_map_mut_stress(self):
COLLECTION_SIZE = 7000
TEST_ITERS_EVERY = 647