aboutsummaryrefslogtreecommitdiff
path: root/immutables/map.py
diff options
context:
space:
mode:
Diffstat (limited to 'immutables/map.py')
-rw-r--r--immutables/map.py855
1 files changed, 855 insertions, 0 deletions
diff --git a/immutables/map.py b/immutables/map.py
new file mode 100644
index 0000000..0ad2858
--- /dev/null
+++ b/immutables/map.py
@@ -0,0 +1,855 @@
+import collections.abc
+import itertools
+import reprlib
+import sys
+
+
+__all__ = ('Map',)
+
+
+# Thread-safe counter.
+_mut_id = itertools.count(1).__next__
+
+
+# Python version of _map.c. The topmost comment there explains
+# all datastructures and algorithms.
+# The code here follows C code closely on purpose to make
+# debugging and testing easier.
+
+
+def map_hash(o):
+ x = hash(o)
+ if sys.hash_info.width > 32:
+ return (x & 0xffffffff) ^ ((x >> 32) & 0xffffffff)
+ else:
+ return x
+
+
+def map_mask(hash, shift):
+ return (hash >> shift) & 0x01f
+
+
+def map_bitpos(hash, shift):
+ return 1 << map_mask(hash, shift)
+
+
+def map_bitcount(v):
+ v = v - ((v >> 1) & 0x55555555)
+ v = (v & 0x33333333) + ((v >> 2) & 0x33333333)
+ v = (v & 0x0F0F0F0F) + ((v >> 4) & 0x0F0F0F0F)
+ v = v + (v >> 8)
+ v = (v + (v >> 16)) & 0x3F
+ return v
+
+
+def map_bitindex(bitmap, bit):
+ return map_bitcount(bitmap & (bit - 1))
+
+
+W_EMPTY, W_NEWNODE, W_NOT_FOUND = range(3)
+void = object()
+
+
+class _Unhashable:
+ __slots__ = ()
+ __hash__ = None
+
+
+_NULL = _Unhashable()
+del _Unhashable
+
+
+class BitmapNode:
+
+ def __init__(self, size, bitmap, array, mutid):
+ self.size = size
+ self.bitmap = bitmap
+ assert isinstance(array, list) and len(array) == size
+ self.array = array
+ self.mutid = mutid
+
+ def clone(self, mutid):
+ return BitmapNode(self.size, self.bitmap, self.array.copy(), mutid)
+
+ def assoc(self, shift, hash, key, val, mutid):
+ bit = map_bitpos(hash, shift)
+ idx = map_bitindex(self.bitmap, bit)
+
+ if self.bitmap & bit:
+ key_idx = 2 * idx
+ val_idx = key_idx + 1
+
+ key_or_null = self.array[key_idx]
+ val_or_node = self.array[val_idx]
+
+ if key_or_null is _NULL:
+ sub_node, added = val_or_node.assoc(
+ shift + 5, hash, key, val, mutid)
+ if val_or_node is sub_node:
+ return self, added
+
+ if mutid and mutid == self.mutid:
+ self.array[val_idx] = sub_node
+ return self, added
+ else:
+ ret = self.clone(mutid)
+ ret.array[val_idx] = sub_node
+ return ret, added
+
+ if key == key_or_null:
+ if val is val_or_node:
+ return self, False
+
+ if mutid and mutid == self.mutid:
+ self.array[val_idx] = val
+ return self, False
+ else:
+ ret = self.clone(mutid)
+ ret.array[val_idx] = val
+ return ret, False
+
+ existing_key_hash = map_hash(key_or_null)
+ if existing_key_hash == hash:
+ sub_node = CollisionNode(
+ 4, hash, [key_or_null, val_or_node, key, val], mutid)
+ else:
+ sub_node = BitmapNode(0, 0, [], mutid)
+ sub_node, _ = sub_node.assoc(
+ shift + 5, existing_key_hash,
+ key_or_null, val_or_node,
+ mutid)
+ sub_node, _ = sub_node.assoc(
+ shift + 5, hash, key, val,
+ mutid)
+
+ if mutid and mutid == self.mutid:
+ self.array[key_idx] = _NULL
+ self.array[val_idx] = sub_node
+ return self, True
+ else:
+ ret = self.clone(mutid)
+ ret.array[key_idx] = _NULL
+ ret.array[val_idx] = sub_node
+ return ret, True
+
+ else:
+ key_idx = 2 * idx
+ val_idx = key_idx + 1
+
+ n = map_bitcount(self.bitmap)
+
+ new_array = self.array[:key_idx]
+ new_array.append(key)
+ new_array.append(val)
+ new_array.extend(self.array[key_idx:])
+
+ if mutid and mutid == self.mutid:
+ self.size = 2 * (n + 1)
+ self.bitmap |= bit
+ self.array = new_array
+ return self, True
+ else:
+ return BitmapNode(
+ 2 * (n + 1), self.bitmap | bit, new_array, mutid), True
+
+ def find(self, shift, hash, key):
+ bit = map_bitpos(hash, shift)
+
+ if not (self.bitmap & bit):
+ raise KeyError
+
+ idx = map_bitindex(self.bitmap, bit)
+ key_idx = idx * 2
+ val_idx = key_idx + 1
+
+ key_or_null = self.array[key_idx]
+ val_or_node = self.array[val_idx]
+
+ if key_or_null is _NULL:
+ return val_or_node.find(shift + 5, hash, key)
+
+ if key == key_or_null:
+ return val_or_node
+
+ raise KeyError(key)
+
+ def without(self, shift, hash, key, mutid):
+ bit = map_bitpos(hash, shift)
+ if not (self.bitmap & bit):
+ return W_NOT_FOUND, None
+
+ idx = map_bitindex(self.bitmap, bit)
+ key_idx = 2 * idx
+ val_idx = key_idx + 1
+
+ key_or_null = self.array[key_idx]
+ val_or_node = self.array[val_idx]
+
+ if key_or_null is _NULL:
+ res, sub_node = val_or_node.without(shift + 5, hash, key, mutid)
+
+ if res is W_EMPTY:
+ raise RuntimeError('unreachable code') # pragma: no cover
+
+ elif res is W_NEWNODE:
+ if (type(sub_node) is BitmapNode and
+ sub_node.size == 2 and
+ sub_node.array[0] is not _NULL):
+
+ if mutid and mutid == self.mutid:
+ self.array[key_idx] = sub_node.array[0]
+ self.array[val_idx] = sub_node.array[1]
+ return W_NEWNODE, self
+ else:
+ clone = self.clone(mutid)
+ clone.array[key_idx] = sub_node.array[0]
+ clone.array[val_idx] = sub_node.array[1]
+ return W_NEWNODE, clone
+
+ if mutid and mutid == self.mutid:
+ self.array[val_idx] = sub_node
+ return W_NEWNODE, self
+ else:
+ clone = self.clone(mutid)
+ clone.array[val_idx] = sub_node
+ return W_NEWNODE, clone
+
+ else:
+ assert sub_node is None
+ return res, None
+
+ else:
+ if key == key_or_null:
+ if self.size == 2:
+ return W_EMPTY, None
+
+ new_array = self.array[:key_idx]
+ new_array.extend(self.array[val_idx + 1:])
+
+ if mutid and mutid == self.mutid:
+ self.size -= 2
+ self.bitmap &= ~bit
+ self.array = new_array
+ return W_NEWNODE, self
+ else:
+ new_node = BitmapNode(
+ self.size - 2, self.bitmap & ~bit, new_array, mutid)
+ return W_NEWNODE, new_node
+
+ else:
+ return W_NOT_FOUND, None
+
+ def keys(self):
+ for i in range(0, self.size, 2):
+ key_or_null = self.array[i]
+
+ if key_or_null is _NULL:
+ val_or_node = self.array[i + 1]
+ yield from val_or_node.keys()
+ else:
+ yield key_or_null
+
+ def values(self):
+ for i in range(0, self.size, 2):
+ key_or_null = self.array[i]
+ val_or_node = self.array[i + 1]
+
+ if key_or_null is _NULL:
+ yield from val_or_node.values()
+ else:
+ yield val_or_node
+
+ def items(self):
+ for i in range(0, self.size, 2):
+ key_or_null = self.array[i]
+ val_or_node = self.array[i + 1]
+
+ if key_or_null is _NULL:
+ yield from val_or_node.items()
+ else:
+ yield key_or_null, val_or_node
+
+ def dump(self, buf, level): # pragma: no cover
+ buf.append(
+ ' ' * (level + 1) +
+ 'BitmapNode(size={} count={} bitmap={} id={:0x}):'.format(
+ self.size, self.size / 2, bin(self.bitmap), id(self)))
+
+ for i in range(0, self.size, 2):
+ key_or_null = self.array[i]
+ val_or_node = self.array[i + 1]
+
+ pad = ' ' * (level + 2)
+
+ if key_or_null is _NULL:
+ buf.append(pad + 'NULL:')
+ val_or_node.dump(buf, level + 2)
+ else:
+ buf.append(pad + '{!r}: {!r}'.format(key_or_null, val_or_node))
+
+
+class CollisionNode:
+
+ def __init__(self, size, hash, array, mutid):
+ self.size = size
+ self.hash = hash
+ self.array = array
+ self.mutid = mutid
+
+ def find_index(self, key):
+ for i in range(0, self.size, 2):
+ if self.array[i] == key:
+ return i
+ return -1
+
+ def find(self, shift, hash, key):
+ for i in range(0, self.size, 2):
+ if self.array[i] == key:
+ return self.array[i + 1]
+ raise KeyError(key)
+
+ def assoc(self, shift, hash, key, val, mutid):
+ if hash == self.hash:
+ key_idx = self.find_index(key)
+
+ if key_idx == -1:
+ new_array = self.array.copy()
+ new_array.append(key)
+ new_array.append(val)
+
+ if mutid and mutid == self.mutid:
+ self.size += 2
+ self.array = new_array
+ return self, True
+ else:
+ new_node = CollisionNode(
+ self.size + 2, hash, new_array, mutid)
+ return new_node, True
+
+ val_idx = key_idx + 1
+ if self.array[val_idx] is val:
+ return self, False
+
+ if mutid and mutid == self.mutid:
+ self.array[val_idx] = val
+ return self, False
+ else:
+ new_array = self.array.copy()
+ new_array[val_idx] = val
+ return CollisionNode(self.size, hash, new_array, mutid), False
+
+ else:
+ new_node = BitmapNode(
+ 2, map_bitpos(self.hash, shift), [_NULL, self], mutid)
+ return new_node.assoc(shift, hash, key, val, mutid)
+
+ def without(self, shift, hash, key, mutid):
+ if hash != self.hash:
+ return W_NOT_FOUND, None
+
+ key_idx = self.find_index(key)
+ if key_idx == -1:
+ return W_NOT_FOUND, None
+
+ new_size = self.size - 2
+ if new_size == 0:
+ # Shouldn't be ever reachable
+ return W_EMPTY, None # pragma: no cover
+
+ if new_size == 2:
+ if key_idx == 0:
+ new_array = [self.array[2], self.array[3]]
+ else:
+ assert key_idx == 2
+ new_array = [self.array[0], self.array[1]]
+
+ new_node = BitmapNode(
+ 2, map_bitpos(hash, shift), new_array, mutid)
+ return W_NEWNODE, new_node
+
+ new_array = self.array[:key_idx]
+ new_array.extend(self.array[key_idx + 2:])
+ if mutid and mutid == self.mutid:
+ self.array = new_array
+ self.size -= 2
+ return W_NEWNODE, self
+ else:
+ new_node = CollisionNode(
+ self.size - 2, self.hash, new_array, mutid)
+ return W_NEWNODE, new_node
+
+ def keys(self):
+ for i in range(0, self.size, 2):
+ yield self.array[i]
+
+ def values(self):
+ for i in range(1, self.size, 2):
+ yield self.array[i]
+
+ def items(self):
+ for i in range(0, self.size, 2):
+ yield self.array[i], self.array[i + 1]
+
+ def dump(self, buf, level): # pragma: no cover
+ pad = ' ' * (level + 1)
+ buf.append(
+ pad + 'CollisionNode(size={} id={:0x}):'.format(
+ self.size, id(self)))
+
+ pad = ' ' * (level + 2)
+ for i in range(0, self.size, 2):
+ key = self.array[i]
+ val = self.array[i + 1]
+
+ buf.append('{}{!r}: {!r}'.format(pad, key, val))
+
+
+class MapKeys:
+
+ def __init__(self, c, m):
+ self.__count = c
+ self.__root = m
+
+ def __len__(self):
+ return self.__count
+
+ def __iter__(self):
+ return iter(self.__root.keys())
+
+
+class MapValues:
+
+ def __init__(self, c, m):
+ self.__count = c
+ self.__root = m
+
+ def __len__(self):
+ return self.__count
+
+ def __iter__(self):
+ return iter(self.__root.values())
+
+
+class MapItems:
+
+ def __init__(self, c, m):
+ self.__count = c
+ self.__root = m
+
+ def __len__(self):
+ return self.__count
+
+ def __iter__(self):
+ return iter(self.__root.items())
+
+
+class Map:
+
+ def __init__(self, *args, **kw):
+ if not args:
+ col = None
+ elif len(args) == 1:
+ col = args[0]
+ else:
+ raise TypeError(
+ "immutables.Map expected at most 1 arguments, "
+ "got {}".format(len(args))
+ )
+
+ self.__count = 0
+ self.__root = BitmapNode(0, 0, [], 0)
+ self.__hash = -1
+
+ if isinstance(col, Map):
+ self.__count = col.__count
+ self.__root = col.__root
+ self.__hash = col.__hash
+ col = None
+ elif isinstance(col, MapMutation):
+ raise TypeError('cannot create Maps from MapMutations')
+
+ if col or kw:
+ init = self.update(col, **kw)
+ self.__count = init.__count
+ self.__root = init.__root
+
+ @classmethod
+ def _new(cls, count, root):
+ m = Map.__new__(Map)
+ m.__count = count
+ m.__root = root
+ m.__hash = -1
+ return m
+
+ def __reduce__(self):
+ return (type(self), (dict(self.items()),))
+
+ def __len__(self):
+ return self.__count
+
+ def __eq__(self, other):
+ if not isinstance(other, Map):
+ return NotImplemented
+
+ if len(self) != len(other):
+ return False
+
+ for key, val in self.__root.items():
+ try:
+ oval = other.__root.find(0, map_hash(key), key)
+ except KeyError:
+ return False
+ else:
+ if oval != val:
+ return False
+
+ return True
+
+ def update(self, *args, **kw):
+ if not args:
+ col = None
+ elif len(args) == 1:
+ col = args[0]
+ else:
+ raise TypeError(
+ "update expected at most 1 arguments, got {}".format(len(args))
+ )
+
+ it = None
+
+ if col is not None:
+ if hasattr(col, 'items'):
+ it = iter(col.items())
+ else:
+ it = iter(col)
+
+ if it is not None:
+ if kw:
+ it = iter(itertools.chain(it, kw.items()))
+ else:
+ if kw:
+ it = iter(kw.items())
+
+ if it is None:
+
+ return self
+
+ mutid = _mut_id()
+ root = self.__root
+ count = self.__count
+
+ i = 0
+ while True:
+ try:
+ tup = next(it)
+ except StopIteration:
+ break
+
+ try:
+ tup = tuple(tup)
+ except TypeError:
+ raise TypeError(
+ 'cannot convert map update '
+ 'sequence element #{} to a sequence'.format(i)) from None
+ key, val, *r = tup
+ if r:
+ raise ValueError(
+ 'map update sequence element #{} has length '
+ '{}; 2 is required'.format(i, len(r) + 2))
+
+ root, added = root.assoc(0, map_hash(key), key, val, mutid)
+ if added:
+ count += 1
+
+ i += 1
+
+ return Map._new(count, root)
+
+ def mutate(self):
+ return MapMutation(self.__count, self.__root)
+
+ def set(self, key, val):
+ new_count = self.__count
+ new_root, added = self.__root.assoc(0, map_hash(key), key, val, 0)
+
+ if new_root is self.__root:
+ assert not added
+ return self
+
+ if added:
+ new_count += 1
+
+ return Map._new(new_count, new_root)
+
+ def delete(self, key):
+ res, node = self.__root.without(0, map_hash(key), key, 0)
+ if res is W_EMPTY:
+ return Map()
+ elif res is W_NOT_FOUND:
+ raise KeyError(key)
+ else:
+ return Map._new(self.__count - 1, node)
+
+ def get(self, key, default=None):
+ try:
+ return self.__root.find(0, map_hash(key), key)
+ except KeyError:
+ return default
+
+ def __getitem__(self, key):
+ return self.__root.find(0, map_hash(key), key)
+
+ def __contains__(self, key):
+ try:
+ self.__root.find(0, map_hash(key), key)
+ except KeyError:
+ return False
+ else:
+ return True
+
+ def __iter__(self):
+ yield from self.__root.keys()
+
+ def keys(self):
+ return MapKeys(self.__count, self.__root)
+
+ def values(self):
+ return MapValues(self.__count, self.__root)
+
+ def items(self):
+ return MapItems(self.__count, self.__root)
+
+ def __hash__(self):
+ if self.__hash != -1:
+ return self.__hash
+
+ MAX = sys.maxsize
+ MASK = 2 * MAX + 1
+
+ h = 1927868237 * (self.__count * 2 + 1)
+ h &= MASK
+
+ for key, value in self.__root.items():
+ hx = hash(key)
+ h ^= (hx ^ (hx << 16) ^ 89869747) * 3644798167
+ h &= MASK
+
+ hx = hash(value)
+ h ^= (hx ^ (hx << 16) ^ 89869747) * 3644798167
+ h &= MASK
+
+ h = h * 69069 + 907133923
+ h &= MASK
+
+ if h > MAX:
+ h -= MASK + 1 # pragma: no cover
+ if h == -1:
+ h = 590923713 # pragma: no cover
+
+ self.__hash = h
+ return h
+
+ @reprlib.recursive_repr("{...}")
+ def __repr__(self):
+ items = []
+ for key, val in self.items():
+ items.append("{!r}: {!r}".format(key, val))
+ return 'immutables.Map({{{}}})'.format(', '.join(items))
+
+ def __dump__(self): # pragma: no cover
+ buf = []
+ self.__root.dump(buf, 0)
+ return '\n'.join(buf)
+
+ def __class_getitem__(cls, item):
+ return cls
+
+
+class MapMutation:
+
+ def __init__(self, count, root):
+ self.__count = count
+ self.__root = root
+ self.__mutid = _mut_id()
+
+ def set(self, key, val):
+ self[key] = val
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, *exc):
+ self.finish()
+ return False
+
+ def __iter__(self):
+ raise TypeError('{} is not iterable'.format(type(self)))
+
+ def __delitem__(self, key):
+ if self.__mutid == 0:
+ raise ValueError('mutation {!r} has been finished'.format(self))
+
+ res, new_root = self.__root.without(
+ 0, map_hash(key), key, self.__mutid)
+ if res is W_EMPTY:
+ self.__count = 0
+ self.__root = BitmapNode(0, 0, [], self.__mutid)
+ elif res is W_NOT_FOUND:
+ raise KeyError(key)
+ else:
+ self.__root = new_root
+ self.__count -= 1
+
+ def __setitem__(self, key, val):
+ if self.__mutid == 0:
+ raise ValueError('mutation {!r} has been finished'.format(self))
+
+ self.__root, added = self.__root.assoc(
+ 0, map_hash(key), key, val, self.__mutid)
+
+ if added:
+ self.__count += 1
+
+ def pop(self, key, *args):
+ if self.__mutid == 0:
+ raise ValueError('mutation {!r} has been finished'.format(self))
+
+ if len(args) > 1:
+ raise TypeError(
+ 'pop() accepts 1 to 2 positional arguments, '
+ 'got {}'.format(len(args) + 1))
+ elif len(args) == 1:
+ default = args[0]
+ else:
+ default = void
+
+ val = self.get(key, default)
+
+ try:
+ del self[key]
+ except KeyError:
+ if val is void:
+ raise
+ return val
+ else:
+ assert val is not void
+ return val
+
+ def get(self, key, default=None):
+ try:
+ return self.__root.find(0, map_hash(key), key)
+ except KeyError:
+ return default
+
+ def __getitem__(self, key):
+ return self.__root.find(0, map_hash(key), key)
+
+ def __contains__(self, key):
+ try:
+ self.__root.find(0, map_hash(key), key)
+ except KeyError:
+ return False
+ else:
+ return True
+
+ def update(self, *args, **kw):
+ if not args:
+ col = None
+ elif len(args) == 1:
+ col = args[0]
+ else:
+ raise TypeError(
+ "update expected at most 1 arguments, got {}".format(len(args))
+ )
+
+ if self.__mutid == 0:
+ raise ValueError('mutation {!r} has been finished'.format(self))
+
+ it = None
+ if col is not None:
+ if hasattr(col, 'items'):
+ it = iter(col.items())
+ else:
+ it = iter(col)
+
+ if it is not None:
+ if kw:
+ it = iter(itertools.chain(it, kw.items()))
+ else:
+ if kw:
+ it = iter(kw.items())
+
+ if it is None:
+ return
+
+ root = self.__root
+ count = self.__count
+
+ i = 0
+ while True:
+ try:
+ tup = next(it)
+ except StopIteration:
+ break
+
+ try:
+ tup = tuple(tup)
+ except TypeError:
+ raise TypeError(
+ 'cannot convert map update '
+ 'sequence element #{} to a sequence'.format(i)) from None
+ key, val, *r = tup
+ if r:
+ raise ValueError(
+ 'map update sequence element #{} has length '
+ '{}; 2 is required'.format(i, len(r) + 2))
+
+ root, added = root.assoc(0, map_hash(key), key, val, self.__mutid)
+ if added:
+ count += 1
+
+ i += 1
+
+ self.__root = root
+ self.__count = count
+
+ def finish(self):
+ self.__mutid = 0
+ return Map._new(self.__count, self.__root)
+
+ @reprlib.recursive_repr("{...}")
+ def __repr__(self):
+ items = []
+ for key, val in self.__root.items():
+ items.append("{!r}: {!r}".format(key, val))
+ return 'immutables.MapMutation({{{}}})'.format(', '.join(items))
+
+ def __len__(self):
+ return self.__count
+
+ def __reduce__(self):
+ raise TypeError("can't pickle {} objects".format(type(self).__name__))
+
+ def __hash__(self):
+ raise TypeError('unhashable type: {}'.format(type(self).__name__))
+
+ def __eq__(self, other):
+ if not isinstance(other, MapMutation):
+ return NotImplemented
+
+ if len(self) != len(other):
+ return False
+
+ for key, val in self.__root.items():
+ try:
+ oval = other.__root.find(0, map_hash(key), key)
+ except KeyError:
+ return False
+ else:
+ if oval != val:
+ return False
+
+ return True
+
+
+collections.abc.Mapping.register(Map)