aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYury Selivanov <yury@magic.io>2018-04-07 16:17:34 -0400
committerYury Selivanov <yury@magic.io>2018-04-07 16:17:34 -0400
commit771954f43c70d3ba6546fafcaec40ba59a7c1d44 (patch)
tree681292cee5b3b12c88fdd5b2ef743f0e0fb630d9
parent40495427c5b1c6f42faec5b4b6418228aa6247c6 (diff)
downloadimmutables-771954f43c70d3ba6546fafcaec40ba59a7c1d44.tar.gz
immutables-771954f43c70d3ba6546fafcaec40ba59a7c1d44.zip
pymap: Streamline .assoc() method
-rw-r--r--immutables/map.py48
-rw-r--r--tests/test_map.py2
2 files changed, 23 insertions, 27 deletions
diff --git a/immutables/map.py b/immutables/map.py
index a7171b6..81612bc 100644
--- a/immutables/map.py
+++ b/immutables/map.py
@@ -52,7 +52,7 @@ class BitmapNode:
def clone(self):
return BitmapNode(self.size, self.bitmap, self.array.copy())
- def assoc(self, shift, hash, key, val, added_leaf):
+ def assoc(self, shift, hash, key, val):
bit = map_bitpos(hash, shift)
idx = map_bitindex(self.bitmap, bit)
@@ -64,22 +64,22 @@ class BitmapNode:
val_or_node = self.array[val_idx]
if key_or_null is None:
- sub_node = val_or_node.assoc(
- shift + 5, hash, key, val, added_leaf)
+ sub_node, added = val_or_node.assoc(
+ shift + 5, hash, key, val)
if val_or_node is sub_node:
- return self
+ return self, False
ret = self.clone()
ret.array[val_idx] = sub_node
- return ret
+ return ret, added
if key == key_or_null:
if val is val_or_node:
- return self
+ return self, False
ret = self.clone()
ret.array[val_idx] = val
- return ret
+ return ret, False
existing_key_hash = map_hash(key_or_null)
if existing_key_hash == hash:
@@ -87,21 +87,18 @@ class BitmapNode:
4, hash, [key_or_null, val_or_node, key, val])
else:
sub_node = BitmapNode(0, 0, [])
- sub_node = sub_node.assoc(
+ sub_node, _ = sub_node.assoc(
shift + 5, existing_key_hash,
- key_or_null, val_or_node, [False])
- sub_node = sub_node.assoc(
- shift + 5, hash, key, val, [False])
+ key_or_null, val_or_node)
+ sub_node, _ = sub_node.assoc(
+ shift + 5, hash, key, val)
ret = self.clone()
ret.array[key_idx] = None
ret.array[val_idx] = sub_node
- added_leaf[0] = True
- return ret
+ return ret, True
else:
- added_leaf[0] = True
-
key_idx = 2 * idx
val_idx = key_idx + 1
@@ -111,7 +108,7 @@ class BitmapNode:
new_array.append(key)
new_array.append(val)
new_array.extend(self.array[key_idx:])
- return BitmapNode(2 * (n + 1), self.bitmap | bit, new_array)
+ return BitmapNode(2 * (n + 1), self.bitmap | bit, new_array), True
def find(self, shift, hash, key):
bit = map_bitpos(hash, shift)
@@ -251,7 +248,7 @@ class CollisionNode:
return self.array[i + 1]
raise KeyError(key)
- def assoc(self, shift, hash, key, val, added_leaf):
+ def assoc(self, shift, hash, key, val):
if hash == self.hash:
key_idx = self.find_index(key)
@@ -260,21 +257,20 @@ class CollisionNode:
new_array.append(key)
new_array.append(val)
new_node = CollisionNode(self.size + 2, hash, new_array)
- added_leaf[0] = True
- return new_node
+ return new_node, True
val_idx = key_idx + 1
if self.array[val_idx] is val:
- return self
+ return self, False
new_array = self.array.copy()
new_array[val_idx] = val
- return CollisionNode(self.size, hash, new_array)
+ return CollisionNode(self.size, hash, new_array), False
else:
new_node = BitmapNode(
2, map_bitpos(self.hash, shift), [None, self])
- return new_node.assoc(shift, hash, key, val, added_leaf)
+ return new_node.assoc(shift, hash, key, val)
def without(self, shift, hash, key):
if hash != self.hash:
@@ -375,16 +371,14 @@ class Map:
return True
def set(self, key, val):
- added = [False]
-
new_count = self.__count
- new_root = self.__root.assoc(0, map_hash(key), key, val, added)
+ new_root, added = self.__root.assoc(0, map_hash(key), key, val)
if new_root is self.__root:
- assert not added[0]
+ assert not added
return self
- if added[0]:
+ if added:
new_count += 1
m = Map.__new__(Map)
diff --git a/tests/test_map.py b/tests/test_map.py
index c64258d..660f742 100644
--- a/tests/test_map.py
+++ b/tests/test_map.py
@@ -455,7 +455,9 @@ class BaseMapTest:
h = h.set(D, 'd')
h = h.set(E, 'e')
+ self.assertEqual(len(h), 5)
h = h.set(C, 'c') # trigger branch in CollisionNode.assoc
+ self.assertEqual(len(h), 5)
orig_len = len(h)