From 451a84825d82e7fba4022857085ee2977f9a1d09 Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Mon, 2 Apr 2018 22:43:52 -0400 Subject: Add pure Python implementation (compatible with PyPy) --- immutables/__init__.py | 5 +- immutables/map.py | 460 +++++++++++++++++++++++++++++++++++++++++++++++++ setup.py | 18 +- tests/test_map.py | 79 +++++---- 4 files changed, 525 insertions(+), 37 deletions(-) create mode 100644 immutables/map.py diff --git a/immutables/__init__.py b/immutables/__init__.py index 8ed9faa..d598413 100644 --- a/immutables/__init__.py +++ b/immutables/__init__.py @@ -1,4 +1,7 @@ -from ._map import Map +try: + from ._map import Map +except ImportError: + from .map import Map __all__ = 'Map', diff --git a/immutables/map.py b/immutables/map.py new file mode 100644 index 0000000..ffed188 --- /dev/null +++ b/immutables/map.py @@ -0,0 +1,460 @@ +import reprlib + + +def map_hash(o): + x = hash(o) + return (x & 0xffffffff) ^ ((x >> 32) & 0xffffffff) + + +def map_mask(hash, shift): + return (hash >> shift) & 0x01f + + +def map_bitpos(hash, shift): + return 1 << map_mask(hash, shift) + + +def map_bitcount(v): + v = v - ((v >> 1) & 0x55555555) + v = (v & 0x33333333) + ((v >> 2) & 0x33333333) + v = (v & 0x0F0F0F0F) + ((v >> 4) & 0x0F0F0F0F) + v = v + (v >> 8) + v = (v + (v >> 16)) & 0x3F + return v + + +def map_bitindex(bitmap, bit): + return map_bitcount(bitmap & (bit - 1)) + + +def shuffle_bits(h): + # used in Map.__hash__ + return ((h ^ 89869747) ^ (h << 16)) * 3644798167 + + +W_EMPTY, W_NEWNODE, W_NOT_FOUND = range(3) + + +class BitmapNode: + + def __init__(self, size, bitmap, array): + self.size = size + self.bitmap = bitmap + assert isinstance(array, list) and len(array) == size + self.array = array + + def clone(self): + return BitmapNode(self.size, self.bitmap, self.array.copy()) + + def assoc(self, shift, hash, key, val, added_leaf): + bit = map_bitpos(hash, shift) + idx = map_bitindex(self.bitmap, bit) + + if self.bitmap & bit: + key_idx = 2 * idx + val_idx = key_idx + 1 + + key_or_null = self.array[key_idx] + val_or_node = self.array[val_idx] + + if key_or_null is None: + sub_node = val_or_node.assoc( + shift + 5, hash, key, val, added_leaf) + if val_or_node is sub_node: + return self + + ret = self.clone() + ret.array[val_idx] = sub_node + return ret + + if key == key_or_null: + if val is val_or_node: + return self + + ret = self.clone() + ret.array[val_idx] = val + return ret + + existing_key_hash = map_hash(key_or_null) + if existing_key_hash == hash: + sub_node = CollisionNode( + 4, hash, [key_or_null, val_or_node, key, val]) + else: + sub_node = BitmapNode(0, 0, []) + sub_node = sub_node.assoc( + shift + 5, existing_key_hash, + key_or_null, val_or_node, [False]) + sub_node = sub_node.assoc( + shift + 5, hash, key, val, [False]) + + ret = self.clone() + ret.array[key_idx] = None + ret.array[val_idx] = sub_node + added_leaf[0] = True + return ret + + else: + added_leaf[0] = True + + key_idx = 2 * idx + val_idx = key_idx + 1 + + n = map_bitcount(self.bitmap) + + new_array = self.array[:key_idx] + new_array.append(key) + new_array.append(val) + new_array.extend(self.array[key_idx:]) + return BitmapNode(2 * (n + 1), self.bitmap | bit, new_array) + + def find(self, shift, hash, key): + bit = map_bitpos(hash, shift) + + if not (self.bitmap & bit): + raise KeyError + + idx = map_bitindex(self.bitmap, bit) + key_idx = idx * 2 + val_idx = key_idx + 1 + + key_or_null = self.array[key_idx] + val_or_node = self.array[val_idx] + + if key_or_null is None: + return val_or_node.find(shift + 5, hash, key) + + if key == key_or_null: + return val_or_node + + raise KeyError + + def without(self, shift, hash, key): + bit = map_bitpos(hash, shift) + if not (self.bitmap & bit): + return W_NOT_FOUND, None + + idx = map_bitindex(self.bitmap, bit) + key_idx = 2 * idx + val_idx = key_idx + 1 + + key_or_null = self.array[key_idx] + val_or_node = self.array[val_idx] + + if key_or_null is None: + res, sub_node = val_or_node.without(shift + 5, hash, key) + + if res is W_EMPTY: + raise RuntimeError('unreachable code') + + elif res is W_NEWNODE: + if (type(sub_node) is BitmapNode and + sub_node.size == 2 and + sub_node.array[0] is not None): + clone = self.clone() + clone.array[key_idx] = sub_node.array[0] + clone.array[val_idx] = sub_node.array[1] + return W_NEWNODE, clone + + clone = self.clone() + clone.array[val_idx] = sub_node + return W_NEWNODE, clone + + else: + assert sub_node is None + return res, None + + else: + if key == key_or_null: + new_array = self.array[:key_idx] + new_array.extend(self.array[val_idx + 1:]) + new_node = BitmapNode( + self.size - 2, self.bitmap & ~bit, new_array) + return W_NEWNODE, new_node + + else: + return W_NOT_FOUND, None + + def keys(self): + for i in range(0, self.size, 2): + key_or_null = self.array[i] + + if key_or_null is None: + val_or_node = self.array[i + 1] + yield from val_or_node.keys() + else: + yield key_or_null + + def values(self): + for i in range(0, self.size, 2): + key_or_null = self.array[i] + val_or_node = self.array[i + 1] + + if key_or_null is None: + yield from val_or_node.values() + else: + yield val_or_node + + def items(self): + for i in range(0, self.size, 2): + key_or_null = self.array[i] + val_or_node = self.array[i + 1] + + if key_or_null is None: + yield from val_or_node.items() + else: + yield key_or_null, val_or_node + + def dump(self, buf, level): + buf.append( + ' ' * (level + 1) + + 'BitmapNode(size={} count={} bitmap={} id={:0x}):'.format( + self.size, self.size / 2, bin(self.bitmap), id(self))) + + for i in range(0, self.size, 2): + key_or_null = self.array[i] + val_or_node = self.array[i + 1] + + pad = ' ' * (level + 2) + + if key_or_null is None: + buf.append(pad + 'None:') + val_or_node.dump(buf, level + 2) + else: + buf.append(pad + '{!r}: {!r}'.format(key_or_null, val_or_node)) + + +class CollisionNode: + + def __init__(self, size, hash, array): + self.size = size + self.hash = hash + self.array = array + + def find_index(self, key): + for i in range(0, self.size, 2): + if self.array[i] == key: + return i + return -1 + + def find(self, shift, hash, key): + for i in range(0, self.size, 2): + if self.array[i] == key: + return self.array[i + 1] + raise KeyError + + def assoc(self, shift, hash, key, val, added_leaf): + if hash == self.hash: + key_idx = self.find_index(key) + + if key_idx == -1: + new_array = self.array.copy() + new_array.append(key) + new_array.append(val) + new_node = CollisionNode(self.size + 2, hash, new_array) + added_leaf[0] = True + return new_node + + val_idx = key_idx + 1 + if self.array[val_idx] is val: + return self + + new_array = self.array.copy() + new_array[val_idx] = val + return CollisionNode(self.size, hash, new_array) + + else: + new_node = BitmapNode( + 2, map_bitpos(self.hash, shift), [None, self]) + return new_node.assoc(shift, hash, key, val, added_leaf) + + def without(self, shift, hash, key): + if hash != self.hash: + return W_NOT_FOUND, None + + key_idx = self.find_index(key) + if key_idx == -1: + return W_NOT_FOUND, None + + new_size = self.size - 2 + if new_size == 0: + return W_EMPTY, None + + if new_size == 2: + if key_idx == 0: + new_array = [self.array[2], self.array[3]] + else: + assert key_idx == 2 + new_array = [self.array[0], self.array[1]] + + new_node = BitmapNode(2, map_bitpos(hash, shift), new_array) + return W_NEWNODE, new_node + + new_array = self.array[:key_idx] + new_array.extend(self.array[key_idx + 2:]) + new_node = CollisionNode(self.size - 2, self.hash, new_array) + return W_NEWNODE, new_node + + def keys(self): + for i in range(0, self.size, 2): + yield self.array[i] + + def values(self): + for i in range(1, self.size, 2): + yield self.array[i] + + def items(self): + for i in range(0, self.size, 2): + yield self.array[i], self.array[i + 1] + + def dump(self, buf, level): + pad = ' ' * (level + 1) + buf.append( + pad + 'CollisionNode(size={} id={:0x}):'.format( + self.size, id(self))) + + pad = ' ' * (level + 2) + for i in range(0, self.size, 2): + key = self.array[i] + val = self.array[i + 1] + + buf.append('{}{!r}: {!r}'.format(pad, key, val)) + + +class GenWrapper: + + def __init__(self, count, gen): + self.__count = count + self.__gen = gen + + def __len__(self): + return self.__count + + def __iter__(self): + return self + + def __next__(self): + return next(self.__gen) + + +class Map: + + def __init__(self): + self.__count = 0 + self.__root = BitmapNode(0, 0, []) + self.__hash = -1 + + def __len__(self): + return self.__count + + def __eq__(self, other): + if not isinstance(other, Map): + return NotImplemented + + if len(self) != len(other): + return False + + for key, val in self.__root.items(): + try: + oval = other.__root.find(0, map_hash(key), key) + except KeyError: + return False + else: + if oval != val: + return False + + return True + + def set(self, key, val): + added = [False] + + new_count = self.__count + new_root = self.__root.assoc(0, map_hash(key), key, val, added) + + if new_root is self.__root: + assert not added[0] + return self + + if added[0]: + new_count += 1 + + m = Map.__new__(Map) + m.__count = new_count + m.__root = new_root + m.__hash = -1 + return m + + def delete(self, key): + res, node = self.__root.without(0, map_hash(key), key) + if res is W_EMPTY: + return Map() + elif res is W_NOT_FOUND: + # raise KeyError(key) + return self + else: + m = Map.__new__(Map) + m.__count = self.__count - 1 + m.__root = node + m.__hash = -1 + return m + + def get(self, key, default=None): + try: + return self.__root.find(0, map_hash(key), key) + except KeyError: + return default + + def __getitem__(self, key): + return self.__root.find(0, map_hash(key), key) + + def __contains__(self, key): + try: + self.__root.find(0, map_hash(key), key) + except KeyError: + return False + else: + return True + + def __iter__(self): + yield from self.__root.keys() + + def keys(self): + return GenWrapper(self.__count, self.__root.keys()) + + def values(self): + return GenWrapper(self.__count, self.__root.values()) + + def items(self): + return GenWrapper(self.__count, self.__root.items()) + + def __hash__(self): + if self.__hash != -1: + return self.__hash + + h = 0 + for key, value in self.__root.items(): + h ^= shuffle_bits(hash(key)) + h ^= shuffle_bits(hash(value)) + + h ^= (self.__count * 2 + 1) * 1927868237 + + h ^= (h >> 11) ^ (h >> 25) + h = h * 69069 + 907133923 + + if h == -1: + h = -2 + + self.__hash = h + return h + + @reprlib.recursive_repr("{...}") + def __repr__(self): + items = [] + for key, val in self.items(): + items.append("{!r}: {!r}".format(key, val)) + return ''.format( + ', '.join(items), id(self)) + + def __dump__(self): + buf = [] + self.__root.dump(buf, 0) + return '\n'.join(buf) diff --git a/setup.py b/setup.py index b0cd4c9..40877c0 100644 --- a/setup.py +++ b/setup.py @@ -21,6 +21,17 @@ with open(os.path.join( 'unable to read the version from immutables/__init__.py') +if platform.python_implementation() == 'CPython': + ext_modules = [ + setuptools.Extension( + "immutables._map", + ["immutables/_map.c"], + extra_compile_args=CFLAGS) + ] +else: + ext_modules = [] + + setuptools.setup( name='immutables', version=VERSION, @@ -42,11 +53,6 @@ setuptools.setup( packages=['immutables'], provides=['immutables'], include_package_data=True, - ext_modules=[ - setuptools.Extension( - "immutables._map", - ["immutables/_map.c"], - extra_compile_args=CFLAGS) - ], + ext_modules=ext_modules, test_suite='tests.suite', ) diff --git a/tests/test_map.py b/tests/test_map.py index 91050a2..03dc9a0 100644 --- a/tests/test_map.py +++ b/tests/test_map.py @@ -3,7 +3,7 @@ import random import unittest import weakref -from immutables import Map +from immutables.map import Map as PyMap class HashKey: @@ -88,7 +88,9 @@ class ReprError(Exception): pass -class MapTest(unittest.TestCase): +class BaseMapTest: + + Map = None def test_hashkey_helper_1(self): k1 = HashKey(10, 'aaa') @@ -105,11 +107,11 @@ class MapTest(unittest.TestCase): self.assertEqual(d[k2], 'b') def test_map_basics_1(self): - h = Map() + h = self.Map() h = None # NoQA def test_map_basics_2(self): - h = Map() + h = self.Map() self.assertEqual(len(h), 0) h2 = h.set('a', 'b') @@ -139,14 +141,14 @@ class MapTest(unittest.TestCase): h = h2 = h3 = None def test_map_basics_3(self): - h = Map() + h = self.Map() o = object() h1 = h.set('1', o) h2 = h1.set('1', o) self.assertIs(h1, h2) def test_map_basics_4(self): - h = Map() + h = self.Map() h1 = h.set('key', []) h2 = h1.set('key', []) self.assertIsNot(h1, h2) @@ -159,7 +161,7 @@ class MapTest(unittest.TestCase): k2 = HashKey(10, 'bbb') k3 = HashKey(10, 'ccc') - h = Map() + h = self.Map() h2 = h.set(k1, 'a') h3 = h2.set(k2, 'b') @@ -199,7 +201,7 @@ class MapTest(unittest.TestCase): RUN_XTIMES = 3 for _ in range(RUN_XTIMES): - h = Map() + h = self.Map() d = dict() for i in range(COLLECTION_SIZE): @@ -290,7 +292,7 @@ class MapTest(unittest.TestCase): Er = HashKey(103, 'Er', error_on_eq_to=D) - h = Map() + h = self.Map() h = h.set(A, 'a') h = h.set(B, 'b') h = h.set(C, 'c') @@ -335,7 +337,7 @@ class MapTest(unittest.TestCase): Er = HashKey(201001, 'Er', error_on_eq_to=B) - h = Map() + h = self.Map() h = h.set(A, 'a') h = h.set(B, 'b') h = h.set(C, 'c') @@ -384,7 +386,7 @@ class MapTest(unittest.TestCase): D = HashKey(100100, 'D') E = HashKey(104, 'E') - h = Map() + h = self.Map() h = h.set(A, 'a') h = h.set(B, 'b') h = h.set(C, 'c') @@ -420,7 +422,7 @@ class MapTest(unittest.TestCase): D = HashKey(100100, 'D') E = HashKey(100100, 'E') - h = Map() + h = self.Map() h = h.set(A, 'a') h = h.set(B, 'b') h = h.set(C, 'c') @@ -456,7 +458,7 @@ class MapTest(unittest.TestCase): self.assertEqual(len(h), 0) def test_map_delete_5(self): - h = Map() + h = self.Map() keys = [] for i in range(17): @@ -512,7 +514,7 @@ class MapTest(unittest.TestCase): E = HashKey(104, 'E') F = HashKey(110, 'F') - h = Map() + h = self.Map() h = h.set(A, 'a') h = h.set(B, 'b') h = h.set(C, 'c') @@ -533,7 +535,7 @@ class MapTest(unittest.TestCase): E = HashKey(100100, 'E') F = HashKey(110, 'F') - h = Map() + h = self.Map() h = h.set(A, 'a') h = h.set(B, 'b') h = h.set(C, 'c') @@ -554,7 +556,7 @@ class MapTest(unittest.TestCase): E = HashKey(100100, 'E') F = HashKey(110, 'F') - h = Map() + h = self.Map() h = h.set(A, 'a') h = h.set(B, 'b') h = h.set(C, 'c') @@ -566,7 +568,7 @@ class MapTest(unittest.TestCase): self.assertEqual(set(list(h)), {A, B, C, D, E, F}) def test_map_items_3(self): - h = Map() + h = self.Map() self.assertEqual(len(h.items()), 0) self.assertEqual(list(h.items()), []) @@ -577,13 +579,13 @@ class MapTest(unittest.TestCase): D = HashKey(100100, 'D') E = HashKey(120, 'E') - h1 = Map() + h1 = self.Map() h1 = h1.set(A, 'a') h1 = h1.set(B, 'b') h1 = h1.set(C, 'c') h1 = h1.set(D, 'd') - h2 = Map() + h2 = self.Map() h2 = h2.set(A, 'a') self.assertFalse(h1 == h2) @@ -621,10 +623,10 @@ class MapTest(unittest.TestCase): A = HashKey(100, 'A') Er = HashKey(100, 'Er', error_on_eq_to=A) - h1 = Map() + h1 = self.Map() h1 = h1.set(A, 'a') - h2 = Map() + h2 = self.Map() h2 = h2.set(Er, 'a') with self.assertRaisesRegex(ValueError, 'cannot compare'): @@ -636,7 +638,7 @@ class MapTest(unittest.TestCase): def test_map_gc_1(self): A = HashKey(100, 'A') - h = Map() + h = self.Map() h = h.set(0, 0) # empty Map node is memoized in _map.c ref = weakref.ref(h) @@ -659,7 +661,7 @@ class MapTest(unittest.TestCase): def test_map_gc_2(self): A = HashKey(100, 'A') - h = Map() + h = self.Map() h = h.set(A, 'a') h = h.set(A, h) @@ -681,7 +683,7 @@ class MapTest(unittest.TestCase): B = HashKey(101, 'B') - h = Map() + h = self.Map() h = h.set(A, 1) self.assertTrue(A in h) @@ -701,7 +703,7 @@ class MapTest(unittest.TestCase): B = HashKey(101, 'B') - h = Map() + h = self.Map() h = h.set(A, 1) self.assertEqual(h[A], 1) @@ -719,7 +721,7 @@ class MapTest(unittest.TestCase): h[AA] def test_repr_1(self): - h = Map() + h = self.Map() self.assertTrue(repr(h).startswith('