summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYury Selivanov <yury@magic.io>2018-11-19 22:02:25 -0500
committerYury Selivanov <yury@magic.io>2018-11-20 14:25:07 -0500
commita9dc2d83794a9cada687f6b92609fe6ef16c2bb9 (patch)
tree0d0eb43f576d9326f089254373569e0dbed8d3f9
parent666613d4a5400fd1ce0267a0bcf9cd70afae7392 (diff)
downloadimmutables-a9dc2d83794a9cada687f6b92609fe6ef16c2bb9.tar.gz
immutables-a9dc2d83794a9cada687f6b92609fe6ef16c2bb9.zip
Fix .keys() and other views to support being iterated more than once
-rw-r--r--immutables/_map.c151
-rw-r--r--immutables/_map.h20
-rw-r--r--immutables/map.py43
-rw-r--r--tests/test_map.py28
4 files changed, 194 insertions, 48 deletions
diff --git a/immutables/_map.c b/immutables/_map.c
index dbe5e01..8064754 100644
--- a/immutables/_map.c
+++ b/immutables/_map.c
@@ -2625,7 +2625,7 @@ error:
static int
map_baseiter_tp_clear(MapIterator *it)
{
- Py_CLEAR(it->hi_obj);
+ Py_CLEAR(it->mi_obj);
return 0;
}
@@ -2640,7 +2640,7 @@ map_baseiter_tp_dealloc(MapIterator *it)
static int
map_baseiter_tp_traverse(MapIterator *it, visitproc visit, void *arg)
{
- Py_VISIT(it->hi_obj);
+ Py_VISIT(it->mi_obj);
return 0;
}
@@ -2649,7 +2649,7 @@ map_baseiter_tp_iternext(MapIterator *it)
{
PyObject *key;
PyObject *val;
- map_iter_t res = map_iterator_next(&it->hi_iter, &key, &val);
+ map_iter_t res = map_iterator_next(&it->mi_iter, &key, &val);
switch (res) {
case I_END:
@@ -2657,7 +2657,7 @@ map_baseiter_tp_iternext(MapIterator *it)
return NULL;
case I_ITEM: {
- return (*(it->hi_yield))(key, val);
+ return (*(it->mi_yield))(key, val);
}
default: {
@@ -2666,37 +2666,86 @@ map_baseiter_tp_iternext(MapIterator *it)
}
}
+static int
+map_baseview_tp_clear(MapView *view)
+{
+ Py_CLEAR(view->mv_obj);
+ Py_CLEAR(view->mv_itertype);
+ return 0;
+}
+
+static void
+map_baseview_tp_dealloc(MapView *view)
+{
+ PyObject_GC_UnTrack(view);
+ (void)map_baseview_tp_clear(view);
+ PyObject_GC_Del(view);
+}
+
+static int
+map_baseview_tp_traverse(MapView *view, visitproc visit, void *arg)
+{
+ Py_VISIT(view->mv_obj);
+ return 0;
+}
+
static Py_ssize_t
-map_baseiter_tp_len(MapIterator *it)
+map_baseview_tp_len(MapView *view)
{
- return it->hi_obj->h_count;
+ return view->mv_obj->h_count;
}
-static PyMappingMethods MapIterator_as_mapping = {
- (lenfunc)map_baseiter_tp_len,
+static PyMappingMethods MapView_as_mapping = {
+ (lenfunc)map_baseview_tp_len,
};
static PyObject *
-map_baseiter_new(PyTypeObject *type, binaryfunc yield, MapObject *o)
+map_baseview_newiter(PyTypeObject *type, binaryfunc yield, MapObject *map)
{
- MapIterator *it = PyObject_GC_New(MapIterator, type);
- if (it == NULL) {
+ MapIterator *iter = PyObject_GC_New(MapIterator, type);
+ if (iter == NULL) {
+ return NULL;
+ }
+
+ Py_INCREF(map);
+ iter->mi_obj = map;
+ iter->mi_yield = yield;
+ map_iterator_init(&iter->mi_iter, map->h_root);
+
+ PyObject_GC_Track(iter);
+ return (PyObject *)iter;
+}
+
+static PyObject *
+map_baseview_iter(MapView *view)
+{
+ return map_baseview_newiter(
+ view->mv_itertype, view->mv_yield, view->mv_obj);
+}
+
+static PyObject *
+map_baseview_new(PyTypeObject *type, binaryfunc yield,
+ MapObject *o, PyTypeObject *itertype)
+{
+ MapView *view = PyObject_GC_New(MapView, type);
+ if (view == NULL) {
return NULL;
}
Py_INCREF(o);
- it->hi_obj = o;
- it->hi_yield = yield;
+ view->mv_obj = o;
+ view->mv_yield = yield;
- map_iterator_init(&it->hi_iter, o->h_root);
+ Py_INCREF(itertype);
+ view->mv_itertype = itertype;
- return (PyObject*)it;
+ PyObject_GC_Track(view);
+ return (PyObject *)view;
}
#define ITERATOR_TYPE_SHARED_SLOTS \
.tp_basicsize = sizeof(MapIterator), \
.tp_itemsize = 0, \
- .tp_as_mapping = &MapIterator_as_mapping, \
.tp_dealloc = (destructor)map_baseiter_tp_dealloc, \
.tp_getattro = PyObject_GenericGetAttr, \
.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, \
@@ -2706,12 +2755,30 @@ map_baseiter_new(PyTypeObject *type, binaryfunc yield, MapObject *o)
.tp_iternext = (iternextfunc)map_baseiter_tp_iternext,
+#define VIEW_TYPE_SHARED_SLOTS \
+ .tp_basicsize = sizeof(MapView), \
+ .tp_itemsize = 0, \
+ .tp_as_mapping = &MapView_as_mapping, \
+ .tp_dealloc = (destructor)map_baseview_tp_dealloc, \
+ .tp_getattro = PyObject_GenericGetAttr, \
+ .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, \
+ .tp_traverse = (traverseproc)map_baseview_tp_traverse, \
+ .tp_clear = (inquiry)map_baseview_tp_clear, \
+ .tp_iter = (getiterfunc)map_baseview_iter, \
+
+
/////////////////////////////////// _MapItems_Type
PyTypeObject _MapItems_Type = {
PyVarObject_HEAD_INIT(NULL, 0)
"items",
+ VIEW_TYPE_SHARED_SLOTS
+};
+
+PyTypeObject _MapItemsIter_Type = {
+ PyVarObject_HEAD_INIT(NULL, 0)
+ "items_iterator",
ITERATOR_TYPE_SHARED_SLOTS
};
@@ -2722,10 +2789,11 @@ map_iter_yield_items(PyObject *key, PyObject *val)
}
static PyObject *
-map_new_iteritems(MapObject *o)
+map_new_items_view(MapObject *o)
{
- return map_baseiter_new(
- &_MapItems_Type, map_iter_yield_items, o);
+ return map_baseview_new(
+ &_MapItems_Type, map_iter_yield_items, o,
+ &_MapItemsIter_Type);
}
@@ -2735,6 +2803,12 @@ map_new_iteritems(MapObject *o)
PyTypeObject _MapKeys_Type = {
PyVarObject_HEAD_INIT(NULL, 0)
"keys",
+ VIEW_TYPE_SHARED_SLOTS
+};
+
+PyTypeObject _MapKeysIter_Type = {
+ PyVarObject_HEAD_INIT(NULL, 0)
+ "keys_iterator",
ITERATOR_TYPE_SHARED_SLOTS
};
@@ -2746,12 +2820,19 @@ map_iter_yield_keys(PyObject *key, PyObject *val)
}
static PyObject *
-map_new_iterkeys(MapObject *o)
+map_new_keys_iter(MapObject *o)
{
- return map_baseiter_new(
- &_MapKeys_Type, map_iter_yield_keys, o);
+ return map_baseview_newiter(
+ &_MapKeysIter_Type, map_iter_yield_keys, o);
}
+static PyObject *
+map_new_keys_view(MapObject *o)
+{
+ return map_baseview_new(
+ &_MapKeys_Type, map_iter_yield_keys, o,
+ &_MapKeysIter_Type);
+}
/////////////////////////////////// _MapValues_Type
@@ -2759,6 +2840,12 @@ map_new_iterkeys(MapObject *o)
PyTypeObject _MapValues_Type = {
PyVarObject_HEAD_INIT(NULL, 0)
"values",
+ VIEW_TYPE_SHARED_SLOTS
+};
+
+PyTypeObject _MapValuesIter_Type = {
+ PyVarObject_HEAD_INIT(NULL, 0)
+ "values_iterator",
ITERATOR_TYPE_SHARED_SLOTS
};
@@ -2770,10 +2857,11 @@ map_iter_yield_values(PyObject *key, PyObject *val)
}
static PyObject *
-map_new_itervalues(MapObject *o)
+map_new_values_view(MapObject *o)
{
- return map_baseiter_new(
- &_MapValues_Type, map_iter_yield_values, o);
+ return map_baseview_new(
+ &_MapValues_Type, map_iter_yield_values, o,
+ &_MapValuesIter_Type);
}
@@ -2922,7 +3010,7 @@ map_tp_len(BaseMapObject *self)
static PyObject *
map_tp_iter(MapObject *self)
{
- return map_new_iterkeys(self);
+ return map_new_keys_iter(self);
}
static PyObject *
@@ -3041,19 +3129,19 @@ map_py_update(MapObject *self, PyObject *args, PyObject *kwds)
static PyObject *
map_py_items(MapObject *self, PyObject *args)
{
- return map_new_iteritems(self);
+ return map_new_items_view(self);
}
static PyObject *
map_py_values(MapObject *self, PyObject *args)
{
- return map_new_itervalues(self);
+ return map_new_values_view(self);
}
static PyObject *
map_py_keys(MapObject *self, PyObject *args)
{
- return map_new_iterkeys(self);
+ return map_new_keys_view(self);
}
static PyObject *
@@ -3844,7 +3932,10 @@ PyInit__map(void)
(PyType_Ready(&_Map_CollisionNode_Type) < 0) ||
(PyType_Ready(&_MapKeys_Type) < 0) ||
(PyType_Ready(&_MapValues_Type) < 0) ||
- (PyType_Ready(&_MapItems_Type) < 0))
+ (PyType_Ready(&_MapItems_Type) < 0) ||
+ (PyType_Ready(&_MapKeysIter_Type) < 0) ||
+ (PyType_Ready(&_MapValuesIter_Type) < 0) ||
+ (PyType_Ready(&_MapItemsIter_Type) < 0))
{
return 0;
}
diff --git a/immutables/_map.h b/immutables/_map.h
index 483865c..dd12af9 100644
--- a/immutables/_map.h
+++ b/immutables/_map.h
@@ -72,14 +72,25 @@ typedef struct {
just a key for the 'Keys' iterator, and a value for the 'Values'
iterator.
*/
+
+typedef struct {
+ PyObject_HEAD
+ MapObject *mv_obj;
+ binaryfunc mv_yield;
+ PyTypeObject *mv_itertype;
+} MapView;
+
typedef struct {
PyObject_HEAD
- MapObject *hi_obj;
- MapIteratorState hi_iter;
- binaryfunc hi_yield;
+ MapObject *mi_obj;
+ binaryfunc mi_yield;
+ MapIteratorState mi_iter;
} MapIterator;
+/* PyTypes */
+
+
PyTypeObject _Map_Type;
PyTypeObject _MapMutation_Type;
PyTypeObject _Map_ArrayNode_Type;
@@ -88,6 +99,9 @@ PyTypeObject _Map_CollisionNode_Type;
PyTypeObject _MapKeys_Type;
PyTypeObject _MapValues_Type;
PyTypeObject _MapItems_Type;
+PyTypeObject _MapKeysIter_Type;
+PyTypeObject _MapValuesIter_Type;
+PyTypeObject _MapItemsIter_Type;
#endif
diff --git a/immutables/map.py b/immutables/map.py
index 9884956..f498d38 100644
--- a/immutables/map.py
+++ b/immutables/map.py
@@ -391,20 +391,43 @@ class CollisionNode:
buf.append('{}{!r}: {!r}'.format(pad, key, val))
-class GenWrapper:
+class MapKeys:
- def __init__(self, count, gen):
- self.__count = count
- self.__gen = gen
+ def __init__(self, c, m):
+ self.__count = c
+ self.__root = m
+
+ def __len__(self):
+ return self.__count
+
+ def __iter__(self):
+ return iter(self.__root.keys())
+
+
+class MapValues:
+
+ def __init__(self, c, m):
+ self.__count = c
+ self.__root = m
def __len__(self):
return self.__count
def __iter__(self):
- return self
+ return iter(self.__root.values())
+
+
+class MapItems:
+
+ def __init__(self, c, m):
+ self.__count = c
+ self.__root = m
- def __next__(self):
- return next(self.__gen)
+ def __len__(self):
+ return self.__count
+
+ def __iter__(self):
+ return iter(self.__root.items())
class Map:
@@ -544,13 +567,13 @@ class Map:
yield from self.__root.keys()
def keys(self):
- return GenWrapper(self.__count, self.__root.keys())
+ return MapKeys(self.__count, self.__root)
def values(self):
- return GenWrapper(self.__count, self.__root.values())
+ return MapValues(self.__count, self.__root)
def items(self):
- return GenWrapper(self.__count, self.__root.items())
+ return MapItems(self.__count, self.__root)
def __hash__(self):
if self.__hash != -1:
diff --git a/tests/test_map.py b/tests/test_map.py
index afa2e2e..13c6cb1 100644
--- a/tests/test_map.py
+++ b/tests/test_map.py
@@ -637,6 +637,17 @@ class BaseMapTest:
set(list(it)),
{(A, 'a'), (B, 'b'), (C, 'c'), (D, 'd'), (E, 'e'), (F, 'f')})
+ def test_map_items_3(self):
+ h = self.Map()
+ self.assertEqual(len(h.items()), 0)
+ self.assertEqual(list(h.items()), [])
+
+ def test_map_items_4(self):
+ h = self.Map(a=1, b=2, c=3)
+ k = h.items()
+ self.assertEqual(set(k), {('a', 1), ('b', 2), ('c', 3)})
+ self.assertEqual(set(k), {('a', 1), ('b', 2), ('c', 3)})
+
def test_map_keys_1(self):
A = HashKey(100, 'A')
B = HashKey(101, 'B')
@@ -656,6 +667,12 @@ class BaseMapTest:
self.assertEqual(set(list(h.keys())), {A, B, C, D, E, F})
self.assertEqual(set(list(h)), {A, B, C, D, E, F})
+ def test_map_keys_2(self):
+ h = self.Map(a=1, b=2, c=3)
+ k = h.keys()
+ self.assertEqual(set(k), {'a', 'b', 'c'})
+ self.assertEqual(set(k), {'a', 'b', 'c'})
+
def test_map_values_1(self):
A = HashKey(100, 'A')
B = HashKey(101, 'B')
@@ -674,10 +691,11 @@ class BaseMapTest:
self.assertEqual(set(list(h.values())), {'a', 'b', 'c', 'd', 'e', 'f'})
- def test_map_items_3(self):
- h = self.Map()
- self.assertEqual(len(h.items()), 0)
- self.assertEqual(list(h.items()), [])
+ def test_map_values_2(self):
+ h = self.Map(a=1, b=2, c=3)
+ k = h.values()
+ self.assertEqual(set(k), {1, 2, 3})
+ self.assertEqual(set(k), {1, 2, 3})
def test_map_eq_1(self):
A = HashKey(100, 'A')
@@ -776,7 +794,7 @@ class BaseMapTest:
h = h.set(A, h)
ref = weakref.ref(h)
- hi = h.items()
+ hi = iter(h.items())
next(hi)
del h, hi