aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYury Selivanov <yury@magic.io>2018-04-02 22:43:52 -0400
committerYury Selivanov <yury@magic.io>2018-04-02 23:51:14 -0400
commit451a84825d82e7fba4022857085ee2977f9a1d09 (patch)
treef9dd6105d5ae5040a864c799cc747a19faaab1c6
parent552544080cd9a46c6be612f35b924515d998dbe9 (diff)
downloadimmutables-451a84825d82e7fba4022857085ee2977f9a1d09.tar.gz
immutables-451a84825d82e7fba4022857085ee2977f9a1d09.zip
Add pure Python implementation (compatible with PyPy)
-rw-r--r--immutables/__init__.py5
-rw-r--r--immutables/map.py460
-rw-r--r--setup.py18
-rw-r--r--tests/test_map.py79
4 files changed, 525 insertions, 37 deletions
diff --git a/immutables/__init__.py b/immutables/__init__.py
index 8ed9faa..d598413 100644
--- a/immutables/__init__.py
+++ b/immutables/__init__.py
@@ -1,4 +1,7 @@
-from ._map import Map
+try:
+ from ._map import Map
+except ImportError:
+ from .map import Map
__all__ = 'Map',
diff --git a/immutables/map.py b/immutables/map.py
new file mode 100644
index 0000000..ffed188
--- /dev/null
+++ b/immutables/map.py
@@ -0,0 +1,460 @@
+import reprlib
+
+
+def map_hash(o):
+ x = hash(o)
+ return (x & 0xffffffff) ^ ((x >> 32) & 0xffffffff)
+
+
+def map_mask(hash, shift):
+ return (hash >> shift) & 0x01f
+
+
+def map_bitpos(hash, shift):
+ return 1 << map_mask(hash, shift)
+
+
+def map_bitcount(v):
+ v = v - ((v >> 1) & 0x55555555)
+ v = (v & 0x33333333) + ((v >> 2) & 0x33333333)
+ v = (v & 0x0F0F0F0F) + ((v >> 4) & 0x0F0F0F0F)
+ v = v + (v >> 8)
+ v = (v + (v >> 16)) & 0x3F
+ return v
+
+
+def map_bitindex(bitmap, bit):
+ return map_bitcount(bitmap & (bit - 1))
+
+
+def shuffle_bits(h):
+ # used in Map.__hash__
+ return ((h ^ 89869747) ^ (h << 16)) * 3644798167
+
+
+W_EMPTY, W_NEWNODE, W_NOT_FOUND = range(3)
+
+
+class BitmapNode:
+
+ def __init__(self, size, bitmap, array):
+ self.size = size
+ self.bitmap = bitmap
+ assert isinstance(array, list) and len(array) == size
+ self.array = array
+
+ def clone(self):
+ return BitmapNode(self.size, self.bitmap, self.array.copy())
+
+ def assoc(self, shift, hash, key, val, added_leaf):
+ bit = map_bitpos(hash, shift)
+ idx = map_bitindex(self.bitmap, bit)
+
+ if self.bitmap & bit:
+ key_idx = 2 * idx
+ val_idx = key_idx + 1
+
+ key_or_null = self.array[key_idx]
+ val_or_node = self.array[val_idx]
+
+ if key_or_null is None:
+ sub_node = val_or_node.assoc(
+ shift + 5, hash, key, val, added_leaf)
+ if val_or_node is sub_node:
+ return self
+
+ ret = self.clone()
+ ret.array[val_idx] = sub_node
+ return ret
+
+ if key == key_or_null:
+ if val is val_or_node:
+ return self
+
+ ret = self.clone()
+ ret.array[val_idx] = val
+ return ret
+
+ existing_key_hash = map_hash(key_or_null)
+ if existing_key_hash == hash:
+ sub_node = CollisionNode(
+ 4, hash, [key_or_null, val_or_node, key, val])
+ else:
+ sub_node = BitmapNode(0, 0, [])
+ sub_node = sub_node.assoc(
+ shift + 5, existing_key_hash,
+ key_or_null, val_or_node, [False])
+ sub_node = sub_node.assoc(
+ shift + 5, hash, key, val, [False])
+
+ ret = self.clone()
+ ret.array[key_idx] = None
+ ret.array[val_idx] = sub_node
+ added_leaf[0] = True
+ return ret
+
+ else:
+ added_leaf[0] = True
+
+ key_idx = 2 * idx
+ val_idx = key_idx + 1
+
+ n = map_bitcount(self.bitmap)
+
+ new_array = self.array[:key_idx]
+ new_array.append(key)
+ new_array.append(val)
+ new_array.extend(self.array[key_idx:])
+ return BitmapNode(2 * (n + 1), self.bitmap | bit, new_array)
+
+ def find(self, shift, hash, key):
+ bit = map_bitpos(hash, shift)
+
+ if not (self.bitmap & bit):
+ raise KeyError
+
+ idx = map_bitindex(self.bitmap, bit)
+ key_idx = idx * 2
+ val_idx = key_idx + 1
+
+ key_or_null = self.array[key_idx]
+ val_or_node = self.array[val_idx]
+
+ if key_or_null is None:
+ return val_or_node.find(shift + 5, hash, key)
+
+ if key == key_or_null:
+ return val_or_node
+
+ raise KeyError
+
+ def without(self, shift, hash, key):
+ bit = map_bitpos(hash, shift)
+ if not (self.bitmap & bit):
+ return W_NOT_FOUND, None
+
+ idx = map_bitindex(self.bitmap, bit)
+ key_idx = 2 * idx
+ val_idx = key_idx + 1
+
+ key_or_null = self.array[key_idx]
+ val_or_node = self.array[val_idx]
+
+ if key_or_null is None:
+ res, sub_node = val_or_node.without(shift + 5, hash, key)
+
+ if res is W_EMPTY:
+ raise RuntimeError('unreachable code')
+
+ elif res is W_NEWNODE:
+ if (type(sub_node) is BitmapNode and
+ sub_node.size == 2 and
+ sub_node.array[0] is not None):
+ clone = self.clone()
+ clone.array[key_idx] = sub_node.array[0]
+ clone.array[val_idx] = sub_node.array[1]
+ return W_NEWNODE, clone
+
+ clone = self.clone()
+ clone.array[val_idx] = sub_node
+ return W_NEWNODE, clone
+
+ else:
+ assert sub_node is None
+ return res, None
+
+ else:
+ if key == key_or_null:
+ new_array = self.array[:key_idx]
+ new_array.extend(self.array[val_idx + 1:])
+ new_node = BitmapNode(
+ self.size - 2, self.bitmap & ~bit, new_array)
+ return W_NEWNODE, new_node
+
+ else:
+ return W_NOT_FOUND, None
+
+ def keys(self):
+ for i in range(0, self.size, 2):
+ key_or_null = self.array[i]
+
+ if key_or_null is None:
+ val_or_node = self.array[i + 1]
+ yield from val_or_node.keys()
+ else:
+ yield key_or_null
+
+ def values(self):
+ for i in range(0, self.size, 2):
+ key_or_null = self.array[i]
+ val_or_node = self.array[i + 1]
+
+ if key_or_null is None:
+ yield from val_or_node.values()
+ else:
+ yield val_or_node
+
+ def items(self):
+ for i in range(0, self.size, 2):
+ key_or_null = self.array[i]
+ val_or_node = self.array[i + 1]
+
+ if key_or_null is None:
+ yield from val_or_node.items()
+ else:
+ yield key_or_null, val_or_node
+
+ def dump(self, buf, level):
+ buf.append(
+ ' ' * (level + 1) +
+ 'BitmapNode(size={} count={} bitmap={} id={:0x}):'.format(
+ self.size, self.size / 2, bin(self.bitmap), id(self)))
+
+ for i in range(0, self.size, 2):
+ key_or_null = self.array[i]
+ val_or_node = self.array[i + 1]
+
+ pad = ' ' * (level + 2)
+
+ if key_or_null is None:
+ buf.append(pad + 'None:')
+ val_or_node.dump(buf, level + 2)
+ else:
+ buf.append(pad + '{!r}: {!r}'.format(key_or_null, val_or_node))
+
+
+class CollisionNode:
+
+ def __init__(self, size, hash, array):
+ self.size = size
+ self.hash = hash
+ self.array = array
+
+ def find_index(self, key):
+ for i in range(0, self.size, 2):
+ if self.array[i] == key:
+ return i
+ return -1
+
+ def find(self, shift, hash, key):
+ for i in range(0, self.size, 2):
+ if self.array[i] == key:
+ return self.array[i + 1]
+ raise KeyError
+
+ def assoc(self, shift, hash, key, val, added_leaf):
+ if hash == self.hash:
+ key_idx = self.find_index(key)
+
+ if key_idx == -1:
+ new_array = self.array.copy()
+ new_array.append(key)
+ new_array.append(val)
+ new_node = CollisionNode(self.size + 2, hash, new_array)
+ added_leaf[0] = True
+ return new_node
+
+ val_idx = key_idx + 1
+ if self.array[val_idx] is val:
+ return self
+
+ new_array = self.array.copy()
+ new_array[val_idx] = val
+ return CollisionNode(self.size, hash, new_array)
+
+ else:
+ new_node = BitmapNode(
+ 2, map_bitpos(self.hash, shift), [None, self])
+ return new_node.assoc(shift, hash, key, val, added_leaf)
+
+ def without(self, shift, hash, key):
+ if hash != self.hash:
+ return W_NOT_FOUND, None
+
+ key_idx = self.find_index(key)
+ if key_idx == -1:
+ return W_NOT_FOUND, None
+
+ new_size = self.size - 2
+ if new_size == 0:
+ return W_EMPTY, None
+
+ if new_size == 2:
+ if key_idx == 0:
+ new_array = [self.array[2], self.array[3]]
+ else:
+ assert key_idx == 2
+ new_array = [self.array[0], self.array[1]]
+
+ new_node = BitmapNode(2, map_bitpos(hash, shift), new_array)
+ return W_NEWNODE, new_node
+
+ new_array = self.array[:key_idx]
+ new_array.extend(self.array[key_idx + 2:])
+ new_node = CollisionNode(self.size - 2, self.hash, new_array)
+ return W_NEWNODE, new_node
+
+ def keys(self):
+ for i in range(0, self.size, 2):
+ yield self.array[i]
+
+ def values(self):
+ for i in range(1, self.size, 2):
+ yield self.array[i]
+
+ def items(self):
+ for i in range(0, self.size, 2):
+ yield self.array[i], self.array[i + 1]
+
+ def dump(self, buf, level):
+ pad = ' ' * (level + 1)
+ buf.append(
+ pad + 'CollisionNode(size={} id={:0x}):'.format(
+ self.size, id(self)))
+
+ pad = ' ' * (level + 2)
+ for i in range(0, self.size, 2):
+ key = self.array[i]
+ val = self.array[i + 1]
+
+ buf.append('{}{!r}: {!r}'.format(pad, key, val))
+
+
+class GenWrapper:
+
+ def __init__(self, count, gen):
+ self.__count = count
+ self.__gen = gen
+
+ def __len__(self):
+ return self.__count
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ return next(self.__gen)
+
+
+class Map:
+
+ def __init__(self):
+ self.__count = 0
+ self.__root = BitmapNode(0, 0, [])
+ self.__hash = -1
+
+ def __len__(self):
+ return self.__count
+
+ def __eq__(self, other):
+ if not isinstance(other, Map):
+ return NotImplemented
+
+ if len(self) != len(other):
+ return False
+
+ for key, val in self.__root.items():
+ try:
+ oval = other.__root.find(0, map_hash(key), key)
+ except KeyError:
+ return False
+ else:
+ if oval != val:
+ return False
+
+ return True
+
+ def set(self, key, val):
+ added = [False]
+
+ new_count = self.__count
+ new_root = self.__root.assoc(0, map_hash(key), key, val, added)
+
+ if new_root is self.__root:
+ assert not added[0]
+ return self
+
+ if added[0]:
+ new_count += 1
+
+ m = Map.__new__(Map)
+ m.__count = new_count
+ m.__root = new_root
+ m.__hash = -1
+ return m
+
+ def delete(self, key):
+ res, node = self.__root.without(0, map_hash(key), key)
+ if res is W_EMPTY:
+ return Map()
+ elif res is W_NOT_FOUND:
+ # raise KeyError(key)
+ return self
+ else:
+ m = Map.__new__(Map)
+ m.__count = self.__count - 1
+ m.__root = node
+ m.__hash = -1
+ return m
+
+ def get(self, key, default=None):
+ try:
+ return self.__root.find(0, map_hash(key), key)
+ except KeyError:
+ return default
+
+ def __getitem__(self, key):
+ return self.__root.find(0, map_hash(key), key)
+
+ def __contains__(self, key):
+ try:
+ self.__root.find(0, map_hash(key), key)
+ except KeyError:
+ return False
+ else:
+ return True
+
+ def __iter__(self):
+ yield from self.__root.keys()
+
+ def keys(self):
+ return GenWrapper(self.__count, self.__root.keys())
+
+ def values(self):
+ return GenWrapper(self.__count, self.__root.values())
+
+ def items(self):
+ return GenWrapper(self.__count, self.__root.items())
+
+ def __hash__(self):
+ if self.__hash != -1:
+ return self.__hash
+
+ h = 0
+ for key, value in self.__root.items():
+ h ^= shuffle_bits(hash(key))
+ h ^= shuffle_bits(hash(value))
+
+ h ^= (self.__count * 2 + 1) * 1927868237
+
+ h ^= (h >> 11) ^ (h >> 25)
+ h = h * 69069 + 907133923
+
+ if h == -1:
+ h = -2
+
+ self.__hash = h
+ return h
+
+ @reprlib.recursive_repr("{...}")
+ def __repr__(self):
+ items = []
+ for key, val in self.items():
+ items.append("{!r}: {!r}".format(key, val))
+ return '<immutables.Map({{{}}}) at 0x{:0x}>'.format(
+ ', '.join(items), id(self))
+
+ def __dump__(self):
+ buf = []
+ self.__root.dump(buf, 0)
+ return '\n'.join(buf)
diff --git a/setup.py b/setup.py
index b0cd4c9..40877c0 100644
--- a/setup.py
+++ b/setup.py
@@ -21,6 +21,17 @@ with open(os.path.join(
'unable to read the version from immutables/__init__.py')
+if platform.python_implementation() == 'CPython':
+ ext_modules = [
+ setuptools.Extension(
+ "immutables._map",
+ ["immutables/_map.c"],
+ extra_compile_args=CFLAGS)
+ ]
+else:
+ ext_modules = []
+
+
setuptools.setup(
name='immutables',
version=VERSION,
@@ -42,11 +53,6 @@ setuptools.setup(
packages=['immutables'],
provides=['immutables'],
include_package_data=True,
- ext_modules=[
- setuptools.Extension(
- "immutables._map",
- ["immutables/_map.c"],
- extra_compile_args=CFLAGS)
- ],
+ ext_modules=ext_modules,
test_suite='tests.suite',
)
diff --git a/tests/test_map.py b/tests/test_map.py
index 91050a2..03dc9a0 100644
--- a/tests/test_map.py
+++ b/tests/test_map.py
@@ -3,7 +3,7 @@ import random
import unittest
import weakref
-from immutables import Map
+from immutables.map import Map as PyMap
class HashKey:
@@ -88,7 +88,9 @@ class ReprError(Exception):
pass
-class MapTest(unittest.TestCase):
+class BaseMapTest:
+
+ Map = None
def test_hashkey_helper_1(self):
k1 = HashKey(10, 'aaa')
@@ -105,11 +107,11 @@ class MapTest(unittest.TestCase):
self.assertEqual(d[k2], 'b')
def test_map_basics_1(self):
- h = Map()
+ h = self.Map()
h = None # NoQA
def test_map_basics_2(self):
- h = Map()
+ h = self.Map()
self.assertEqual(len(h), 0)
h2 = h.set('a', 'b')
@@ -139,14 +141,14 @@ class MapTest(unittest.TestCase):
h = h2 = h3 = None
def test_map_basics_3(self):
- h = Map()
+ h = self.Map()
o = object()
h1 = h.set('1', o)
h2 = h1.set('1', o)
self.assertIs(h1, h2)
def test_map_basics_4(self):
- h = Map()
+ h = self.Map()
h1 = h.set('key', [])
h2 = h1.set('key', [])
self.assertIsNot(h1, h2)
@@ -159,7 +161,7 @@ class MapTest(unittest.TestCase):
k2 = HashKey(10, 'bbb')
k3 = HashKey(10, 'ccc')
- h = Map()
+ h = self.Map()
h2 = h.set(k1, 'a')
h3 = h2.set(k2, 'b')
@@ -199,7 +201,7 @@ class MapTest(unittest.TestCase):
RUN_XTIMES = 3
for _ in range(RUN_XTIMES):
- h = Map()
+ h = self.Map()
d = dict()
for i in range(COLLECTION_SIZE):
@@ -290,7 +292,7 @@ class MapTest(unittest.TestCase):
Er = HashKey(103, 'Er', error_on_eq_to=D)
- h = Map()
+ h = self.Map()
h = h.set(A, 'a')
h = h.set(B, 'b')
h = h.set(C, 'c')
@@ -335,7 +337,7 @@ class MapTest(unittest.TestCase):
Er = HashKey(201001, 'Er', error_on_eq_to=B)
- h = Map()
+ h = self.Map()
h = h.set(A, 'a')
h = h.set(B, 'b')
h = h.set(C, 'c')
@@ -384,7 +386,7 @@ class MapTest(unittest.TestCase):
D = HashKey(100100, 'D')
E = HashKey(104, 'E')
- h = Map()
+ h = self.Map()
h = h.set(A, 'a')
h = h.set(B, 'b')
h = h.set(C, 'c')
@@ -420,7 +422,7 @@ class MapTest(unittest.TestCase):
D = HashKey(100100, 'D')
E = HashKey(100100, 'E')
- h = Map()
+ h = self.Map()
h = h.set(A, 'a')
h = h.set(B, 'b')
h = h.set(C, 'c')
@@ -456,7 +458,7 @@ class MapTest(unittest.TestCase):
self.assertEqual(len(h), 0)
def test_map_delete_5(self):
- h = Map()
+ h = self.Map()
keys = []
for i in range(17):
@@ -512,7 +514,7 @@ class MapTest(unittest.TestCase):
E = HashKey(104, 'E')
F = HashKey(110, 'F')
- h = Map()
+ h = self.Map()
h = h.set(A, 'a')
h = h.set(B, 'b')
h = h.set(C, 'c')
@@ -533,7 +535,7 @@ class MapTest(unittest.TestCase):
E = HashKey(100100, 'E')
F = HashKey(110, 'F')
- h = Map()
+ h = self.Map()
h = h.set(A, 'a')
h = h.set(B, 'b')
h = h.set(C, 'c')
@@ -554,7 +556,7 @@ class MapTest(unittest.TestCase):
E = HashKey(100100, 'E')
F = HashKey(110, 'F')
- h = Map()
+ h = self.Map()
h = h.set(A, 'a')
h = h.set(B, 'b')
h = h.set(C, 'c')
@@ -566,7 +568,7 @@ class MapTest(unittest.TestCase):
self.assertEqual(set(list(h)), {A, B, C, D, E, F})
def test_map_items_3(self):
- h = Map()
+ h = self.Map()
self.assertEqual(len(h.items()), 0)
self.assertEqual(list(h.items()), [])
@@ -577,13 +579,13 @@ class MapTest(unittest.TestCase):
D = HashKey(100100, 'D')
E = HashKey(120, 'E')
- h1 = Map()
+ h1 = self.Map()
h1 = h1.set(A, 'a')
h1 = h1.set(B, 'b')
h1 = h1.set(C, 'c')
h1 = h1.set(D, 'd')
- h2 = Map()
+ h2 = self.Map()
h2 = h2.set(A, 'a')
self.assertFalse(h1 == h2)
@@ -621,10 +623,10 @@ class MapTest(unittest.TestCase):
A = HashKey(100, 'A')
Er = HashKey(100, 'Er', error_on_eq_to=A)
- h1 = Map()
+ h1 = self.Map()
h1 = h1.set(A, 'a')
- h2 = Map()
+ h2 = self.Map()
h2 = h2.set(Er, 'a')
with self.assertRaisesRegex(ValueError, 'cannot compare'):
@@ -636,7 +638,7 @@ class MapTest(unittest.TestCase):
def test_map_gc_1(self):
A = HashKey(100, 'A')
- h = Map()
+ h = self.Map()
h = h.set(0, 0) # empty Map node is memoized in _map.c
ref = weakref.ref(h)
@@ -659,7 +661,7 @@ class MapTest(unittest.TestCase):
def test_map_gc_2(self):
A = HashKey(100, 'A')
- h = Map()
+ h = self.Map()
h = h.set(A, 'a')
h = h.set(A, h)
@@ -681,7 +683,7 @@ class MapTest(unittest.TestCase):
B = HashKey(101, 'B')
- h = Map()
+ h = self.Map()
h = h.set(A, 1)
self.assertTrue(A in h)
@@ -701,7 +703,7 @@ class MapTest(unittest.TestCase):
B = HashKey(101, 'B')
- h = Map()
+ h = self.Map()
h = h.set(A, 1)
self.assertEqual(h[A], 1)
@@ -719,7 +721,7 @@ class MapTest(unittest.TestCase):
h[AA]
def test_repr_1(self):
- h = Map()
+ h = self.Map()
self.assertTrue(repr(h).startswith('<immutables.Map({}) at 0x'))
h = h.set(1, 2).set(2, 3).set(3, 4)
@@ -727,7 +729,7 @@ class MapTest(unittest.TestCase):
'<immutables.Map({1: 2, 2: 3, 3: 4}) at 0x'))
def test_repr_2(self):
- h = Map()
+ h = self.Map()
A = HashKey(100, 'A')
with self.assertRaises(ReprError):
@@ -749,7 +751,7 @@ class MapTest(unittest.TestCase):
def __repr__(self):
return repr(self.val)
- h = Map()
+ h = self.Map()
k = Key()
h = h.set(k, 1)
k.val = h
@@ -758,7 +760,7 @@ class MapTest(unittest.TestCase):
'<immutables.Map({{...}: 1}) at 0x'))
def test_hash_1(self):
- h = Map()
+ h = self.Map()
self.assertNotEqual(hash(h), -1)
self.assertEqual(hash(h), hash(h))
@@ -771,7 +773,7 @@ class MapTest(unittest.TestCase):
hash(h.set('a', 'b').set(1, 2)))
def test_hash_2(self):
- h = Map()
+ h = self.Map()
A = HashKey(100, 'A')
m = h.set(1, 2).set(A, 3).set(3, 4)
@@ -785,5 +787,22 @@ class MapTest(unittest.TestCase):
hash(m)
+class PyMapTest(BaseMapTest, unittest.TestCase):
+
+ Map = PyMap
+
+
+try:
+ from immutables._map import Map as CMap
+except ImportError:
+ CMap = None
+
+
+@unittest.skipIf(CMap is None, 'C Map is not available')
+class CMapTest(BaseMapTest, unittest.TestCase):
+
+ Map = CMap
+
+
if __name__ == "__main__":
unittest.main()