summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--immutables/map.py31
-rw-r--r--tests/test_none_keys.py511
2 files changed, 530 insertions, 12 deletions
diff --git a/immutables/map.py b/immutables/map.py
index ac7ebd7..7c16139 100644
--- a/immutables/map.py
+++ b/immutables/map.py
@@ -46,6 +46,13 @@ def map_bitindex(bitmap, bit):
W_EMPTY, W_NEWNODE, W_NOT_FOUND = range(3)
void = object()
+class _Unhashable:
+ __slots__ = ()
+ __hash__ = None
+
+_NULL = _Unhashable()
+del _Unhashable
+
class BitmapNode:
@@ -70,7 +77,7 @@ class BitmapNode:
key_or_null = self.array[key_idx]
val_or_node = self.array[val_idx]
- if key_or_null is None:
+ if key_or_null is _NULL:
sub_node, added = val_or_node.assoc(
shift + 5, hash, key, val, mutid)
if val_or_node is sub_node:
@@ -111,12 +118,12 @@ class BitmapNode:
mutid)
if mutid and mutid == self.mutid:
- self.array[key_idx] = None
+ self.array[key_idx] = _NULL
self.array[val_idx] = sub_node
return self, True
else:
ret = self.clone(mutid)
- ret.array[key_idx] = None
+ ret.array[key_idx] = _NULL
ret.array[val_idx] = sub_node
return ret, True
@@ -153,7 +160,7 @@ class BitmapNode:
key_or_null = self.array[key_idx]
val_or_node = self.array[val_idx]
- if key_or_null is None:
+ if key_or_null is _NULL:
return val_or_node.find(shift + 5, hash, key)
if key == key_or_null:
@@ -173,7 +180,7 @@ class BitmapNode:
key_or_null = self.array[key_idx]
val_or_node = self.array[val_idx]
- if key_or_null is None:
+ if key_or_null is _NULL:
res, sub_node = val_or_node.without(shift + 5, hash, key, mutid)
if res is W_EMPTY:
@@ -182,7 +189,7 @@ class BitmapNode:
elif res is W_NEWNODE:
if (type(sub_node) is BitmapNode and
sub_node.size == 2 and
- sub_node.array[0] is not None):
+ sub_node.array[0] is not _NULL):
if mutid and mutid == self.mutid:
self.array[key_idx] = sub_node.array[0]
@@ -231,7 +238,7 @@ class BitmapNode:
for i in range(0, self.size, 2):
key_or_null = self.array[i]
- if key_or_null is None:
+ if key_or_null is _NULL:
val_or_node = self.array[i + 1]
yield from val_or_node.keys()
else:
@@ -242,7 +249,7 @@ class BitmapNode:
key_or_null = self.array[i]
val_or_node = self.array[i + 1]
- if key_or_null is None:
+ if key_or_null is _NULL:
yield from val_or_node.values()
else:
yield val_or_node
@@ -252,7 +259,7 @@ class BitmapNode:
key_or_null = self.array[i]
val_or_node = self.array[i + 1]
- if key_or_null is None:
+ if key_or_null is _NULL:
yield from val_or_node.items()
else:
yield key_or_null, val_or_node
@@ -269,8 +276,8 @@ class BitmapNode:
pad = ' ' * (level + 2)
- if key_or_null is None:
- buf.append(pad + 'None:')
+ if key_or_null is _NULL:
+ buf.append(pad + 'NULL:')
val_or_node.dump(buf, level + 2)
else:
buf.append(pad + '{!r}: {!r}'.format(key_or_null, val_or_node))
@@ -328,7 +335,7 @@ class CollisionNode:
else:
new_node = BitmapNode(
- 2, map_bitpos(self.hash, shift), [None, self], mutid)
+ 2, map_bitpos(self.hash, shift), [_NULL, self], mutid)
return new_node.assoc(shift, hash, key, val, mutid)
def without(self, shift, hash, key, mutid):
diff --git a/tests/test_none_keys.py b/tests/test_none_keys.py
new file mode 100644
index 0000000..3662e9c
--- /dev/null
+++ b/tests/test_none_keys.py
@@ -0,0 +1,511 @@
+import unittest
+
+from immutables.map import map_hash, map_mask, Map as PyMap
+from tests.test_map import HashKey
+
+
+none_hash = map_hash(None)
+assert(none_hash != 1)
+assert((none_hash >> 32) == 0)
+
+not_collision = 0xffffffff & (~none_hash)
+
+mask = 0x7ffffffff
+none_collisions = [none_hash & (mask >> shift)
+ for shift in reversed(range(0, 32, 5))]
+assert(len(none_collisions) == 7)
+none_collisions = [h | (not_collision & (mask << shift))
+ for shift, h in zip(range(5, 37, 5), none_collisions)]
+
+
+class NoneCollision(HashKey):
+ def __init__(self, name, level):
+ if name is None:
+ raise ValueError("Can't have a NoneCollision with a None value")
+ super().__init__(none_collisions[level], name)
+
+ def __eq__(self, other):
+ if other is None:
+ return False
+ return super().__eq__(other)
+
+ __hash__ = HashKey.__hash__
+
+
+class BaseNoneTest:
+ Map = None
+
+ def test_none_collisions(self):
+ collisions = [NoneCollision('a', level) for level in range(7)]
+ indices = [map_mask(none_hash, shift) for shift in range(0, 32, 5)]
+
+ for i, c in enumerate(collisions[:-1], 1):
+ self.assertNotEqual(c, None)
+ c_hash = map_hash(c)
+ self.assertNotEqual(c_hash, none_hash)
+ for j, idx in enumerate(indices[:i]):
+ self.assertEqual(map_mask(c_hash, j*5), idx)
+ for j, idx in enumerate(indices[i:], i):
+ self.assertNotEqual(map_mask(c_hash, j*5), idx)
+
+ c = collisions[-1]
+ self.assertNotEqual(c, None)
+ c_hash = map_hash(c)
+ self.assertEqual(c_hash, none_hash)
+ for i, idx in enumerate(indices):
+ self.assertEqual(map_mask(c_hash, i*5), idx)
+
+ def test_none_as_key(self):
+ m = self.Map({None: 1})
+
+ self.assertEqual(len(m), 1)
+ self.assertTrue(None in m)
+ self.assertEqual(m[None], 1)
+ self.assertTrue(repr(m).startswith('<immutables.Map({None: 1}) at 0x'))
+
+ for level in range(7):
+ key = NoneCollision('a', level)
+ self.assertFalse(key in m)
+ with self.assertRaises(KeyError):
+ m.delete(key)
+
+ m = m.delete(None)
+ self.assertEqual(len(m), 0)
+ self.assertFalse(None in m)
+ self.assertTrue(repr(m).startswith('<immutables.Map({}) at 0x'))
+
+ self.assertEqual(m, self.Map())
+
+ with self.assertRaises(KeyError):
+ m.delete(None)
+
+ def test_none_set(self):
+ m = self.Map().set(None, 2)
+
+ self.assertEqual(len(m), 1)
+ self.assertTrue(None in m)
+ self.assertEqual(m[None], 2)
+
+ m = m.set(None, 1)
+
+ self.assertEqual(len(m), 1)
+ self.assertTrue(None in m)
+ self.assertEqual(m[None], 1)
+
+ m = m.delete(None)
+
+ self.assertEqual(len(m), 0)
+ self.assertEqual(m, self.Map())
+ self.assertFalse(None in m)
+
+ with self.assertRaises(KeyError):
+ m.delete(None)
+
+ def test_none_collision_1(self):
+ for level in range(7):
+ key = NoneCollision('a', level)
+ m = self.Map({None: 1, key: 2})
+
+ self.assertEqual(len(m), 2)
+ self.assertTrue(None in m)
+ self.assertEqual(m[None], 1)
+ self.assertTrue(key in m)
+ self.assertEqual(m[key], 2)
+
+ m2 = m.delete(None)
+ self.assertEqual(len(m2), 1)
+ self.assertTrue(key in m2)
+ self.assertEqual(m2[key], 2)
+ self.assertFalse(None in m2)
+ with self.assertRaises(KeyError):
+ m2.delete(None)
+
+ m3 = m2.delete(key)
+ self.assertEqual(len(m3), 0)
+ self.assertFalse(None in m3)
+ self.assertFalse(key in m3)
+ self.assertEqual(m3, self.Map())
+ self.assertTrue(repr(m3).startswith('<immutables.Map({}) at 0x'))
+ with self.assertRaises(KeyError):
+ m3.delete(None)
+ with self.assertRaises(KeyError):
+ m3.delete(key)
+
+ m2 = m.delete(key)
+ self.assertEqual(len(m2), 1)
+ self.assertTrue(None in m2)
+ self.assertEqual(m2[None], 1)
+ self.assertFalse(key in m2)
+ with self.assertRaises(KeyError):
+ m2.delete(key)
+
+ m4 = m2.delete(None)
+ self.assertEqual(len(m4), 0)
+ self.assertFalse(None in m4)
+ self.assertFalse(key in m4)
+ self.assertEqual(m4, self.Map())
+ self.assertTrue(repr(m4).startswith('<immutables.Map({}) at 0x'))
+ with self.assertRaises(KeyError):
+ m4.delete(None)
+ with self.assertRaises(KeyError):
+ m4.delete(key)
+
+ self.assertEqual(m3, m4)
+
+ def test_none_collision_2(self):
+ key = HashKey(not_collision, 'a')
+ m = self.Map().set(None, 1).set(key, 2)
+
+ self.assertEqual(len(m), 2)
+ self.assertTrue(key in m)
+ self.assertTrue(None in m)
+ self.assertEqual(m[key], 2)
+ self.assertEqual
+
+ m = m.set(None, 0)
+ self.assertEqual(len(m), 2)
+ self.assertTrue(key in m)
+ self.assertTrue(None in m)
+
+ for level in range(7):
+ key2 = NoneCollision('b', level)
+ self.assertFalse(key2 in m)
+ m2 = m.set(key2, 1)
+
+ self.assertEqual(len(m2), 3)
+ self.assertTrue(key in m2)
+ self.assertTrue(None in m2)
+ self.assertTrue(key2 in m2)
+ self.assertEqual(m2[key], 2)
+ self.assertEqual(m2[None], 0)
+ self.assertEqual(m2[key2], 1)
+
+ m2 = m2.set(None, 1)
+ self.assertEqual(len(m2), 3)
+ self.assertTrue(key in m2)
+ self.assertTrue(None in m2)
+ self.assertTrue(key2 in m2)
+ self.assertEqual(m2[key], 2)
+ self.assertEqual(m2[None], 1)
+ self.assertEqual(m2[key2], 1)
+
+ m2 = m2.set(None, 2)
+ self.assertEqual(len(m2), 3)
+ self.assertTrue(key in m2)
+ self.assertTrue(None in m2)
+ self.assertTrue(key2 in m2)
+ self.assertEqual(m2[key], 2)
+ self.assertEqual(m2[None], 2)
+ self.assertEqual(m2[key2], 1)
+
+ m3 = m2.delete(key)
+ self.assertEqual(len(m3), 2)
+ self.assertTrue(None in m3)
+ self.assertTrue(key2 in m3)
+ self.assertFalse(key in m3)
+ self.assertEqual(m3[None], 2)
+ self.assertEqual(m3[key2], 1)
+ with self.assertRaises(KeyError):
+ m3.delete(key)
+
+ m3 = m2.delete(key2)
+ self.assertEqual(len(m3), 2)
+ self.assertTrue(None in m3)
+ self.assertTrue(key in m3)
+ self.assertFalse(key2 in m3)
+ self.assertEqual(m3[None], 2)
+ self.assertEqual(m3[key], 2)
+ with self.assertRaises(KeyError):
+ m3.delete(key2)
+
+ m3 = m2.delete(None)
+ self.assertEqual(len(m3), 2)
+ self.assertTrue(key in m3)
+ self.assertTrue(key2 in m3)
+ self.assertFalse(None in m3)
+ self.assertEqual(m3[key], 2)
+ self.assertEqual(m3[key2], 1)
+ with self.assertRaises(KeyError):
+ m3.delete(None)
+
+ m2 = m.delete(None)
+ self.assertEqual(len(m2), 1)
+ self.assertFalse(None in m2)
+ self.assertTrue(key in m2)
+ self.assertEqual(m2[key], 2)
+ with self.assertRaises(KeyError):
+ m2.delete(None)
+
+ m2 = m.delete(key)
+ self.assertEqual(len(m2), 1)
+ self.assertFalse(key in m2)
+ self.assertTrue(None in m2)
+ self.assertEqual(m2[None], 0)
+ with self.assertRaises(KeyError):
+ m2.delete(key)
+
+ def test_none_collision_3(self):
+ for level in range(7):
+ key = NoneCollision('a', level)
+ m = self.Map({key: 2})
+
+ self.assertEqual(len(m), 1)
+ self.assertFalse(None in m)
+ self.assertTrue(key in m)
+ self.assertEqual(m[key], 2)
+ with self.assertRaises(KeyError):
+ m.delete(None)
+
+ m = m.set(None, 1)
+ self.assertEqual(len(m), 2)
+ self.assertTrue(key in m)
+ self.assertEqual(m[key], 2)
+ self.assertTrue(None in m)
+ self.assertEqual(m[None], 1)
+
+ m = m.set(None, 0)
+ self.assertEqual(len(m), 2)
+ self.assertTrue(key in m)
+ self.assertEqual(m[key], 2)
+ self.assertTrue(None in m)
+ self.assertEqual(m[None], 0)
+
+ m2 = m.delete(key)
+ self.assertEqual(len(m2), 1)
+ self.assertTrue(None in m2)
+ self.assertEqual(m2[None], 0)
+ self.assertFalse(key in m2)
+ with self.assertRaises(KeyError):
+ m2.delete(key)
+
+ m2 = m.delete(None)
+ self.assertEqual(len(m2), 1)
+ self.assertTrue(key in m2)
+ self.assertEqual(m2[key], 2)
+ self.assertFalse(None in m2)
+ with self.assertRaises(KeyError):
+ m2.delete(None)
+
+ def test_collision_4(self):
+ key2 = NoneCollision('a', 2)
+ key4 = NoneCollision('b', 4)
+ m = self.Map({key2: 2, key4: 4})
+
+ self.assertEqual(len(m), 2)
+ self.assertTrue(key2 in m)
+ self.assertTrue(key4 in m)
+ self.assertEqual(m[key2], 2)
+ self.assertEqual(m[key4], 4)
+ self.assertFalse(None in m)
+
+ m2 = m.set(None, 9)
+
+ self.assertEqual(len(m2), 3)
+ self.assertTrue(key2 in m2)
+ self.assertTrue(key4 in m2)
+ self.assertTrue(None in m2)
+ self.assertEqual(m2[key2], 2)
+ self.assertEqual(m2[key4], 4)
+ self.assertEqual(m2[None], 9)
+
+ m3 = m2.set(None, 0)
+ self.assertEqual(len(m3), 3)
+ self.assertTrue(key2 in m3)
+ self.assertTrue(key4 in m3)
+ self.assertTrue(None in m3)
+ self.assertEqual(m3[key2], 2)
+ self.assertEqual(m3[key4], 4)
+ self.assertEqual(m3[None], 0)
+
+ m3 = m2.set(key2, 0)
+ self.assertEqual(len(m3), 3)
+ self.assertTrue(key2 in m3)
+ self.assertTrue(key4 in m3)
+ self.assertTrue(None in m3)
+ self.assertEqual(m3[key2], 0)
+ self.assertEqual(m3[key4], 4)
+ self.assertEqual(m3[None], 9)
+
+ m3 = m2.set(key4, 0)
+ self.assertEqual(len(m3), 3)
+ self.assertTrue(key2 in m3)
+ self.assertTrue(key4 in m3)
+ self.assertTrue(None in m3)
+ self.assertEqual(m3[key2], 2)
+ self.assertEqual(m3[key4], 0)
+ self.assertEqual(m3[None], 9)
+
+ m3 = m2.delete(None)
+ self.assertEqual(m3, m)
+ self.assertEqual(len(m3), 2)
+ self.assertTrue(key2 in m3)
+ self.assertTrue(key4 in m3)
+ self.assertEqual(m3[key2], 2)
+ self.assertEqual(m3[key4], 4)
+ self.assertFalse(None in m3)
+ with self.assertRaises(KeyError):
+ m3.delete(None)
+
+ m3 = m2.delete(key2)
+ self.assertEqual(len(m3), 2)
+ self.assertTrue(None in m3)
+ self.assertTrue(key4 in m3)
+ self.assertEqual(m3[None], 9)
+ self.assertEqual(m3[key4], 4)
+ self.assertFalse(key2 in m3)
+ with self.assertRaises(KeyError):
+ m3.delete(key2)
+
+ m3 = m2.delete(key4)
+ self.assertEqual(len(m3), 2)
+ self.assertTrue(None in m3)
+ self.assertTrue(key2 in m3)
+ self.assertEqual(m3[None], 9)
+ self.assertEqual(m3[key2], 2)
+ self.assertFalse(key4 in m3)
+ with self.assertRaises(KeyError):
+ m3.delete(key4)
+
+ def test_none_mutation(self):
+ key2 = NoneCollision('a', 2)
+ key4 = NoneCollision('b', 4)
+ key = NoneCollision('c', -1)
+ m = self.Map({key: -1, key2: 2, key4: 4, None: 9})
+
+ with m.mutate() as mm:
+ self.assertEqual(len(mm), 4)
+ self.assertTrue(key in mm)
+ self.assertTrue(key2 in mm)
+ self.assertTrue(key4 in mm)
+ self.assertTrue(None in mm)
+ self.assertEqual(mm[key2], 2)
+ self.assertEqual(mm[key4], 4)
+ self.assertEqual(mm[key], -1)
+ self.assertEqual(mm[None], 9)
+
+ for k in m:
+ mm[k] = -mm[k]
+
+ self.assertEqual(len(mm), 4)
+ self.assertTrue(key in mm)
+ self.assertTrue(key2 in mm)
+ self.assertTrue(key4 in mm)
+ self.assertTrue(None in mm)
+ self.assertEqual(mm[key2], -2)
+ self.assertEqual(mm[key4], -4)
+ self.assertEqual(mm[key], 1)
+ self.assertEqual(mm[None], -9)
+
+ for k in m:
+ del mm[k]
+ self.assertEqual(len(mm), 3)
+ self.assertFalse(k in mm)
+ for n in m:
+ if n != k:
+ self.assertTrue(n in mm)
+ self.assertEqual(mm[n], -m[n])
+ with self.assertRaises(KeyError):
+ del mm[k]
+ mm[k] = -m[k]
+ self.assertEqual(len(mm), 4)
+ self.assertTrue(k in mm)
+ self.assertEqual(mm[k], -m[k])
+
+ for k in m:
+ mm[k] = -mm[k]
+
+ self.assertEqual(len(mm), 4)
+ self.assertTrue(key in mm)
+ self.assertTrue(key2 in mm)
+ self.assertTrue(key4 in mm)
+ self.assertTrue(None in mm)
+ self.assertEqual(mm[key2], 2)
+ self.assertEqual(mm[key4], 4)
+ self.assertEqual(mm[key], -1)
+ self.assertEqual(mm[None], 9)
+
+ for k in m:
+ mm[k] = -mm[k]
+
+ self.assertEqual(len(mm), 4)
+ self.assertTrue(key in mm)
+ self.assertTrue(key2 in mm)
+ self.assertTrue(key4 in mm)
+ self.assertTrue(None in mm)
+ self.assertEqual(mm[key2], -2)
+ self.assertEqual(mm[key4], -4)
+ self.assertEqual(mm[key], 1)
+ self.assertEqual(mm[None], -9)
+
+ m2 = mm.finish()
+
+ self.assertEqual(set(m), set(m2))
+ self.assertEqual(len(m2), 4)
+ self.assertTrue(key in m2)
+ self.assertTrue(key2 in m2)
+ self.assertTrue(key4 in m2)
+ self.assertTrue(None in m2)
+ self.assertEqual(m2[key2], -2)
+ self.assertEqual(m2[key4], -4)
+ self.assertEqual(m2[key], 1)
+ self.assertEqual(m2[None], -9)
+
+ for k, v in m.items():
+ self.assertTrue(k in m2)
+ self.assertEqual(m2[k], -v)
+
+ def test_iterators(self):
+ key2 = NoneCollision('a', 2)
+ key4 = NoneCollision('b', 4)
+ key = NoneCollision('c', -1)
+ m = self.Map({key: -1, key2: 2, key4: 4, None: 9})
+
+ self.assertEqual(len(m), 4)
+ self.assertTrue(key in m)
+ self.assertTrue(key2 in m)
+ self.assertTrue(key4 in m)
+ self.assertTrue(None in m)
+ self.assertEqual(m[key2], 2)
+ self.assertEqual(m[key4], 4)
+ self.assertEqual(m[key], -1)
+ self.assertEqual(m[None], 9)
+
+ s = set(m)
+ self.assertEqual(len(s), 4)
+ self.assertEqual(s, set([None, key, key2, key4]))
+
+ sk = set(m.keys())
+ self.assertEqual(s, sk)
+
+ sv = set(m.values())
+ self.assertEqual(len(sv), 4)
+ self.assertEqual(sv, set([-1, 2, 4, 9]))
+
+ si = set(m.items())
+ self.assertEqual(len(si), 4)
+ self.assertEqual(si,
+ set([(key, -1), (key2, 2), (key4, 4), (None, 9)]))
+
+ d = {key: -1, key2: 2, key4: 4, None: 9}
+ self.assertEqual(dict(m.items()), d)
+
+
+class PyMapNoneTest(BaseNoneTest, unittest.TestCase):
+
+ Map = PyMap
+
+
+try:
+ from immutables._map import Map as CMap
+except ImportError:
+ CMap = None
+
+
+@unittest.skipIf(CMap is None, 'C Map is not available')
+class CMapNoneTest(BaseNoneTest, unittest.TestCase):
+
+ Map = CMap
+
+
+if __name__ == "__main__":
+ unittest.main()