diff options
Diffstat (limited to 'immutables/map.py')
-rw-r--r-- | immutables/map.py | 855 |
1 files changed, 855 insertions, 0 deletions
diff --git a/immutables/map.py b/immutables/map.py new file mode 100644 index 0000000..0ad2858 --- /dev/null +++ b/immutables/map.py @@ -0,0 +1,855 @@ +import collections.abc +import itertools +import reprlib +import sys + + +__all__ = ('Map',) + + +# Thread-safe counter. +_mut_id = itertools.count(1).__next__ + + +# Python version of _map.c. The topmost comment there explains +# all datastructures and algorithms. +# The code here follows C code closely on purpose to make +# debugging and testing easier. + + +def map_hash(o): + x = hash(o) + if sys.hash_info.width > 32: + return (x & 0xffffffff) ^ ((x >> 32) & 0xffffffff) + else: + return x + + +def map_mask(hash, shift): + return (hash >> shift) & 0x01f + + +def map_bitpos(hash, shift): + return 1 << map_mask(hash, shift) + + +def map_bitcount(v): + v = v - ((v >> 1) & 0x55555555) + v = (v & 0x33333333) + ((v >> 2) & 0x33333333) + v = (v & 0x0F0F0F0F) + ((v >> 4) & 0x0F0F0F0F) + v = v + (v >> 8) + v = (v + (v >> 16)) & 0x3F + return v + + +def map_bitindex(bitmap, bit): + return map_bitcount(bitmap & (bit - 1)) + + +W_EMPTY, W_NEWNODE, W_NOT_FOUND = range(3) +void = object() + + +class _Unhashable: + __slots__ = () + __hash__ = None + + +_NULL = _Unhashable() +del _Unhashable + + +class BitmapNode: + + def __init__(self, size, bitmap, array, mutid): + self.size = size + self.bitmap = bitmap + assert isinstance(array, list) and len(array) == size + self.array = array + self.mutid = mutid + + def clone(self, mutid): + return BitmapNode(self.size, self.bitmap, self.array.copy(), mutid) + + def assoc(self, shift, hash, key, val, mutid): + bit = map_bitpos(hash, shift) + idx = map_bitindex(self.bitmap, bit) + + if self.bitmap & bit: + key_idx = 2 * idx + val_idx = key_idx + 1 + + key_or_null = self.array[key_idx] + val_or_node = self.array[val_idx] + + if key_or_null is _NULL: + sub_node, added = val_or_node.assoc( + shift + 5, hash, key, val, mutid) + if val_or_node is sub_node: + return self, added + + if mutid and mutid == self.mutid: + self.array[val_idx] = sub_node + return self, added + else: + ret = self.clone(mutid) + ret.array[val_idx] = sub_node + return ret, added + + if key == key_or_null: + if val is val_or_node: + return self, False + + if mutid and mutid == self.mutid: + self.array[val_idx] = val + return self, False + else: + ret = self.clone(mutid) + ret.array[val_idx] = val + return ret, False + + existing_key_hash = map_hash(key_or_null) + if existing_key_hash == hash: + sub_node = CollisionNode( + 4, hash, [key_or_null, val_or_node, key, val], mutid) + else: + sub_node = BitmapNode(0, 0, [], mutid) + sub_node, _ = sub_node.assoc( + shift + 5, existing_key_hash, + key_or_null, val_or_node, + mutid) + sub_node, _ = sub_node.assoc( + shift + 5, hash, key, val, + mutid) + + if mutid and mutid == self.mutid: + self.array[key_idx] = _NULL + self.array[val_idx] = sub_node + return self, True + else: + ret = self.clone(mutid) + ret.array[key_idx] = _NULL + ret.array[val_idx] = sub_node + return ret, True + + else: + key_idx = 2 * idx + val_idx = key_idx + 1 + + n = map_bitcount(self.bitmap) + + new_array = self.array[:key_idx] + new_array.append(key) + new_array.append(val) + new_array.extend(self.array[key_idx:]) + + if mutid and mutid == self.mutid: + self.size = 2 * (n + 1) + self.bitmap |= bit + self.array = new_array + return self, True + else: + return BitmapNode( + 2 * (n + 1), self.bitmap | bit, new_array, mutid), True + + def find(self, shift, hash, key): + bit = map_bitpos(hash, shift) + + if not (self.bitmap & bit): + raise KeyError + + idx = map_bitindex(self.bitmap, bit) + key_idx = idx * 2 + val_idx = key_idx + 1 + + key_or_null = self.array[key_idx] + val_or_node = self.array[val_idx] + + if key_or_null is _NULL: + return val_or_node.find(shift + 5, hash, key) + + if key == key_or_null: + return val_or_node + + raise KeyError(key) + + def without(self, shift, hash, key, mutid): + bit = map_bitpos(hash, shift) + if not (self.bitmap & bit): + return W_NOT_FOUND, None + + idx = map_bitindex(self.bitmap, bit) + key_idx = 2 * idx + val_idx = key_idx + 1 + + key_or_null = self.array[key_idx] + val_or_node = self.array[val_idx] + + if key_or_null is _NULL: + res, sub_node = val_or_node.without(shift + 5, hash, key, mutid) + + if res is W_EMPTY: + raise RuntimeError('unreachable code') # pragma: no cover + + elif res is W_NEWNODE: + if (type(sub_node) is BitmapNode and + sub_node.size == 2 and + sub_node.array[0] is not _NULL): + + if mutid and mutid == self.mutid: + self.array[key_idx] = sub_node.array[0] + self.array[val_idx] = sub_node.array[1] + return W_NEWNODE, self + else: + clone = self.clone(mutid) + clone.array[key_idx] = sub_node.array[0] + clone.array[val_idx] = sub_node.array[1] + return W_NEWNODE, clone + + if mutid and mutid == self.mutid: + self.array[val_idx] = sub_node + return W_NEWNODE, self + else: + clone = self.clone(mutid) + clone.array[val_idx] = sub_node + return W_NEWNODE, clone + + else: + assert sub_node is None + return res, None + + else: + if key == key_or_null: + if self.size == 2: + return W_EMPTY, None + + new_array = self.array[:key_idx] + new_array.extend(self.array[val_idx + 1:]) + + if mutid and mutid == self.mutid: + self.size -= 2 + self.bitmap &= ~bit + self.array = new_array + return W_NEWNODE, self + else: + new_node = BitmapNode( + self.size - 2, self.bitmap & ~bit, new_array, mutid) + return W_NEWNODE, new_node + + else: + return W_NOT_FOUND, None + + def keys(self): + for i in range(0, self.size, 2): + key_or_null = self.array[i] + + if key_or_null is _NULL: + val_or_node = self.array[i + 1] + yield from val_or_node.keys() + else: + yield key_or_null + + def values(self): + for i in range(0, self.size, 2): + key_or_null = self.array[i] + val_or_node = self.array[i + 1] + + if key_or_null is _NULL: + yield from val_or_node.values() + else: + yield val_or_node + + def items(self): + for i in range(0, self.size, 2): + key_or_null = self.array[i] + val_or_node = self.array[i + 1] + + if key_or_null is _NULL: + yield from val_or_node.items() + else: + yield key_or_null, val_or_node + + def dump(self, buf, level): # pragma: no cover + buf.append( + ' ' * (level + 1) + + 'BitmapNode(size={} count={} bitmap={} id={:0x}):'.format( + self.size, self.size / 2, bin(self.bitmap), id(self))) + + for i in range(0, self.size, 2): + key_or_null = self.array[i] + val_or_node = self.array[i + 1] + + pad = ' ' * (level + 2) + + if key_or_null is _NULL: + buf.append(pad + 'NULL:') + val_or_node.dump(buf, level + 2) + else: + buf.append(pad + '{!r}: {!r}'.format(key_or_null, val_or_node)) + + +class CollisionNode: + + def __init__(self, size, hash, array, mutid): + self.size = size + self.hash = hash + self.array = array + self.mutid = mutid + + def find_index(self, key): + for i in range(0, self.size, 2): + if self.array[i] == key: + return i + return -1 + + def find(self, shift, hash, key): + for i in range(0, self.size, 2): + if self.array[i] == key: + return self.array[i + 1] + raise KeyError(key) + + def assoc(self, shift, hash, key, val, mutid): + if hash == self.hash: + key_idx = self.find_index(key) + + if key_idx == -1: + new_array = self.array.copy() + new_array.append(key) + new_array.append(val) + + if mutid and mutid == self.mutid: + self.size += 2 + self.array = new_array + return self, True + else: + new_node = CollisionNode( + self.size + 2, hash, new_array, mutid) + return new_node, True + + val_idx = key_idx + 1 + if self.array[val_idx] is val: + return self, False + + if mutid and mutid == self.mutid: + self.array[val_idx] = val + return self, False + else: + new_array = self.array.copy() + new_array[val_idx] = val + return CollisionNode(self.size, hash, new_array, mutid), False + + else: + new_node = BitmapNode( + 2, map_bitpos(self.hash, shift), [_NULL, self], mutid) + return new_node.assoc(shift, hash, key, val, mutid) + + def without(self, shift, hash, key, mutid): + if hash != self.hash: + return W_NOT_FOUND, None + + key_idx = self.find_index(key) + if key_idx == -1: + return W_NOT_FOUND, None + + new_size = self.size - 2 + if new_size == 0: + # Shouldn't be ever reachable + return W_EMPTY, None # pragma: no cover + + if new_size == 2: + if key_idx == 0: + new_array = [self.array[2], self.array[3]] + else: + assert key_idx == 2 + new_array = [self.array[0], self.array[1]] + + new_node = BitmapNode( + 2, map_bitpos(hash, shift), new_array, mutid) + return W_NEWNODE, new_node + + new_array = self.array[:key_idx] + new_array.extend(self.array[key_idx + 2:]) + if mutid and mutid == self.mutid: + self.array = new_array + self.size -= 2 + return W_NEWNODE, self + else: + new_node = CollisionNode( + self.size - 2, self.hash, new_array, mutid) + return W_NEWNODE, new_node + + def keys(self): + for i in range(0, self.size, 2): + yield self.array[i] + + def values(self): + for i in range(1, self.size, 2): + yield self.array[i] + + def items(self): + for i in range(0, self.size, 2): + yield self.array[i], self.array[i + 1] + + def dump(self, buf, level): # pragma: no cover + pad = ' ' * (level + 1) + buf.append( + pad + 'CollisionNode(size={} id={:0x}):'.format( + self.size, id(self))) + + pad = ' ' * (level + 2) + for i in range(0, self.size, 2): + key = self.array[i] + val = self.array[i + 1] + + buf.append('{}{!r}: {!r}'.format(pad, key, val)) + + +class MapKeys: + + def __init__(self, c, m): + self.__count = c + self.__root = m + + def __len__(self): + return self.__count + + def __iter__(self): + return iter(self.__root.keys()) + + +class MapValues: + + def __init__(self, c, m): + self.__count = c + self.__root = m + + def __len__(self): + return self.__count + + def __iter__(self): + return iter(self.__root.values()) + + +class MapItems: + + def __init__(self, c, m): + self.__count = c + self.__root = m + + def __len__(self): + return self.__count + + def __iter__(self): + return iter(self.__root.items()) + + +class Map: + + def __init__(self, *args, **kw): + if not args: + col = None + elif len(args) == 1: + col = args[0] + else: + raise TypeError( + "immutables.Map expected at most 1 arguments, " + "got {}".format(len(args)) + ) + + self.__count = 0 + self.__root = BitmapNode(0, 0, [], 0) + self.__hash = -1 + + if isinstance(col, Map): + self.__count = col.__count + self.__root = col.__root + self.__hash = col.__hash + col = None + elif isinstance(col, MapMutation): + raise TypeError('cannot create Maps from MapMutations') + + if col or kw: + init = self.update(col, **kw) + self.__count = init.__count + self.__root = init.__root + + @classmethod + def _new(cls, count, root): + m = Map.__new__(Map) + m.__count = count + m.__root = root + m.__hash = -1 + return m + + def __reduce__(self): + return (type(self), (dict(self.items()),)) + + def __len__(self): + return self.__count + + def __eq__(self, other): + if not isinstance(other, Map): + return NotImplemented + + if len(self) != len(other): + return False + + for key, val in self.__root.items(): + try: + oval = other.__root.find(0, map_hash(key), key) + except KeyError: + return False + else: + if oval != val: + return False + + return True + + def update(self, *args, **kw): + if not args: + col = None + elif len(args) == 1: + col = args[0] + else: + raise TypeError( + "update expected at most 1 arguments, got {}".format(len(args)) + ) + + it = None + + if col is not None: + if hasattr(col, 'items'): + it = iter(col.items()) + else: + it = iter(col) + + if it is not None: + if kw: + it = iter(itertools.chain(it, kw.items())) + else: + if kw: + it = iter(kw.items()) + + if it is None: + + return self + + mutid = _mut_id() + root = self.__root + count = self.__count + + i = 0 + while True: + try: + tup = next(it) + except StopIteration: + break + + try: + tup = tuple(tup) + except TypeError: + raise TypeError( + 'cannot convert map update ' + 'sequence element #{} to a sequence'.format(i)) from None + key, val, *r = tup + if r: + raise ValueError( + 'map update sequence element #{} has length ' + '{}; 2 is required'.format(i, len(r) + 2)) + + root, added = root.assoc(0, map_hash(key), key, val, mutid) + if added: + count += 1 + + i += 1 + + return Map._new(count, root) + + def mutate(self): + return MapMutation(self.__count, self.__root) + + def set(self, key, val): + new_count = self.__count + new_root, added = self.__root.assoc(0, map_hash(key), key, val, 0) + + if new_root is self.__root: + assert not added + return self + + if added: + new_count += 1 + + return Map._new(new_count, new_root) + + def delete(self, key): + res, node = self.__root.without(0, map_hash(key), key, 0) + if res is W_EMPTY: + return Map() + elif res is W_NOT_FOUND: + raise KeyError(key) + else: + return Map._new(self.__count - 1, node) + + def get(self, key, default=None): + try: + return self.__root.find(0, map_hash(key), key) + except KeyError: + return default + + def __getitem__(self, key): + return self.__root.find(0, map_hash(key), key) + + def __contains__(self, key): + try: + self.__root.find(0, map_hash(key), key) + except KeyError: + return False + else: + return True + + def __iter__(self): + yield from self.__root.keys() + + def keys(self): + return MapKeys(self.__count, self.__root) + + def values(self): + return MapValues(self.__count, self.__root) + + def items(self): + return MapItems(self.__count, self.__root) + + def __hash__(self): + if self.__hash != -1: + return self.__hash + + MAX = sys.maxsize + MASK = 2 * MAX + 1 + + h = 1927868237 * (self.__count * 2 + 1) + h &= MASK + + for key, value in self.__root.items(): + hx = hash(key) + h ^= (hx ^ (hx << 16) ^ 89869747) * 3644798167 + h &= MASK + + hx = hash(value) + h ^= (hx ^ (hx << 16) ^ 89869747) * 3644798167 + h &= MASK + + h = h * 69069 + 907133923 + h &= MASK + + if h > MAX: + h -= MASK + 1 # pragma: no cover + if h == -1: + h = 590923713 # pragma: no cover + + self.__hash = h + return h + + @reprlib.recursive_repr("{...}") + def __repr__(self): + items = [] + for key, val in self.items(): + items.append("{!r}: {!r}".format(key, val)) + return 'immutables.Map({{{}}})'.format(', '.join(items)) + + def __dump__(self): # pragma: no cover + buf = [] + self.__root.dump(buf, 0) + return '\n'.join(buf) + + def __class_getitem__(cls, item): + return cls + + +class MapMutation: + + def __init__(self, count, root): + self.__count = count + self.__root = root + self.__mutid = _mut_id() + + def set(self, key, val): + self[key] = val + + def __enter__(self): + return self + + def __exit__(self, *exc): + self.finish() + return False + + def __iter__(self): + raise TypeError('{} is not iterable'.format(type(self))) + + def __delitem__(self, key): + if self.__mutid == 0: + raise ValueError('mutation {!r} has been finished'.format(self)) + + res, new_root = self.__root.without( + 0, map_hash(key), key, self.__mutid) + if res is W_EMPTY: + self.__count = 0 + self.__root = BitmapNode(0, 0, [], self.__mutid) + elif res is W_NOT_FOUND: + raise KeyError(key) + else: + self.__root = new_root + self.__count -= 1 + + def __setitem__(self, key, val): + if self.__mutid == 0: + raise ValueError('mutation {!r} has been finished'.format(self)) + + self.__root, added = self.__root.assoc( + 0, map_hash(key), key, val, self.__mutid) + + if added: + self.__count += 1 + + def pop(self, key, *args): + if self.__mutid == 0: + raise ValueError('mutation {!r} has been finished'.format(self)) + + if len(args) > 1: + raise TypeError( + 'pop() accepts 1 to 2 positional arguments, ' + 'got {}'.format(len(args) + 1)) + elif len(args) == 1: + default = args[0] + else: + default = void + + val = self.get(key, default) + + try: + del self[key] + except KeyError: + if val is void: + raise + return val + else: + assert val is not void + return val + + def get(self, key, default=None): + try: + return self.__root.find(0, map_hash(key), key) + except KeyError: + return default + + def __getitem__(self, key): + return self.__root.find(0, map_hash(key), key) + + def __contains__(self, key): + try: + self.__root.find(0, map_hash(key), key) + except KeyError: + return False + else: + return True + + def update(self, *args, **kw): + if not args: + col = None + elif len(args) == 1: + col = args[0] + else: + raise TypeError( + "update expected at most 1 arguments, got {}".format(len(args)) + ) + + if self.__mutid == 0: + raise ValueError('mutation {!r} has been finished'.format(self)) + + it = None + if col is not None: + if hasattr(col, 'items'): + it = iter(col.items()) + else: + it = iter(col) + + if it is not None: + if kw: + it = iter(itertools.chain(it, kw.items())) + else: + if kw: + it = iter(kw.items()) + + if it is None: + return + + root = self.__root + count = self.__count + + i = 0 + while True: + try: + tup = next(it) + except StopIteration: + break + + try: + tup = tuple(tup) + except TypeError: + raise TypeError( + 'cannot convert map update ' + 'sequence element #{} to a sequence'.format(i)) from None + key, val, *r = tup + if r: + raise ValueError( + 'map update sequence element #{} has length ' + '{}; 2 is required'.format(i, len(r) + 2)) + + root, added = root.assoc(0, map_hash(key), key, val, self.__mutid) + if added: + count += 1 + + i += 1 + + self.__root = root + self.__count = count + + def finish(self): + self.__mutid = 0 + return Map._new(self.__count, self.__root) + + @reprlib.recursive_repr("{...}") + def __repr__(self): + items = [] + for key, val in self.__root.items(): + items.append("{!r}: {!r}".format(key, val)) + return 'immutables.MapMutation({{{}}})'.format(', '.join(items)) + + def __len__(self): + return self.__count + + def __reduce__(self): + raise TypeError("can't pickle {} objects".format(type(self).__name__)) + + def __hash__(self): + raise TypeError('unhashable type: {}'.format(type(self).__name__)) + + def __eq__(self, other): + if not isinstance(other, MapMutation): + return NotImplemented + + if len(self) != len(other): + return False + + for key, val in self.__root.items(): + try: + oval = other.__root.find(0, map_hash(key), key) + except KeyError: + return False + else: + if oval != val: + return False + + return True + + +collections.abc.Mapping.register(Map) |