diff options
-rw-r--r-- | immutables/map.py | 48 | ||||
-rw-r--r-- | tests/test_map.py | 2 |
2 files changed, 23 insertions, 27 deletions
diff --git a/immutables/map.py b/immutables/map.py index a7171b6..81612bc 100644 --- a/immutables/map.py +++ b/immutables/map.py @@ -52,7 +52,7 @@ class BitmapNode: def clone(self): return BitmapNode(self.size, self.bitmap, self.array.copy()) - def assoc(self, shift, hash, key, val, added_leaf): + def assoc(self, shift, hash, key, val): bit = map_bitpos(hash, shift) idx = map_bitindex(self.bitmap, bit) @@ -64,22 +64,22 @@ class BitmapNode: val_or_node = self.array[val_idx] if key_or_null is None: - sub_node = val_or_node.assoc( - shift + 5, hash, key, val, added_leaf) + sub_node, added = val_or_node.assoc( + shift + 5, hash, key, val) if val_or_node is sub_node: - return self + return self, False ret = self.clone() ret.array[val_idx] = sub_node - return ret + return ret, added if key == key_or_null: if val is val_or_node: - return self + return self, False ret = self.clone() ret.array[val_idx] = val - return ret + return ret, False existing_key_hash = map_hash(key_or_null) if existing_key_hash == hash: @@ -87,21 +87,18 @@ class BitmapNode: 4, hash, [key_or_null, val_or_node, key, val]) else: sub_node = BitmapNode(0, 0, []) - sub_node = sub_node.assoc( + sub_node, _ = sub_node.assoc( shift + 5, existing_key_hash, - key_or_null, val_or_node, [False]) - sub_node = sub_node.assoc( - shift + 5, hash, key, val, [False]) + key_or_null, val_or_node) + sub_node, _ = sub_node.assoc( + shift + 5, hash, key, val) ret = self.clone() ret.array[key_idx] = None ret.array[val_idx] = sub_node - added_leaf[0] = True - return ret + return ret, True else: - added_leaf[0] = True - key_idx = 2 * idx val_idx = key_idx + 1 @@ -111,7 +108,7 @@ class BitmapNode: new_array.append(key) new_array.append(val) new_array.extend(self.array[key_idx:]) - return BitmapNode(2 * (n + 1), self.bitmap | bit, new_array) + return BitmapNode(2 * (n + 1), self.bitmap | bit, new_array), True def find(self, shift, hash, key): bit = map_bitpos(hash, shift) @@ -251,7 +248,7 @@ class CollisionNode: return self.array[i + 1] raise KeyError(key) - def assoc(self, shift, hash, key, val, added_leaf): + def assoc(self, shift, hash, key, val): if hash == self.hash: key_idx = self.find_index(key) @@ -260,21 +257,20 @@ class CollisionNode: new_array.append(key) new_array.append(val) new_node = CollisionNode(self.size + 2, hash, new_array) - added_leaf[0] = True - return new_node + return new_node, True val_idx = key_idx + 1 if self.array[val_idx] is val: - return self + return self, False new_array = self.array.copy() new_array[val_idx] = val - return CollisionNode(self.size, hash, new_array) + return CollisionNode(self.size, hash, new_array), False else: new_node = BitmapNode( 2, map_bitpos(self.hash, shift), [None, self]) - return new_node.assoc(shift, hash, key, val, added_leaf) + return new_node.assoc(shift, hash, key, val) def without(self, shift, hash, key): if hash != self.hash: @@ -375,16 +371,14 @@ class Map: return True def set(self, key, val): - added = [False] - new_count = self.__count - new_root = self.__root.assoc(0, map_hash(key), key, val, added) + new_root, added = self.__root.assoc(0, map_hash(key), key, val) if new_root is self.__root: - assert not added[0] + assert not added return self - if added[0]: + if added: new_count += 1 m = Map.__new__(Map) diff --git a/tests/test_map.py b/tests/test_map.py index c64258d..660f742 100644 --- a/tests/test_map.py +++ b/tests/test_map.py @@ -455,7 +455,9 @@ class BaseMapTest: h = h.set(D, 'd') h = h.set(E, 'e') + self.assertEqual(len(h), 5) h = h.set(C, 'c') # trigger branch in CollisionNode.assoc + self.assertEqual(len(h), 5) orig_len = len(h) |