aboutsummaryrefslogtreecommitdiff
path: root/immutables/map.py
diff options
context:
space:
mode:
Diffstat (limited to 'immutables/map.py')
-rw-r--r--immutables/map.py272
1 files changed, 213 insertions, 59 deletions
diff --git a/immutables/map.py b/immutables/map.py
index 81612bc..8c7d862 100644
--- a/immutables/map.py
+++ b/immutables/map.py
@@ -1,4 +1,5 @@
import collections.abc
+import itertools
import reprlib
import sys
@@ -6,6 +7,10 @@ import sys
__all__ = ('Map',)
+# Thread-safe counter.
+_mut_id = itertools.count(1).__next__
+
+
# Python version of _map.c. The topmost comment there explains
# all datastructures and algorithms.
# The code here follows C code closely on purpose to make
@@ -43,16 +48,17 @@ W_EMPTY, W_NEWNODE, W_NOT_FOUND = range(3)
class BitmapNode:
- def __init__(self, size, bitmap, array):
+ def __init__(self, size, bitmap, array, mutid):
self.size = size
self.bitmap = bitmap
assert isinstance(array, list) and len(array) == size
self.array = array
+ self.mutid = mutid
- def clone(self):
- return BitmapNode(self.size, self.bitmap, self.array.copy())
+ def clone(self, mutid):
+ return BitmapNode(self.size, self.bitmap, self.array.copy(), mutid)
- def assoc(self, shift, hash, key, val):
+ def assoc(self, shift, hash, key, val, mutid):
bit = map_bitpos(hash, shift)
idx = map_bitindex(self.bitmap, bit)
@@ -65,38 +71,53 @@ class BitmapNode:
if key_or_null is None:
sub_node, added = val_or_node.assoc(
- shift + 5, hash, key, val)
+ shift + 5, hash, key, val, mutid)
if val_or_node is sub_node:
- return self, False
+ return self, added
- ret = self.clone()
- ret.array[val_idx] = sub_node
- return ret, added
+ if mutid and mutid == self.mutid:
+ self.array[val_idx] = sub_node
+ return self, added
+ else:
+ ret = self.clone(mutid)
+ ret.array[val_idx] = sub_node
+ return ret, added
if key == key_or_null:
if val is val_or_node:
return self, False
- ret = self.clone()
- ret.array[val_idx] = val
- return ret, False
+ if mutid and mutid == self.mutid:
+ self.array[val_idx] = val
+ return self, False
+ else:
+ ret = self.clone(mutid)
+ ret.array[val_idx] = val
+ return ret, False
existing_key_hash = map_hash(key_or_null)
if existing_key_hash == hash:
sub_node = CollisionNode(
- 4, hash, [key_or_null, val_or_node, key, val])
+ 4, hash, [key_or_null, val_or_node, key, val], mutid)
else:
- sub_node = BitmapNode(0, 0, [])
+ sub_node = BitmapNode(0, 0, [], mutid)
sub_node, _ = sub_node.assoc(
shift + 5, existing_key_hash,
- key_or_null, val_or_node)
+ key_or_null, val_or_node,
+ mutid)
sub_node, _ = sub_node.assoc(
- shift + 5, hash, key, val)
+ shift + 5, hash, key, val,
+ mutid)
- ret = self.clone()
- ret.array[key_idx] = None
- ret.array[val_idx] = sub_node
- return ret, True
+ if mutid and mutid == self.mutid:
+ self.array[key_idx] = None
+ self.array[val_idx] = sub_node
+ return self, True
+ else:
+ ret = self.clone(mutid)
+ ret.array[key_idx] = None
+ ret.array[val_idx] = sub_node
+ return ret, True
else:
key_idx = 2 * idx
@@ -108,7 +129,15 @@ class BitmapNode:
new_array.append(key)
new_array.append(val)
new_array.extend(self.array[key_idx:])
- return BitmapNode(2 * (n + 1), self.bitmap | bit, new_array), True
+
+ if mutid and mutid == self.mutid:
+ self.size = 2 * (n + 1)
+ self.bitmap |= bit
+ self.array = new_array
+ return self, True
+ else:
+ return BitmapNode(
+ 2 * (n + 1), self.bitmap | bit, new_array, mutid), True
def find(self, shift, hash, key):
bit = map_bitpos(hash, shift)
@@ -131,7 +160,7 @@ class BitmapNode:
raise KeyError(key)
- def without(self, shift, hash, key):
+ def without(self, shift, hash, key, mutid):
bit = map_bitpos(hash, shift)
if not (self.bitmap & bit):
return W_NOT_FOUND, None
@@ -144,7 +173,7 @@ class BitmapNode:
val_or_node = self.array[val_idx]
if key_or_null is None:
- res, sub_node = val_or_node.without(shift + 5, hash, key)
+ res, sub_node = val_or_node.without(shift + 5, hash, key, mutid)
if res is W_EMPTY:
raise RuntimeError('unreachable code') # pragma: no cover
@@ -153,14 +182,24 @@ class BitmapNode:
if (type(sub_node) is BitmapNode and
sub_node.size == 2 and
sub_node.array[0] is not None):
- clone = self.clone()
- clone.array[key_idx] = sub_node.array[0]
- clone.array[val_idx] = sub_node.array[1]
- return W_NEWNODE, clone
- clone = self.clone()
- clone.array[val_idx] = sub_node
- return W_NEWNODE, clone
+ if mutid and mutid == self.mutid:
+ self.array[key_idx] = sub_node.array[0]
+ self.array[val_idx] = sub_node.array[1]
+ return W_NEWNODE, self
+ else:
+ clone = self.clone(mutid)
+ clone.array[key_idx] = sub_node.array[0]
+ clone.array[val_idx] = sub_node.array[1]
+ return W_NEWNODE, clone
+
+ if mutid and mutid == self.mutid:
+ self.array[val_idx] = sub_node
+ return W_NEWNODE, self
+ else:
+ clone = self.clone(mutid)
+ clone.array[val_idx] = sub_node
+ return W_NEWNODE, clone
else:
assert sub_node is None
@@ -173,9 +212,16 @@ class BitmapNode:
new_array = self.array[:key_idx]
new_array.extend(self.array[val_idx + 1:])
- new_node = BitmapNode(
- self.size - 2, self.bitmap & ~bit, new_array)
- return W_NEWNODE, new_node
+
+ if mutid and mutid == self.mutid:
+ self.size -= 2
+ self.bitmap &= ~bit
+ self.array = new_array
+ return W_NEWNODE, self
+ else:
+ new_node = BitmapNode(
+ self.size - 2, self.bitmap & ~bit, new_array, mutid)
+ return W_NEWNODE, new_node
else:
return W_NOT_FOUND, None
@@ -231,10 +277,11 @@ class BitmapNode:
class CollisionNode:
- def __init__(self, size, hash, array):
+ def __init__(self, size, hash, array, mutid):
self.size = size
self.hash = hash
self.array = array
+ self.mutid = mutid
def find_index(self, key):
for i in range(0, self.size, 2):
@@ -248,7 +295,7 @@ class CollisionNode:
return self.array[i + 1]
raise KeyError(key)
- def assoc(self, shift, hash, key, val):
+ def assoc(self, shift, hash, key, val, mutid):
if hash == self.hash:
key_idx = self.find_index(key)
@@ -256,23 +303,34 @@ class CollisionNode:
new_array = self.array.copy()
new_array.append(key)
new_array.append(val)
- new_node = CollisionNode(self.size + 2, hash, new_array)
- return new_node, True
+
+ if mutid and mutid == self.mutid:
+ self.size += 2
+ self.array = new_array
+ return self, True
+ else:
+ new_node = CollisionNode(
+ self.size + 2, hash, new_array, mutid)
+ return new_node, True
val_idx = key_idx + 1
if self.array[val_idx] is val:
return self, False
- new_array = self.array.copy()
- new_array[val_idx] = val
- return CollisionNode(self.size, hash, new_array), False
+ if mutid and mutid == self.mutid:
+ self.array[val_idx] = val
+ return self, False
+ else:
+ new_array = self.array.copy()
+ new_array[val_idx] = val
+ return CollisionNode(self.size, hash, new_array, mutid), False
else:
new_node = BitmapNode(
- 2, map_bitpos(self.hash, shift), [None, self])
- return new_node.assoc(shift, hash, key, val)
+ 2, map_bitpos(self.hash, shift), [None, self], mutid)
+ return new_node.assoc(shift, hash, key, val, mutid)
- def without(self, shift, hash, key):
+ def without(self, shift, hash, key, mutid):
if hash != self.hash:
return W_NOT_FOUND, None
@@ -292,13 +350,20 @@ class CollisionNode:
assert key_idx == 2
new_array = [self.array[0], self.array[1]]
- new_node = BitmapNode(2, map_bitpos(hash, shift), new_array)
+ new_node = BitmapNode(
+ 2, map_bitpos(hash, shift), new_array, mutid)
return W_NEWNODE, new_node
new_array = self.array[:key_idx]
new_array.extend(self.array[key_idx + 2:])
- new_node = CollisionNode(self.size - 2, self.hash, new_array)
- return W_NEWNODE, new_node
+ if mutid and mutid == self.mutid:
+ self.array = new_array
+ self.size -= 2
+ return W_NEWNODE, self
+ else:
+ new_node = CollisionNode(
+ self.size - 2, self.hash, new_array, mutid)
+ return W_NEWNODE, new_node
def keys(self):
for i in range(0, self.size, 2):
@@ -346,9 +411,17 @@ class Map:
def __init__(self):
self.__count = 0
- self.__root = BitmapNode(0, 0, [])
+ self.__root = BitmapNode(0, 0, [], 0)
self.__hash = -1
+ @classmethod
+ def _new(cls, count, root):
+ m = Map.__new__(Map)
+ m.__count = count
+ m.__root = root
+ m.__hash = -1
+ return m
+
def __len__(self):
return self.__count
@@ -370,9 +443,12 @@ class Map:
return True
+ def mutate(self):
+ return MapMutation(self.__count, self.__root)
+
def set(self, key, val):
new_count = self.__count
- new_root, added = self.__root.assoc(0, map_hash(key), key, val)
+ new_root, added = self.__root.assoc(0, map_hash(key), key, val, 0)
if new_root is self.__root:
assert not added
@@ -381,24 +457,16 @@ class Map:
if added:
new_count += 1
- m = Map.__new__(Map)
- m.__count = new_count
- m.__root = new_root
- m.__hash = -1
- return m
+ return Map._new(new_count, new_root)
def delete(self, key):
- res, node = self.__root.without(0, map_hash(key), key)
+ res, node = self.__root.without(0, map_hash(key), key, 0)
if res is W_EMPTY:
return Map()
elif res is W_NOT_FOUND:
raise KeyError(key)
else:
- m = Map.__new__(Map)
- m.__count = self.__count - 1
- m.__root = node
- m.__hash = -1
- return m
+ return Map._new(self.__count - 1, node)
def get(self, key, default=None):
try:
@@ -473,4 +541,90 @@ class Map:
return '\n'.join(buf)
+class MapMutation:
+
+ def __init__(self, count, root):
+ self.__count = count
+ self.__root = root
+ self.__mutid = _mut_id()
+
+ def set(self, key, val):
+ if self.__mutid == 0:
+ raise ValueError(f'mutation {self!r} has been finalized')
+
+ self.__root, added = self.__root.assoc(
+ 0, map_hash(key), key, val, self.__mutid)
+
+ if added:
+ self.__count += 1
+
+ def delete(self, key):
+ if self.__mutid == 0:
+ raise ValueError(f'mutation {self!r} has been finalized')
+
+ res, new_root = self.__root.without(
+ 0, map_hash(key), key, self.__mutid)
+ if res is W_EMPTY:
+ self.__count = 0
+ self.__root = BitmapNode(0, 0, [], self.__mutid)
+ elif res is W_NOT_FOUND:
+ raise KeyError(key)
+ else:
+ self.__root = new_root
+ self.__count -= 1
+
+ def get(self, key, default=None):
+ try:
+ return self.__root.find(0, map_hash(key), key)
+ except KeyError:
+ return default
+
+ def __getitem__(self, key):
+ return self.__root.find(0, map_hash(key), key)
+
+ def __contains__(self, key):
+ try:
+ self.__root.find(0, map_hash(key), key)
+ except KeyError:
+ return False
+ else:
+ return True
+
+ def finalize(self):
+ self.__mutid = 0
+ return Map._new(self.__count, self.__root)
+
+ @reprlib.recursive_repr("{...}")
+ def __repr__(self):
+ items = []
+ for key, val in self.__root.items():
+ items.append("{!r}: {!r}".format(key, val))
+ return '<immutables.MapMutation({{{}}}) at 0x{:0x}>'.format(
+ ', '.join(items), id(self))
+
+ def __len__(self):
+ return self.__count
+
+ def __hash__(self):
+ raise TypeError(f'unhashable type: {type(self).__name__}')
+
+ def __eq__(self, other):
+ if not isinstance(other, MapMutation):
+ return NotImplemented
+
+ if len(self) != len(other):
+ return False
+
+ for key, val in self.__root.items():
+ try:
+ oval = other.__root.find(0, map_hash(key), key)
+ except KeyError:
+ return False
+ else:
+ if oval != val:
+ return False
+
+ return True
+
+
collections.abc.Mapping.register(Map)