diff options
Diffstat (limited to 'immutables/map.py')
-rw-r--r-- | immutables/map.py | 272 |
1 files changed, 213 insertions, 59 deletions
diff --git a/immutables/map.py b/immutables/map.py index 81612bc..8c7d862 100644 --- a/immutables/map.py +++ b/immutables/map.py @@ -1,4 +1,5 @@ import collections.abc +import itertools import reprlib import sys @@ -6,6 +7,10 @@ import sys __all__ = ('Map',) +# Thread-safe counter. +_mut_id = itertools.count(1).__next__ + + # Python version of _map.c. The topmost comment there explains # all datastructures and algorithms. # The code here follows C code closely on purpose to make @@ -43,16 +48,17 @@ W_EMPTY, W_NEWNODE, W_NOT_FOUND = range(3) class BitmapNode: - def __init__(self, size, bitmap, array): + def __init__(self, size, bitmap, array, mutid): self.size = size self.bitmap = bitmap assert isinstance(array, list) and len(array) == size self.array = array + self.mutid = mutid - def clone(self): - return BitmapNode(self.size, self.bitmap, self.array.copy()) + def clone(self, mutid): + return BitmapNode(self.size, self.bitmap, self.array.copy(), mutid) - def assoc(self, shift, hash, key, val): + def assoc(self, shift, hash, key, val, mutid): bit = map_bitpos(hash, shift) idx = map_bitindex(self.bitmap, bit) @@ -65,38 +71,53 @@ class BitmapNode: if key_or_null is None: sub_node, added = val_or_node.assoc( - shift + 5, hash, key, val) + shift + 5, hash, key, val, mutid) if val_or_node is sub_node: - return self, False + return self, added - ret = self.clone() - ret.array[val_idx] = sub_node - return ret, added + if mutid and mutid == self.mutid: + self.array[val_idx] = sub_node + return self, added + else: + ret = self.clone(mutid) + ret.array[val_idx] = sub_node + return ret, added if key == key_or_null: if val is val_or_node: return self, False - ret = self.clone() - ret.array[val_idx] = val - return ret, False + if mutid and mutid == self.mutid: + self.array[val_idx] = val + return self, False + else: + ret = self.clone(mutid) + ret.array[val_idx] = val + return ret, False existing_key_hash = map_hash(key_or_null) if existing_key_hash == hash: sub_node = CollisionNode( - 4, hash, [key_or_null, val_or_node, key, val]) + 4, hash, [key_or_null, val_or_node, key, val], mutid) else: - sub_node = BitmapNode(0, 0, []) + sub_node = BitmapNode(0, 0, [], mutid) sub_node, _ = sub_node.assoc( shift + 5, existing_key_hash, - key_or_null, val_or_node) + key_or_null, val_or_node, + mutid) sub_node, _ = sub_node.assoc( - shift + 5, hash, key, val) + shift + 5, hash, key, val, + mutid) - ret = self.clone() - ret.array[key_idx] = None - ret.array[val_idx] = sub_node - return ret, True + if mutid and mutid == self.mutid: + self.array[key_idx] = None + self.array[val_idx] = sub_node + return self, True + else: + ret = self.clone(mutid) + ret.array[key_idx] = None + ret.array[val_idx] = sub_node + return ret, True else: key_idx = 2 * idx @@ -108,7 +129,15 @@ class BitmapNode: new_array.append(key) new_array.append(val) new_array.extend(self.array[key_idx:]) - return BitmapNode(2 * (n + 1), self.bitmap | bit, new_array), True + + if mutid and mutid == self.mutid: + self.size = 2 * (n + 1) + self.bitmap |= bit + self.array = new_array + return self, True + else: + return BitmapNode( + 2 * (n + 1), self.bitmap | bit, new_array, mutid), True def find(self, shift, hash, key): bit = map_bitpos(hash, shift) @@ -131,7 +160,7 @@ class BitmapNode: raise KeyError(key) - def without(self, shift, hash, key): + def without(self, shift, hash, key, mutid): bit = map_bitpos(hash, shift) if not (self.bitmap & bit): return W_NOT_FOUND, None @@ -144,7 +173,7 @@ class BitmapNode: val_or_node = self.array[val_idx] if key_or_null is None: - res, sub_node = val_or_node.without(shift + 5, hash, key) + res, sub_node = val_or_node.without(shift + 5, hash, key, mutid) if res is W_EMPTY: raise RuntimeError('unreachable code') # pragma: no cover @@ -153,14 +182,24 @@ class BitmapNode: if (type(sub_node) is BitmapNode and sub_node.size == 2 and sub_node.array[0] is not None): - clone = self.clone() - clone.array[key_idx] = sub_node.array[0] - clone.array[val_idx] = sub_node.array[1] - return W_NEWNODE, clone - clone = self.clone() - clone.array[val_idx] = sub_node - return W_NEWNODE, clone + if mutid and mutid == self.mutid: + self.array[key_idx] = sub_node.array[0] + self.array[val_idx] = sub_node.array[1] + return W_NEWNODE, self + else: + clone = self.clone(mutid) + clone.array[key_idx] = sub_node.array[0] + clone.array[val_idx] = sub_node.array[1] + return W_NEWNODE, clone + + if mutid and mutid == self.mutid: + self.array[val_idx] = sub_node + return W_NEWNODE, self + else: + clone = self.clone(mutid) + clone.array[val_idx] = sub_node + return W_NEWNODE, clone else: assert sub_node is None @@ -173,9 +212,16 @@ class BitmapNode: new_array = self.array[:key_idx] new_array.extend(self.array[val_idx + 1:]) - new_node = BitmapNode( - self.size - 2, self.bitmap & ~bit, new_array) - return W_NEWNODE, new_node + + if mutid and mutid == self.mutid: + self.size -= 2 + self.bitmap &= ~bit + self.array = new_array + return W_NEWNODE, self + else: + new_node = BitmapNode( + self.size - 2, self.bitmap & ~bit, new_array, mutid) + return W_NEWNODE, new_node else: return W_NOT_FOUND, None @@ -231,10 +277,11 @@ class BitmapNode: class CollisionNode: - def __init__(self, size, hash, array): + def __init__(self, size, hash, array, mutid): self.size = size self.hash = hash self.array = array + self.mutid = mutid def find_index(self, key): for i in range(0, self.size, 2): @@ -248,7 +295,7 @@ class CollisionNode: return self.array[i + 1] raise KeyError(key) - def assoc(self, shift, hash, key, val): + def assoc(self, shift, hash, key, val, mutid): if hash == self.hash: key_idx = self.find_index(key) @@ -256,23 +303,34 @@ class CollisionNode: new_array = self.array.copy() new_array.append(key) new_array.append(val) - new_node = CollisionNode(self.size + 2, hash, new_array) - return new_node, True + + if mutid and mutid == self.mutid: + self.size += 2 + self.array = new_array + return self, True + else: + new_node = CollisionNode( + self.size + 2, hash, new_array, mutid) + return new_node, True val_idx = key_idx + 1 if self.array[val_idx] is val: return self, False - new_array = self.array.copy() - new_array[val_idx] = val - return CollisionNode(self.size, hash, new_array), False + if mutid and mutid == self.mutid: + self.array[val_idx] = val + return self, False + else: + new_array = self.array.copy() + new_array[val_idx] = val + return CollisionNode(self.size, hash, new_array, mutid), False else: new_node = BitmapNode( - 2, map_bitpos(self.hash, shift), [None, self]) - return new_node.assoc(shift, hash, key, val) + 2, map_bitpos(self.hash, shift), [None, self], mutid) + return new_node.assoc(shift, hash, key, val, mutid) - def without(self, shift, hash, key): + def without(self, shift, hash, key, mutid): if hash != self.hash: return W_NOT_FOUND, None @@ -292,13 +350,20 @@ class CollisionNode: assert key_idx == 2 new_array = [self.array[0], self.array[1]] - new_node = BitmapNode(2, map_bitpos(hash, shift), new_array) + new_node = BitmapNode( + 2, map_bitpos(hash, shift), new_array, mutid) return W_NEWNODE, new_node new_array = self.array[:key_idx] new_array.extend(self.array[key_idx + 2:]) - new_node = CollisionNode(self.size - 2, self.hash, new_array) - return W_NEWNODE, new_node + if mutid and mutid == self.mutid: + self.array = new_array + self.size -= 2 + return W_NEWNODE, self + else: + new_node = CollisionNode( + self.size - 2, self.hash, new_array, mutid) + return W_NEWNODE, new_node def keys(self): for i in range(0, self.size, 2): @@ -346,9 +411,17 @@ class Map: def __init__(self): self.__count = 0 - self.__root = BitmapNode(0, 0, []) + self.__root = BitmapNode(0, 0, [], 0) self.__hash = -1 + @classmethod + def _new(cls, count, root): + m = Map.__new__(Map) + m.__count = count + m.__root = root + m.__hash = -1 + return m + def __len__(self): return self.__count @@ -370,9 +443,12 @@ class Map: return True + def mutate(self): + return MapMutation(self.__count, self.__root) + def set(self, key, val): new_count = self.__count - new_root, added = self.__root.assoc(0, map_hash(key), key, val) + new_root, added = self.__root.assoc(0, map_hash(key), key, val, 0) if new_root is self.__root: assert not added @@ -381,24 +457,16 @@ class Map: if added: new_count += 1 - m = Map.__new__(Map) - m.__count = new_count - m.__root = new_root - m.__hash = -1 - return m + return Map._new(new_count, new_root) def delete(self, key): - res, node = self.__root.without(0, map_hash(key), key) + res, node = self.__root.without(0, map_hash(key), key, 0) if res is W_EMPTY: return Map() elif res is W_NOT_FOUND: raise KeyError(key) else: - m = Map.__new__(Map) - m.__count = self.__count - 1 - m.__root = node - m.__hash = -1 - return m + return Map._new(self.__count - 1, node) def get(self, key, default=None): try: @@ -473,4 +541,90 @@ class Map: return '\n'.join(buf) +class MapMutation: + + def __init__(self, count, root): + self.__count = count + self.__root = root + self.__mutid = _mut_id() + + def set(self, key, val): + if self.__mutid == 0: + raise ValueError(f'mutation {self!r} has been finalized') + + self.__root, added = self.__root.assoc( + 0, map_hash(key), key, val, self.__mutid) + + if added: + self.__count += 1 + + def delete(self, key): + if self.__mutid == 0: + raise ValueError(f'mutation {self!r} has been finalized') + + res, new_root = self.__root.without( + 0, map_hash(key), key, self.__mutid) + if res is W_EMPTY: + self.__count = 0 + self.__root = BitmapNode(0, 0, [], self.__mutid) + elif res is W_NOT_FOUND: + raise KeyError(key) + else: + self.__root = new_root + self.__count -= 1 + + def get(self, key, default=None): + try: + return self.__root.find(0, map_hash(key), key) + except KeyError: + return default + + def __getitem__(self, key): + return self.__root.find(0, map_hash(key), key) + + def __contains__(self, key): + try: + self.__root.find(0, map_hash(key), key) + except KeyError: + return False + else: + return True + + def finalize(self): + self.__mutid = 0 + return Map._new(self.__count, self.__root) + + @reprlib.recursive_repr("{...}") + def __repr__(self): + items = [] + for key, val in self.__root.items(): + items.append("{!r}: {!r}".format(key, val)) + return '<immutables.MapMutation({{{}}}) at 0x{:0x}>'.format( + ', '.join(items), id(self)) + + def __len__(self): + return self.__count + + def __hash__(self): + raise TypeError(f'unhashable type: {type(self).__name__}') + + def __eq__(self, other): + if not isinstance(other, MapMutation): + return NotImplemented + + if len(self) != len(other): + return False + + for key, val in self.__root.items(): + try: + oval = other.__root.find(0, map_hash(key), key) + except KeyError: + return False + else: + if oval != val: + return False + + return True + + collections.abc.Mapping.register(Map) |