From a9dc2d83794a9cada687f6b92609fe6ef16c2bb9 Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Mon, 19 Nov 2018 22:02:25 -0500 Subject: Fix .keys() and other views to support being iterated more than once --- immutables/_map.c | 151 +++++++++++++++++++++++++++++++++++++++++++----------- immutables/_map.h | 20 ++++++-- immutables/map.py | 43 ++++++++++++---- tests/test_map.py | 28 ++++++++-- 4 files changed, 194 insertions(+), 48 deletions(-) diff --git a/immutables/_map.c b/immutables/_map.c index dbe5e01..8064754 100644 --- a/immutables/_map.c +++ b/immutables/_map.c @@ -2625,7 +2625,7 @@ error: static int map_baseiter_tp_clear(MapIterator *it) { - Py_CLEAR(it->hi_obj); + Py_CLEAR(it->mi_obj); return 0; } @@ -2640,7 +2640,7 @@ map_baseiter_tp_dealloc(MapIterator *it) static int map_baseiter_tp_traverse(MapIterator *it, visitproc visit, void *arg) { - Py_VISIT(it->hi_obj); + Py_VISIT(it->mi_obj); return 0; } @@ -2649,7 +2649,7 @@ map_baseiter_tp_iternext(MapIterator *it) { PyObject *key; PyObject *val; - map_iter_t res = map_iterator_next(&it->hi_iter, &key, &val); + map_iter_t res = map_iterator_next(&it->mi_iter, &key, &val); switch (res) { case I_END: @@ -2657,7 +2657,7 @@ map_baseiter_tp_iternext(MapIterator *it) return NULL; case I_ITEM: { - return (*(it->hi_yield))(key, val); + return (*(it->mi_yield))(key, val); } default: { @@ -2666,37 +2666,86 @@ map_baseiter_tp_iternext(MapIterator *it) } } +static int +map_baseview_tp_clear(MapView *view) +{ + Py_CLEAR(view->mv_obj); + Py_CLEAR(view->mv_itertype); + return 0; +} + +static void +map_baseview_tp_dealloc(MapView *view) +{ + PyObject_GC_UnTrack(view); + (void)map_baseview_tp_clear(view); + PyObject_GC_Del(view); +} + +static int +map_baseview_tp_traverse(MapView *view, visitproc visit, void *arg) +{ + Py_VISIT(view->mv_obj); + return 0; +} + static Py_ssize_t -map_baseiter_tp_len(MapIterator *it) +map_baseview_tp_len(MapView *view) { - return it->hi_obj->h_count; + return view->mv_obj->h_count; } -static PyMappingMethods MapIterator_as_mapping = { - (lenfunc)map_baseiter_tp_len, +static PyMappingMethods MapView_as_mapping = { + (lenfunc)map_baseview_tp_len, }; static PyObject * -map_baseiter_new(PyTypeObject *type, binaryfunc yield, MapObject *o) +map_baseview_newiter(PyTypeObject *type, binaryfunc yield, MapObject *map) { - MapIterator *it = PyObject_GC_New(MapIterator, type); - if (it == NULL) { + MapIterator *iter = PyObject_GC_New(MapIterator, type); + if (iter == NULL) { + return NULL; + } + + Py_INCREF(map); + iter->mi_obj = map; + iter->mi_yield = yield; + map_iterator_init(&iter->mi_iter, map->h_root); + + PyObject_GC_Track(iter); + return (PyObject *)iter; +} + +static PyObject * +map_baseview_iter(MapView *view) +{ + return map_baseview_newiter( + view->mv_itertype, view->mv_yield, view->mv_obj); +} + +static PyObject * +map_baseview_new(PyTypeObject *type, binaryfunc yield, + MapObject *o, PyTypeObject *itertype) +{ + MapView *view = PyObject_GC_New(MapView, type); + if (view == NULL) { return NULL; } Py_INCREF(o); - it->hi_obj = o; - it->hi_yield = yield; + view->mv_obj = o; + view->mv_yield = yield; - map_iterator_init(&it->hi_iter, o->h_root); + Py_INCREF(itertype); + view->mv_itertype = itertype; - return (PyObject*)it; + PyObject_GC_Track(view); + return (PyObject *)view; } #define ITERATOR_TYPE_SHARED_SLOTS \ .tp_basicsize = sizeof(MapIterator), \ .tp_itemsize = 0, \ - .tp_as_mapping = &MapIterator_as_mapping, \ .tp_dealloc = (destructor)map_baseiter_tp_dealloc, \ .tp_getattro = PyObject_GenericGetAttr, \ .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, \ @@ -2706,12 +2755,30 @@ map_baseiter_new(PyTypeObject *type, binaryfunc yield, MapObject *o) .tp_iternext = (iternextfunc)map_baseiter_tp_iternext, +#define VIEW_TYPE_SHARED_SLOTS \ + .tp_basicsize = sizeof(MapView), \ + .tp_itemsize = 0, \ + .tp_as_mapping = &MapView_as_mapping, \ + .tp_dealloc = (destructor)map_baseview_tp_dealloc, \ + .tp_getattro = PyObject_GenericGetAttr, \ + .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, \ + .tp_traverse = (traverseproc)map_baseview_tp_traverse, \ + .tp_clear = (inquiry)map_baseview_tp_clear, \ + .tp_iter = (getiterfunc)map_baseview_iter, \ + + /////////////////////////////////// _MapItems_Type PyTypeObject _MapItems_Type = { PyVarObject_HEAD_INIT(NULL, 0) "items", + VIEW_TYPE_SHARED_SLOTS +}; + +PyTypeObject _MapItemsIter_Type = { + PyVarObject_HEAD_INIT(NULL, 0) + "items_iterator", ITERATOR_TYPE_SHARED_SLOTS }; @@ -2722,10 +2789,11 @@ map_iter_yield_items(PyObject *key, PyObject *val) } static PyObject * -map_new_iteritems(MapObject *o) +map_new_items_view(MapObject *o) { - return map_baseiter_new( - &_MapItems_Type, map_iter_yield_items, o); + return map_baseview_new( + &_MapItems_Type, map_iter_yield_items, o, + &_MapItemsIter_Type); } @@ -2735,6 +2803,12 @@ map_new_iteritems(MapObject *o) PyTypeObject _MapKeys_Type = { PyVarObject_HEAD_INIT(NULL, 0) "keys", + VIEW_TYPE_SHARED_SLOTS +}; + +PyTypeObject _MapKeysIter_Type = { + PyVarObject_HEAD_INIT(NULL, 0) + "keys_iterator", ITERATOR_TYPE_SHARED_SLOTS }; @@ -2746,12 +2820,19 @@ map_iter_yield_keys(PyObject *key, PyObject *val) } static PyObject * -map_new_iterkeys(MapObject *o) +map_new_keys_iter(MapObject *o) { - return map_baseiter_new( - &_MapKeys_Type, map_iter_yield_keys, o); + return map_baseview_newiter( + &_MapKeysIter_Type, map_iter_yield_keys, o); } +static PyObject * +map_new_keys_view(MapObject *o) +{ + return map_baseview_new( + &_MapKeys_Type, map_iter_yield_keys, o, + &_MapKeysIter_Type); +} /////////////////////////////////// _MapValues_Type @@ -2759,6 +2840,12 @@ map_new_iterkeys(MapObject *o) PyTypeObject _MapValues_Type = { PyVarObject_HEAD_INIT(NULL, 0) "values", + VIEW_TYPE_SHARED_SLOTS +}; + +PyTypeObject _MapValuesIter_Type = { + PyVarObject_HEAD_INIT(NULL, 0) + "values_iterator", ITERATOR_TYPE_SHARED_SLOTS }; @@ -2770,10 +2857,11 @@ map_iter_yield_values(PyObject *key, PyObject *val) } static PyObject * -map_new_itervalues(MapObject *o) +map_new_values_view(MapObject *o) { - return map_baseiter_new( - &_MapValues_Type, map_iter_yield_values, o); + return map_baseview_new( + &_MapValues_Type, map_iter_yield_values, o, + &_MapValuesIter_Type); } @@ -2922,7 +3010,7 @@ map_tp_len(BaseMapObject *self) static PyObject * map_tp_iter(MapObject *self) { - return map_new_iterkeys(self); + return map_new_keys_iter(self); } static PyObject * @@ -3041,19 +3129,19 @@ map_py_update(MapObject *self, PyObject *args, PyObject *kwds) static PyObject * map_py_items(MapObject *self, PyObject *args) { - return map_new_iteritems(self); + return map_new_items_view(self); } static PyObject * map_py_values(MapObject *self, PyObject *args) { - return map_new_itervalues(self); + return map_new_values_view(self); } static PyObject * map_py_keys(MapObject *self, PyObject *args) { - return map_new_iterkeys(self); + return map_new_keys_view(self); } static PyObject * @@ -3844,7 +3932,10 @@ PyInit__map(void) (PyType_Ready(&_Map_CollisionNode_Type) < 0) || (PyType_Ready(&_MapKeys_Type) < 0) || (PyType_Ready(&_MapValues_Type) < 0) || - (PyType_Ready(&_MapItems_Type) < 0)) + (PyType_Ready(&_MapItems_Type) < 0) || + (PyType_Ready(&_MapKeysIter_Type) < 0) || + (PyType_Ready(&_MapValuesIter_Type) < 0) || + (PyType_Ready(&_MapItemsIter_Type) < 0)) { return 0; } diff --git a/immutables/_map.h b/immutables/_map.h index 483865c..dd12af9 100644 --- a/immutables/_map.h +++ b/immutables/_map.h @@ -72,14 +72,25 @@ typedef struct { just a key for the 'Keys' iterator, and a value for the 'Values' iterator. */ + +typedef struct { + PyObject_HEAD + MapObject *mv_obj; + binaryfunc mv_yield; + PyTypeObject *mv_itertype; +} MapView; + typedef struct { PyObject_HEAD - MapObject *hi_obj; - MapIteratorState hi_iter; - binaryfunc hi_yield; + MapObject *mi_obj; + binaryfunc mi_yield; + MapIteratorState mi_iter; } MapIterator; +/* PyTypes */ + + PyTypeObject _Map_Type; PyTypeObject _MapMutation_Type; PyTypeObject _Map_ArrayNode_Type; @@ -88,6 +99,9 @@ PyTypeObject _Map_CollisionNode_Type; PyTypeObject _MapKeys_Type; PyTypeObject _MapValues_Type; PyTypeObject _MapItems_Type; +PyTypeObject _MapKeysIter_Type; +PyTypeObject _MapValuesIter_Type; +PyTypeObject _MapItemsIter_Type; #endif diff --git a/immutables/map.py b/immutables/map.py index 9884956..f498d38 100644 --- a/immutables/map.py +++ b/immutables/map.py @@ -391,20 +391,43 @@ class CollisionNode: buf.append('{}{!r}: {!r}'.format(pad, key, val)) -class GenWrapper: +class MapKeys: - def __init__(self, count, gen): - self.__count = count - self.__gen = gen + def __init__(self, c, m): + self.__count = c + self.__root = m + + def __len__(self): + return self.__count + + def __iter__(self): + return iter(self.__root.keys()) + + +class MapValues: + + def __init__(self, c, m): + self.__count = c + self.__root = m def __len__(self): return self.__count def __iter__(self): - return self + return iter(self.__root.values()) + + +class MapItems: + + def __init__(self, c, m): + self.__count = c + self.__root = m - def __next__(self): - return next(self.__gen) + def __len__(self): + return self.__count + + def __iter__(self): + return iter(self.__root.items()) class Map: @@ -544,13 +567,13 @@ class Map: yield from self.__root.keys() def keys(self): - return GenWrapper(self.__count, self.__root.keys()) + return MapKeys(self.__count, self.__root) def values(self): - return GenWrapper(self.__count, self.__root.values()) + return MapValues(self.__count, self.__root) def items(self): - return GenWrapper(self.__count, self.__root.items()) + return MapItems(self.__count, self.__root) def __hash__(self): if self.__hash != -1: diff --git a/tests/test_map.py b/tests/test_map.py index afa2e2e..13c6cb1 100644 --- a/tests/test_map.py +++ b/tests/test_map.py @@ -637,6 +637,17 @@ class BaseMapTest: set(list(it)), {(A, 'a'), (B, 'b'), (C, 'c'), (D, 'd'), (E, 'e'), (F, 'f')}) + def test_map_items_3(self): + h = self.Map() + self.assertEqual(len(h.items()), 0) + self.assertEqual(list(h.items()), []) + + def test_map_items_4(self): + h = self.Map(a=1, b=2, c=3) + k = h.items() + self.assertEqual(set(k), {('a', 1), ('b', 2), ('c', 3)}) + self.assertEqual(set(k), {('a', 1), ('b', 2), ('c', 3)}) + def test_map_keys_1(self): A = HashKey(100, 'A') B = HashKey(101, 'B') @@ -656,6 +667,12 @@ class BaseMapTest: self.assertEqual(set(list(h.keys())), {A, B, C, D, E, F}) self.assertEqual(set(list(h)), {A, B, C, D, E, F}) + def test_map_keys_2(self): + h = self.Map(a=1, b=2, c=3) + k = h.keys() + self.assertEqual(set(k), {'a', 'b', 'c'}) + self.assertEqual(set(k), {'a', 'b', 'c'}) + def test_map_values_1(self): A = HashKey(100, 'A') B = HashKey(101, 'B') @@ -674,10 +691,11 @@ class BaseMapTest: self.assertEqual(set(list(h.values())), {'a', 'b', 'c', 'd', 'e', 'f'}) - def test_map_items_3(self): - h = self.Map() - self.assertEqual(len(h.items()), 0) - self.assertEqual(list(h.items()), []) + def test_map_values_2(self): + h = self.Map(a=1, b=2, c=3) + k = h.values() + self.assertEqual(set(k), {1, 2, 3}) + self.assertEqual(set(k), {1, 2, 3}) def test_map_eq_1(self): A = HashKey(100, 'A') @@ -776,7 +794,7 @@ class BaseMapTest: h = h.set(A, h) ref = weakref.ref(h) - hi = h.items() + hi = iter(h.items()) next(hi) del h, hi -- cgit v1.2.3