summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--immutables/_map.c89
-rw-r--r--tests/test_map.py54
2 files changed, 142 insertions, 1 deletions
diff --git a/immutables/_map.c b/immutables/_map.c
index e1efe19..6f76318 100644
--- a/immutables/_map.c
+++ b/immutables/_map.c
@@ -2854,6 +2854,94 @@ map_py_dump(MapObject *self, PyObject *args)
}
+static PyObject *
+map_py_repr(MapObject *self)
+{
+ Py_ssize_t i;
+ _PyUnicodeWriter writer;
+
+
+ i = Py_ReprEnter((PyObject *)self);
+ if (i != 0) {
+ return i > 0 ? PyUnicode_FromString("{...}") : NULL;
+ }
+
+ _PyUnicodeWriter_Init(&writer);
+
+ if (_PyUnicodeWriter_WriteASCIIString(
+ &writer, "<immutables.Map({", 17) < 0)
+ {
+ goto error;
+ }
+
+ MapIteratorState iter;
+ map_iter_t iter_res;
+ map_iterator_init(&iter, self->h_root);
+ int second = 0;
+ do {
+ PyObject *v_key;
+ PyObject *v_val;
+
+ iter_res = map_iterator_next(&iter, &v_key, &v_val);
+ if (iter_res == I_ITEM) {
+ if (second) {
+ if (_PyUnicodeWriter_WriteASCIIString(&writer, ", ", 2) < 0) {
+ goto error;
+ }
+ }
+
+ PyObject *s = PyObject_Repr(v_key);
+ if (s == NULL) {
+ goto error;
+ }
+ if (_PyUnicodeWriter_WriteStr(&writer, s) < 0) {
+ Py_DECREF(s);
+ goto error;
+ }
+ Py_DECREF(s);
+
+ if (_PyUnicodeWriter_WriteASCIIString(&writer, ": ", 2) < 0) {
+ goto error;
+ }
+
+ s = PyObject_Repr(v_val);
+ if (s == NULL) {
+ goto error;
+ }
+ if (_PyUnicodeWriter_WriteStr(&writer, s) < 0) {
+ Py_DECREF(s);
+ goto error;
+ }
+ Py_DECREF(s);
+ }
+
+ second = 1;
+ } while (iter_res != I_END);
+
+ if (_PyUnicodeWriter_WriteASCIIString(&writer, "})", 2) < 0) {
+ goto error;
+ }
+
+ PyObject *addr = PyUnicode_FromFormat(" at %p>", self);
+ if (addr == NULL) {
+ goto error;
+ }
+ if (_PyUnicodeWriter_WriteStr(&writer, addr) < 0) {
+ Py_DECREF(addr);
+ goto error;
+ }
+ Py_DECREF(addr);
+
+ Py_ReprLeave((PyObject *)self);
+ return _PyUnicodeWriter_Finish(&writer);
+
+error:
+ _PyUnicodeWriter_Dealloc(&writer);
+ Py_ReprLeave((PyObject *)self);
+ return NULL;
+}
+
+
static PyMethodDef Map_methods[] = {
{"set", (PyCFunction)map_py_set, METH_VARARGS, NULL},
{"get", (PyCFunction)map_py_get, METH_VARARGS, NULL},
@@ -2900,6 +2988,7 @@ PyTypeObject _Map_Type = {
.tp_new = map_tp_new,
.tp_weaklistoffset = offsetof(MapObject, h_weakreflist),
.tp_hash = PyObject_HashNotImplemented,
+ .tp_repr = (reprfunc)map_py_repr,
};
diff --git a/tests/test_map.py b/tests/test_map.py
index a17e005..06c82b9 100644
--- a/tests/test_map.py
+++ b/tests/test_map.py
@@ -16,6 +16,8 @@ class HashKey:
self.error_on_eq_to = error_on_eq_to
def __repr__(self):
+ if self._crasher is not None and self._crasher.error_on_repr:
+ raise ReprError
return '<Key name:{} hash:{}>'.format(self.name, self.hash)
def __hash__(self):
@@ -51,12 +53,19 @@ class KeyStr(str):
raise EqError
return super().__eq__(other)
+ def __repr__(self, other):
+ if HashKey._crasher is not None and HashKey._crasher.error_on_repr:
+ raise ReprError
+ return super().__eq__(other)
+
class HaskKeyCrasher:
- def __init__(self, *, error_on_hash=False, error_on_eq=False):
+ def __init__(self, *, error_on_hash=False, error_on_eq=False,
+ error_on_repr=False):
self.error_on_hash = error_on_hash
self.error_on_eq = error_on_eq
+ self.error_on_repr = error_on_repr
def __enter__(self):
if HashKey._crasher is not None:
@@ -75,6 +84,10 @@ class EqError(Exception):
pass
+class ReprError(Exception):
+ pass
+
+
class MapTest(unittest.TestCase):
def test_hashkey_helper_1(self):
@@ -705,6 +718,45 @@ class MapTest(unittest.TestCase):
with HaskKeyCrasher(error_on_hash=True):
h[AA]
+ def test_repr_1(self):
+ h = Map()
+ self.assertTrue(repr(h).startswith('<immutables.Map({}) at 0x'))
+
+ h = h.set(1, 2).set(2, 3).set(3, 4)
+ self.assertTrue(repr(h).startswith(
+ '<immutables.Map({1: 2, 2: 3, 3: 4}) at 0x'))
+
+ def test_repr_2(self):
+ h = Map()
+ A = HashKey(100, 'A')
+
+ with self.assertRaises(ReprError):
+ with HaskKeyCrasher(error_on_repr=True):
+ repr(h.set(1, 2).set(A, 3).set(3, 4))
+
+ with self.assertRaises(ReprError):
+ with HaskKeyCrasher(error_on_repr=True):
+ repr(h.set(1, 2).set(2, A).set(3, 4))
+
+ def test_repr_3(self):
+ class Key:
+ def __init__(self):
+ self.val = None
+
+ def __hash__(self):
+ return 123
+
+ def __repr__(self):
+ return repr(self.val)
+
+ h = Map()
+ k = Key()
+ h = h.set(k, 1)
+ k.val = h
+
+ self.assertTrue(repr(h).startswith(
+ '<immutables.Map({{...}: 1}) at 0x'))
+
if __name__ == "__main__":
unittest.main()