1 import concurrent.futures
2 import contextvars
3 import functools
4 import gc
5 import random
6 import time
7 import unittest
8 import weakref
9 from test import support
10 from test.support import threading_helper
11
12 try:
13 from _testcapi import hamt
14 except ImportError:
15 hamt = None
16
17
18 def isolated_context(func):
19 """Needed to make reftracking test mode work."""
20 @functools.wraps(func)
21 def wrapper(*args, **kwargs):
22 ctx = contextvars.Context()
23 return ctx.run(func, *args, **kwargs)
24 return wrapper
25
26
27 class ESC[4;38;5;81mContextTest(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
28 def test_context_var_new_1(self):
29 with self.assertRaisesRegex(TypeError, 'takes exactly 1'):
30 contextvars.ContextVar()
31
32 with self.assertRaisesRegex(TypeError, 'must be a str'):
33 contextvars.ContextVar(1)
34
35 c = contextvars.ContextVar('aaa')
36 self.assertEqual(c.name, 'aaa')
37
38 with self.assertRaises(AttributeError):
39 c.name = 'bbb'
40
41 self.assertNotEqual(hash(c), hash('aaa'))
42
43 @isolated_context
44 def test_context_var_repr_1(self):
45 c = contextvars.ContextVar('a')
46 self.assertIn('a', repr(c))
47
48 c = contextvars.ContextVar('a', default=123)
49 self.assertIn('123', repr(c))
50
51 lst = []
52 c = contextvars.ContextVar('a', default=lst)
53 lst.append(c)
54 self.assertIn('...', repr(c))
55 self.assertIn('...', repr(lst))
56
57 t = c.set(1)
58 self.assertIn(repr(c), repr(t))
59 self.assertNotIn(' used ', repr(t))
60 c.reset(t)
61 self.assertIn(' used ', repr(t))
62
63 def test_context_subclassing_1(self):
64 with self.assertRaisesRegex(TypeError, 'not an acceptable base type'):
65 class ESC[4;38;5;81mMyContextVar(ESC[4;38;5;149mcontextvarsESC[4;38;5;149m.ESC[4;38;5;149mContextVar):
66 # Potentially we might want ContextVars to be subclassable.
67 pass
68
69 with self.assertRaisesRegex(TypeError, 'not an acceptable base type'):
70 class ESC[4;38;5;81mMyContext(ESC[4;38;5;149mcontextvarsESC[4;38;5;149m.ESC[4;38;5;149mContext):
71 pass
72
73 with self.assertRaisesRegex(TypeError, 'not an acceptable base type'):
74 class ESC[4;38;5;81mMyToken(ESC[4;38;5;149mcontextvarsESC[4;38;5;149m.ESC[4;38;5;149mToken):
75 pass
76
77 def test_context_new_1(self):
78 with self.assertRaisesRegex(TypeError, 'any arguments'):
79 contextvars.Context(1)
80 with self.assertRaisesRegex(TypeError, 'any arguments'):
81 contextvars.Context(1, a=1)
82 with self.assertRaisesRegex(TypeError, 'any arguments'):
83 contextvars.Context(a=1)
84 contextvars.Context(**{})
85
86 def test_context_typerrors_1(self):
87 ctx = contextvars.Context()
88
89 with self.assertRaisesRegex(TypeError, 'ContextVar key was expected'):
90 ctx[1]
91 with self.assertRaisesRegex(TypeError, 'ContextVar key was expected'):
92 1 in ctx
93 with self.assertRaisesRegex(TypeError, 'ContextVar key was expected'):
94 ctx.get(1)
95
96 def test_context_get_context_1(self):
97 ctx = contextvars.copy_context()
98 self.assertIsInstance(ctx, contextvars.Context)
99
100 def test_context_run_1(self):
101 ctx = contextvars.Context()
102
103 with self.assertRaisesRegex(TypeError, 'missing 1 required'):
104 ctx.run()
105
106 def test_context_run_2(self):
107 ctx = contextvars.Context()
108
109 def func(*args, **kwargs):
110 kwargs['spam'] = 'foo'
111 args += ('bar',)
112 return args, kwargs
113
114 for f in (func, functools.partial(func)):
115 # partial doesn't support FASTCALL
116
117 self.assertEqual(ctx.run(f), (('bar',), {'spam': 'foo'}))
118 self.assertEqual(ctx.run(f, 1), ((1, 'bar'), {'spam': 'foo'}))
119
120 self.assertEqual(
121 ctx.run(f, a=2),
122 (('bar',), {'a': 2, 'spam': 'foo'}))
123
124 self.assertEqual(
125 ctx.run(f, 11, a=2),
126 ((11, 'bar'), {'a': 2, 'spam': 'foo'}))
127
128 a = {}
129 self.assertEqual(
130 ctx.run(f, 11, **a),
131 ((11, 'bar'), {'spam': 'foo'}))
132 self.assertEqual(a, {})
133
134 def test_context_run_3(self):
135 ctx = contextvars.Context()
136
137 def func(*args, **kwargs):
138 1 / 0
139
140 with self.assertRaises(ZeroDivisionError):
141 ctx.run(func)
142 with self.assertRaises(ZeroDivisionError):
143 ctx.run(func, 1, 2)
144 with self.assertRaises(ZeroDivisionError):
145 ctx.run(func, 1, 2, a=123)
146
147 @isolated_context
148 def test_context_run_4(self):
149 ctx1 = contextvars.Context()
150 ctx2 = contextvars.Context()
151 var = contextvars.ContextVar('var')
152
153 def func2():
154 self.assertIsNone(var.get(None))
155
156 def func1():
157 self.assertIsNone(var.get(None))
158 var.set('spam')
159 ctx2.run(func2)
160 self.assertEqual(var.get(None), 'spam')
161
162 cur = contextvars.copy_context()
163 self.assertEqual(len(cur), 1)
164 self.assertEqual(cur[var], 'spam')
165 return cur
166
167 returned_ctx = ctx1.run(func1)
168 self.assertEqual(ctx1, returned_ctx)
169 self.assertEqual(returned_ctx[var], 'spam')
170 self.assertIn(var, returned_ctx)
171
172 def test_context_run_5(self):
173 ctx = contextvars.Context()
174 var = contextvars.ContextVar('var')
175
176 def func():
177 self.assertIsNone(var.get(None))
178 var.set('spam')
179 1 / 0
180
181 with self.assertRaises(ZeroDivisionError):
182 ctx.run(func)
183
184 self.assertIsNone(var.get(None))
185
186 def test_context_run_6(self):
187 ctx = contextvars.Context()
188 c = contextvars.ContextVar('a', default=0)
189
190 def fun():
191 self.assertEqual(c.get(), 0)
192 self.assertIsNone(ctx.get(c))
193
194 c.set(42)
195 self.assertEqual(c.get(), 42)
196 self.assertEqual(ctx.get(c), 42)
197
198 ctx.run(fun)
199
200 def test_context_run_7(self):
201 ctx = contextvars.Context()
202
203 def fun():
204 with self.assertRaisesRegex(RuntimeError, 'is already entered'):
205 ctx.run(fun)
206
207 ctx.run(fun)
208
209 @isolated_context
210 def test_context_getset_1(self):
211 c = contextvars.ContextVar('c')
212 with self.assertRaises(LookupError):
213 c.get()
214
215 self.assertIsNone(c.get(None))
216
217 t0 = c.set(42)
218 self.assertEqual(c.get(), 42)
219 self.assertEqual(c.get(None), 42)
220 self.assertIs(t0.old_value, t0.MISSING)
221 self.assertIs(t0.old_value, contextvars.Token.MISSING)
222 self.assertIs(t0.var, c)
223
224 t = c.set('spam')
225 self.assertEqual(c.get(), 'spam')
226 self.assertEqual(c.get(None), 'spam')
227 self.assertEqual(t.old_value, 42)
228 c.reset(t)
229
230 self.assertEqual(c.get(), 42)
231 self.assertEqual(c.get(None), 42)
232
233 c.set('spam2')
234 with self.assertRaisesRegex(RuntimeError, 'has already been used'):
235 c.reset(t)
236 self.assertEqual(c.get(), 'spam2')
237
238 ctx1 = contextvars.copy_context()
239 self.assertIn(c, ctx1)
240
241 c.reset(t0)
242 with self.assertRaisesRegex(RuntimeError, 'has already been used'):
243 c.reset(t0)
244 self.assertIsNone(c.get(None))
245
246 self.assertIn(c, ctx1)
247 self.assertEqual(ctx1[c], 'spam2')
248 self.assertEqual(ctx1.get(c, 'aa'), 'spam2')
249 self.assertEqual(len(ctx1), 1)
250 self.assertEqual(list(ctx1.items()), [(c, 'spam2')])
251 self.assertEqual(list(ctx1.values()), ['spam2'])
252 self.assertEqual(list(ctx1.keys()), [c])
253 self.assertEqual(list(ctx1), [c])
254
255 ctx2 = contextvars.copy_context()
256 self.assertNotIn(c, ctx2)
257 with self.assertRaises(KeyError):
258 ctx2[c]
259 self.assertEqual(ctx2.get(c, 'aa'), 'aa')
260 self.assertEqual(len(ctx2), 0)
261 self.assertEqual(list(ctx2), [])
262
263 @isolated_context
264 def test_context_getset_2(self):
265 v1 = contextvars.ContextVar('v1')
266 v2 = contextvars.ContextVar('v2')
267
268 t1 = v1.set(42)
269 with self.assertRaisesRegex(ValueError, 'by a different'):
270 v2.reset(t1)
271
272 @isolated_context
273 def test_context_getset_3(self):
274 c = contextvars.ContextVar('c', default=42)
275 ctx = contextvars.Context()
276
277 def fun():
278 self.assertEqual(c.get(), 42)
279 with self.assertRaises(KeyError):
280 ctx[c]
281 self.assertIsNone(ctx.get(c))
282 self.assertEqual(ctx.get(c, 'spam'), 'spam')
283 self.assertNotIn(c, ctx)
284 self.assertEqual(list(ctx.keys()), [])
285
286 t = c.set(1)
287 self.assertEqual(list(ctx.keys()), [c])
288 self.assertEqual(ctx[c], 1)
289
290 c.reset(t)
291 self.assertEqual(list(ctx.keys()), [])
292 with self.assertRaises(KeyError):
293 ctx[c]
294
295 ctx.run(fun)
296
297 @isolated_context
298 def test_context_getset_4(self):
299 c = contextvars.ContextVar('c', default=42)
300 ctx = contextvars.Context()
301
302 tok = ctx.run(c.set, 1)
303
304 with self.assertRaisesRegex(ValueError, 'different Context'):
305 c.reset(tok)
306
307 @isolated_context
308 def test_context_getset_5(self):
309 c = contextvars.ContextVar('c', default=42)
310 c.set([])
311
312 def fun():
313 c.set([])
314 c.get().append(42)
315 self.assertEqual(c.get(), [42])
316
317 contextvars.copy_context().run(fun)
318 self.assertEqual(c.get(), [])
319
320 def test_context_copy_1(self):
321 ctx1 = contextvars.Context()
322 c = contextvars.ContextVar('c', default=42)
323
324 def ctx1_fun():
325 c.set(10)
326
327 ctx2 = ctx1.copy()
328 self.assertEqual(ctx2[c], 10)
329
330 c.set(20)
331 self.assertEqual(ctx1[c], 20)
332 self.assertEqual(ctx2[c], 10)
333
334 ctx2.run(ctx2_fun)
335 self.assertEqual(ctx1[c], 20)
336 self.assertEqual(ctx2[c], 30)
337
338 def ctx2_fun():
339 self.assertEqual(c.get(), 10)
340 c.set(30)
341 self.assertEqual(c.get(), 30)
342
343 ctx1.run(ctx1_fun)
344
345 @isolated_context
346 @threading_helper.requires_working_threading()
347 def test_context_threads_1(self):
348 cvar = contextvars.ContextVar('cvar')
349
350 def sub(num):
351 for i in range(10):
352 cvar.set(num + i)
353 time.sleep(random.uniform(0.001, 0.05))
354 self.assertEqual(cvar.get(), num + i)
355 return num
356
357 tp = concurrent.futures.ThreadPoolExecutor(max_workers=10)
358 try:
359 results = list(tp.map(sub, range(10)))
360 finally:
361 tp.shutdown()
362 self.assertEqual(results, list(range(10)))
363
364
365 # HAMT Tests
366
367
368 class ESC[4;38;5;81mHashKey:
369 _crasher = None
370
371 def __init__(self, hash, name, *, error_on_eq_to=None):
372 assert hash != -1
373 self.name = name
374 self.hash = hash
375 self.error_on_eq_to = error_on_eq_to
376
377 def __repr__(self):
378 return f'<Key name:{self.name} hash:{self.hash}>'
379
380 def __hash__(self):
381 if self._crasher is not None and self._crasher.error_on_hash:
382 raise HashingError
383
384 return self.hash
385
386 def __eq__(self, other):
387 if not isinstance(other, HashKey):
388 return NotImplemented
389
390 if self._crasher is not None and self._crasher.error_on_eq:
391 raise EqError
392
393 if self.error_on_eq_to is not None and self.error_on_eq_to is other:
394 raise ValueError(f'cannot compare {self!r} to {other!r}')
395 if other.error_on_eq_to is not None and other.error_on_eq_to is self:
396 raise ValueError(f'cannot compare {other!r} to {self!r}')
397
398 return (self.name, self.hash) == (other.name, other.hash)
399
400
401 class ESC[4;38;5;81mKeyStr(ESC[4;38;5;149mstr):
402 def __hash__(self):
403 if HashKey._crasher is not None and HashKey._crasher.error_on_hash:
404 raise HashingError
405 return super().__hash__()
406
407 def __eq__(self, other):
408 if HashKey._crasher is not None and HashKey._crasher.error_on_eq:
409 raise EqError
410 return super().__eq__(other)
411
412
413 class ESC[4;38;5;81mHaskKeyCrasher:
414 def __init__(self, *, error_on_hash=False, error_on_eq=False):
415 self.error_on_hash = error_on_hash
416 self.error_on_eq = error_on_eq
417
418 def __enter__(self):
419 if HashKey._crasher is not None:
420 raise RuntimeError('cannot nest crashers')
421 HashKey._crasher = self
422
423 def __exit__(self, *exc):
424 HashKey._crasher = None
425
426
427 class ESC[4;38;5;81mHashingError(ESC[4;38;5;149mException):
428 pass
429
430
431 class ESC[4;38;5;81mEqError(ESC[4;38;5;149mException):
432 pass
433
434
435 @unittest.skipIf(hamt is None, '_testcapi lacks "hamt()" function')
436 class ESC[4;38;5;81mHamtTest(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
437
438 def test_hashkey_helper_1(self):
439 k1 = HashKey(10, 'aaa')
440 k2 = HashKey(10, 'bbb')
441
442 self.assertNotEqual(k1, k2)
443 self.assertEqual(hash(k1), hash(k2))
444
445 d = dict()
446 d[k1] = 'a'
447 d[k2] = 'b'
448
449 self.assertEqual(d[k1], 'a')
450 self.assertEqual(d[k2], 'b')
451
452 def test_hamt_basics_1(self):
453 h = hamt()
454 h = None # NoQA
455
456 def test_hamt_basics_2(self):
457 h = hamt()
458 self.assertEqual(len(h), 0)
459
460 h2 = h.set('a', 'b')
461 self.assertIsNot(h, h2)
462 self.assertEqual(len(h), 0)
463 self.assertEqual(len(h2), 1)
464
465 self.assertIsNone(h.get('a'))
466 self.assertEqual(h.get('a', 42), 42)
467
468 self.assertEqual(h2.get('a'), 'b')
469
470 h3 = h2.set('b', 10)
471 self.assertIsNot(h2, h3)
472 self.assertEqual(len(h), 0)
473 self.assertEqual(len(h2), 1)
474 self.assertEqual(len(h3), 2)
475 self.assertEqual(h3.get('a'), 'b')
476 self.assertEqual(h3.get('b'), 10)
477
478 self.assertIsNone(h.get('b'))
479 self.assertIsNone(h2.get('b'))
480
481 self.assertIsNone(h.get('a'))
482 self.assertEqual(h2.get('a'), 'b')
483
484 h = h2 = h3 = None
485
486 def test_hamt_basics_3(self):
487 h = hamt()
488 o = object()
489 h1 = h.set('1', o)
490 h2 = h1.set('1', o)
491 self.assertIs(h1, h2)
492
493 def test_hamt_basics_4(self):
494 h = hamt()
495 h1 = h.set('key', [])
496 h2 = h1.set('key', [])
497 self.assertIsNot(h1, h2)
498 self.assertEqual(len(h1), 1)
499 self.assertEqual(len(h2), 1)
500 self.assertIsNot(h1.get('key'), h2.get('key'))
501
502 def test_hamt_collision_1(self):
503 k1 = HashKey(10, 'aaa')
504 k2 = HashKey(10, 'bbb')
505 k3 = HashKey(10, 'ccc')
506
507 h = hamt()
508 h2 = h.set(k1, 'a')
509 h3 = h2.set(k2, 'b')
510
511 self.assertEqual(h.get(k1), None)
512 self.assertEqual(h.get(k2), None)
513
514 self.assertEqual(h2.get(k1), 'a')
515 self.assertEqual(h2.get(k2), None)
516
517 self.assertEqual(h3.get(k1), 'a')
518 self.assertEqual(h3.get(k2), 'b')
519
520 h4 = h3.set(k2, 'cc')
521 h5 = h4.set(k3, 'aa')
522
523 self.assertEqual(h3.get(k1), 'a')
524 self.assertEqual(h3.get(k2), 'b')
525 self.assertEqual(h4.get(k1), 'a')
526 self.assertEqual(h4.get(k2), 'cc')
527 self.assertEqual(h4.get(k3), None)
528 self.assertEqual(h5.get(k1), 'a')
529 self.assertEqual(h5.get(k2), 'cc')
530 self.assertEqual(h5.get(k2), 'cc')
531 self.assertEqual(h5.get(k3), 'aa')
532
533 self.assertEqual(len(h), 0)
534 self.assertEqual(len(h2), 1)
535 self.assertEqual(len(h3), 2)
536 self.assertEqual(len(h4), 2)
537 self.assertEqual(len(h5), 3)
538
539 def test_hamt_collision_3(self):
540 # Test that iteration works with the deepest tree possible.
541 # https://github.com/python/cpython/issues/93065
542
543 C = HashKey(0b10000000_00000000_00000000_00000000, 'C')
544 D = HashKey(0b10000000_00000000_00000000_00000000, 'D')
545
546 E = HashKey(0b00000000_00000000_00000000_00000000, 'E')
547
548 h = hamt()
549 h = h.set(C, 'C')
550 h = h.set(D, 'D')
551 h = h.set(E, 'E')
552
553 # BitmapNode(size=2 count=1 bitmap=0b1):
554 # NULL:
555 # BitmapNode(size=2 count=1 bitmap=0b1):
556 # NULL:
557 # BitmapNode(size=2 count=1 bitmap=0b1):
558 # NULL:
559 # BitmapNode(size=2 count=1 bitmap=0b1):
560 # NULL:
561 # BitmapNode(size=2 count=1 bitmap=0b1):
562 # NULL:
563 # BitmapNode(size=2 count=1 bitmap=0b1):
564 # NULL:
565 # BitmapNode(size=4 count=2 bitmap=0b101):
566 # <Key name:E hash:0>: 'E'
567 # NULL:
568 # CollisionNode(size=4 id=0x107a24520):
569 # <Key name:C hash:2147483648>: 'C'
570 # <Key name:D hash:2147483648>: 'D'
571
572 self.assertEqual({k.name for k in h.keys()}, {'C', 'D', 'E'})
573
574 @support.requires_resource('cpu')
575 def test_hamt_stress(self):
576 COLLECTION_SIZE = 7000
577 TEST_ITERS_EVERY = 647
578 CRASH_HASH_EVERY = 97
579 CRASH_EQ_EVERY = 11
580 RUN_XTIMES = 3
581
582 for _ in range(RUN_XTIMES):
583 h = hamt()
584 d = dict()
585
586 for i in range(COLLECTION_SIZE):
587 key = KeyStr(i)
588
589 if not (i % CRASH_HASH_EVERY):
590 with HaskKeyCrasher(error_on_hash=True):
591 with self.assertRaises(HashingError):
592 h.set(key, i)
593
594 h = h.set(key, i)
595
596 if not (i % CRASH_EQ_EVERY):
597 with HaskKeyCrasher(error_on_eq=True):
598 with self.assertRaises(EqError):
599 h.get(KeyStr(i)) # really trigger __eq__
600
601 d[key] = i
602 self.assertEqual(len(d), len(h))
603
604 if not (i % TEST_ITERS_EVERY):
605 self.assertEqual(set(h.items()), set(d.items()))
606 self.assertEqual(len(h.items()), len(d.items()))
607
608 self.assertEqual(len(h), COLLECTION_SIZE)
609
610 for key in range(COLLECTION_SIZE):
611 self.assertEqual(h.get(KeyStr(key), 'not found'), key)
612
613 keys_to_delete = list(range(COLLECTION_SIZE))
614 random.shuffle(keys_to_delete)
615 for iter_i, i in enumerate(keys_to_delete):
616 key = KeyStr(i)
617
618 if not (iter_i % CRASH_HASH_EVERY):
619 with HaskKeyCrasher(error_on_hash=True):
620 with self.assertRaises(HashingError):
621 h.delete(key)
622
623 if not (iter_i % CRASH_EQ_EVERY):
624 with HaskKeyCrasher(error_on_eq=True):
625 with self.assertRaises(EqError):
626 h.delete(KeyStr(i))
627
628 h = h.delete(key)
629 self.assertEqual(h.get(key, 'not found'), 'not found')
630 del d[key]
631 self.assertEqual(len(d), len(h))
632
633 if iter_i == COLLECTION_SIZE // 2:
634 hm = h
635 dm = d.copy()
636
637 if not (iter_i % TEST_ITERS_EVERY):
638 self.assertEqual(set(h.keys()), set(d.keys()))
639 self.assertEqual(len(h.keys()), len(d.keys()))
640
641 self.assertEqual(len(d), 0)
642 self.assertEqual(len(h), 0)
643
644 # ============
645
646 for key in dm:
647 self.assertEqual(hm.get(str(key)), dm[key])
648 self.assertEqual(len(dm), len(hm))
649
650 for i, key in enumerate(keys_to_delete):
651 hm = hm.delete(str(key))
652 self.assertEqual(hm.get(str(key), 'not found'), 'not found')
653 dm.pop(str(key), None)
654 self.assertEqual(len(d), len(h))
655
656 if not (i % TEST_ITERS_EVERY):
657 self.assertEqual(set(h.values()), set(d.values()))
658 self.assertEqual(len(h.values()), len(d.values()))
659
660 self.assertEqual(len(d), 0)
661 self.assertEqual(len(h), 0)
662 self.assertEqual(list(h.items()), [])
663
664 def test_hamt_delete_1(self):
665 A = HashKey(100, 'A')
666 B = HashKey(101, 'B')
667 C = HashKey(102, 'C')
668 D = HashKey(103, 'D')
669 E = HashKey(104, 'E')
670 Z = HashKey(-100, 'Z')
671
672 Er = HashKey(103, 'Er', error_on_eq_to=D)
673
674 h = hamt()
675 h = h.set(A, 'a')
676 h = h.set(B, 'b')
677 h = h.set(C, 'c')
678 h = h.set(D, 'd')
679 h = h.set(E, 'e')
680
681 orig_len = len(h)
682
683 # BitmapNode(size=10 bitmap=0b111110000 id=0x10eadc618):
684 # <Key name:A hash:100>: 'a'
685 # <Key name:B hash:101>: 'b'
686 # <Key name:C hash:102>: 'c'
687 # <Key name:D hash:103>: 'd'
688 # <Key name:E hash:104>: 'e'
689
690 h = h.delete(C)
691 self.assertEqual(len(h), orig_len - 1)
692
693 with self.assertRaisesRegex(ValueError, 'cannot compare'):
694 h.delete(Er)
695
696 h = h.delete(D)
697 self.assertEqual(len(h), orig_len - 2)
698
699 h2 = h.delete(Z)
700 self.assertIs(h2, h)
701
702 h = h.delete(A)
703 self.assertEqual(len(h), orig_len - 3)
704
705 self.assertEqual(h.get(A, 42), 42)
706 self.assertEqual(h.get(B), 'b')
707 self.assertEqual(h.get(E), 'e')
708
709 def test_hamt_delete_2(self):
710 A = HashKey(100, 'A')
711 B = HashKey(201001, 'B')
712 C = HashKey(101001, 'C')
713 D = HashKey(103, 'D')
714 E = HashKey(104, 'E')
715 Z = HashKey(-100, 'Z')
716
717 Er = HashKey(201001, 'Er', error_on_eq_to=B)
718
719 h = hamt()
720 h = h.set(A, 'a')
721 h = h.set(B, 'b')
722 h = h.set(C, 'c')
723 h = h.set(D, 'd')
724 h = h.set(E, 'e')
725
726 orig_len = len(h)
727
728 # BitmapNode(size=8 bitmap=0b1110010000):
729 # <Key name:A hash:100>: 'a'
730 # <Key name:D hash:103>: 'd'
731 # <Key name:E hash:104>: 'e'
732 # NULL:
733 # BitmapNode(size=4 bitmap=0b100000000001000000000):
734 # <Key name:B hash:201001>: 'b'
735 # <Key name:C hash:101001>: 'c'
736
737 with self.assertRaisesRegex(ValueError, 'cannot compare'):
738 h.delete(Er)
739
740 h = h.delete(Z)
741 self.assertEqual(len(h), orig_len)
742
743 h = h.delete(C)
744 self.assertEqual(len(h), orig_len - 1)
745
746 h = h.delete(B)
747 self.assertEqual(len(h), orig_len - 2)
748
749 h = h.delete(A)
750 self.assertEqual(len(h), orig_len - 3)
751
752 self.assertEqual(h.get(D), 'd')
753 self.assertEqual(h.get(E), 'e')
754
755 h = h.delete(A)
756 h = h.delete(B)
757 h = h.delete(D)
758 h = h.delete(E)
759 self.assertEqual(len(h), 0)
760
761 def test_hamt_delete_3(self):
762 A = HashKey(100, 'A')
763 B = HashKey(101, 'B')
764 C = HashKey(100100, 'C')
765 D = HashKey(100100, 'D')
766 E = HashKey(104, 'E')
767
768 h = hamt()
769 h = h.set(A, 'a')
770 h = h.set(B, 'b')
771 h = h.set(C, 'c')
772 h = h.set(D, 'd')
773 h = h.set(E, 'e')
774
775 orig_len = len(h)
776
777 # BitmapNode(size=6 bitmap=0b100110000):
778 # NULL:
779 # BitmapNode(size=4 bitmap=0b1000000000000000000001000):
780 # <Key name:A hash:100>: 'a'
781 # NULL:
782 # CollisionNode(size=4 id=0x108572410):
783 # <Key name:C hash:100100>: 'c'
784 # <Key name:D hash:100100>: 'd'
785 # <Key name:B hash:101>: 'b'
786 # <Key name:E hash:104>: 'e'
787
788 h = h.delete(A)
789 self.assertEqual(len(h), orig_len - 1)
790
791 h = h.delete(E)
792 self.assertEqual(len(h), orig_len - 2)
793
794 self.assertEqual(h.get(C), 'c')
795 self.assertEqual(h.get(B), 'b')
796
797 def test_hamt_delete_4(self):
798 A = HashKey(100, 'A')
799 B = HashKey(101, 'B')
800 C = HashKey(100100, 'C')
801 D = HashKey(100100, 'D')
802 E = HashKey(100100, 'E')
803
804 h = hamt()
805 h = h.set(A, 'a')
806 h = h.set(B, 'b')
807 h = h.set(C, 'c')
808 h = h.set(D, 'd')
809 h = h.set(E, 'e')
810
811 orig_len = len(h)
812
813 # BitmapNode(size=4 bitmap=0b110000):
814 # NULL:
815 # BitmapNode(size=4 bitmap=0b1000000000000000000001000):
816 # <Key name:A hash:100>: 'a'
817 # NULL:
818 # CollisionNode(size=6 id=0x10515ef30):
819 # <Key name:C hash:100100>: 'c'
820 # <Key name:D hash:100100>: 'd'
821 # <Key name:E hash:100100>: 'e'
822 # <Key name:B hash:101>: 'b'
823
824 h = h.delete(D)
825 self.assertEqual(len(h), orig_len - 1)
826
827 h = h.delete(E)
828 self.assertEqual(len(h), orig_len - 2)
829
830 h = h.delete(C)
831 self.assertEqual(len(h), orig_len - 3)
832
833 h = h.delete(A)
834 self.assertEqual(len(h), orig_len - 4)
835
836 h = h.delete(B)
837 self.assertEqual(len(h), 0)
838
839 def test_hamt_delete_5(self):
840 h = hamt()
841
842 keys = []
843 for i in range(17):
844 key = HashKey(i, str(i))
845 keys.append(key)
846 h = h.set(key, f'val-{i}')
847
848 collision_key16 = HashKey(16, '18')
849 h = h.set(collision_key16, 'collision')
850
851 # ArrayNode(id=0x10f8b9318):
852 # 0::
853 # BitmapNode(size=2 count=1 bitmap=0b1):
854 # <Key name:0 hash:0>: 'val-0'
855 #
856 # ... 14 more BitmapNodes ...
857 #
858 # 15::
859 # BitmapNode(size=2 count=1 bitmap=0b1):
860 # <Key name:15 hash:15>: 'val-15'
861 #
862 # 16::
863 # BitmapNode(size=2 count=1 bitmap=0b1):
864 # NULL:
865 # CollisionNode(size=4 id=0x10f2f5af8):
866 # <Key name:16 hash:16>: 'val-16'
867 # <Key name:18 hash:16>: 'collision'
868
869 self.assertEqual(len(h), 18)
870
871 h = h.delete(keys[2])
872 self.assertEqual(len(h), 17)
873
874 h = h.delete(collision_key16)
875 self.assertEqual(len(h), 16)
876 h = h.delete(keys[16])
877 self.assertEqual(len(h), 15)
878
879 h = h.delete(keys[1])
880 self.assertEqual(len(h), 14)
881 h = h.delete(keys[1])
882 self.assertEqual(len(h), 14)
883
884 for key in keys:
885 h = h.delete(key)
886 self.assertEqual(len(h), 0)
887
888 def test_hamt_items_1(self):
889 A = HashKey(100, 'A')
890 B = HashKey(201001, 'B')
891 C = HashKey(101001, 'C')
892 D = HashKey(103, 'D')
893 E = HashKey(104, 'E')
894 F = HashKey(110, 'F')
895
896 h = hamt()
897 h = h.set(A, 'a')
898 h = h.set(B, 'b')
899 h = h.set(C, 'c')
900 h = h.set(D, 'd')
901 h = h.set(E, 'e')
902 h = h.set(F, 'f')
903
904 it = h.items()
905 self.assertEqual(
906 set(list(it)),
907 {(A, 'a'), (B, 'b'), (C, 'c'), (D, 'd'), (E, 'e'), (F, 'f')})
908
909 def test_hamt_items_2(self):
910 A = HashKey(100, 'A')
911 B = HashKey(101, 'B')
912 C = HashKey(100100, 'C')
913 D = HashKey(100100, 'D')
914 E = HashKey(100100, 'E')
915 F = HashKey(110, 'F')
916
917 h = hamt()
918 h = h.set(A, 'a')
919 h = h.set(B, 'b')
920 h = h.set(C, 'c')
921 h = h.set(D, 'd')
922 h = h.set(E, 'e')
923 h = h.set(F, 'f')
924
925 it = h.items()
926 self.assertEqual(
927 set(list(it)),
928 {(A, 'a'), (B, 'b'), (C, 'c'), (D, 'd'), (E, 'e'), (F, 'f')})
929
930 def test_hamt_keys_1(self):
931 A = HashKey(100, 'A')
932 B = HashKey(101, 'B')
933 C = HashKey(100100, 'C')
934 D = HashKey(100100, 'D')
935 E = HashKey(100100, 'E')
936 F = HashKey(110, 'F')
937
938 h = hamt()
939 h = h.set(A, 'a')
940 h = h.set(B, 'b')
941 h = h.set(C, 'c')
942 h = h.set(D, 'd')
943 h = h.set(E, 'e')
944 h = h.set(F, 'f')
945
946 self.assertEqual(set(list(h.keys())), {A, B, C, D, E, F})
947 self.assertEqual(set(list(h)), {A, B, C, D, E, F})
948
949 def test_hamt_items_3(self):
950 h = hamt()
951 self.assertEqual(len(h.items()), 0)
952 self.assertEqual(list(h.items()), [])
953
954 def test_hamt_eq_1(self):
955 A = HashKey(100, 'A')
956 B = HashKey(101, 'B')
957 C = HashKey(100100, 'C')
958 D = HashKey(100100, 'D')
959 E = HashKey(120, 'E')
960
961 h1 = hamt()
962 h1 = h1.set(A, 'a')
963 h1 = h1.set(B, 'b')
964 h1 = h1.set(C, 'c')
965 h1 = h1.set(D, 'd')
966
967 h2 = hamt()
968 h2 = h2.set(A, 'a')
969
970 self.assertFalse(h1 == h2)
971 self.assertTrue(h1 != h2)
972
973 h2 = h2.set(B, 'b')
974 self.assertFalse(h1 == h2)
975 self.assertTrue(h1 != h2)
976
977 h2 = h2.set(C, 'c')
978 self.assertFalse(h1 == h2)
979 self.assertTrue(h1 != h2)
980
981 h2 = h2.set(D, 'd2')
982 self.assertFalse(h1 == h2)
983 self.assertTrue(h1 != h2)
984
985 h2 = h2.set(D, 'd')
986 self.assertTrue(h1 == h2)
987 self.assertFalse(h1 != h2)
988
989 h2 = h2.set(E, 'e')
990 self.assertFalse(h1 == h2)
991 self.assertTrue(h1 != h2)
992
993 h2 = h2.delete(D)
994 self.assertFalse(h1 == h2)
995 self.assertTrue(h1 != h2)
996
997 h2 = h2.set(E, 'd')
998 self.assertFalse(h1 == h2)
999 self.assertTrue(h1 != h2)
1000
1001 def test_hamt_eq_2(self):
1002 A = HashKey(100, 'A')
1003 Er = HashKey(100, 'Er', error_on_eq_to=A)
1004
1005 h1 = hamt()
1006 h1 = h1.set(A, 'a')
1007
1008 h2 = hamt()
1009 h2 = h2.set(Er, 'a')
1010
1011 with self.assertRaisesRegex(ValueError, 'cannot compare'):
1012 h1 == h2
1013
1014 with self.assertRaisesRegex(ValueError, 'cannot compare'):
1015 h1 != h2
1016
1017 def test_hamt_gc_1(self):
1018 A = HashKey(100, 'A')
1019
1020 h = hamt()
1021 h = h.set(0, 0) # empty HAMT node is memoized in hamt.c
1022 ref = weakref.ref(h)
1023
1024 a = []
1025 a.append(a)
1026 a.append(h)
1027 b = []
1028 a.append(b)
1029 b.append(a)
1030 h = h.set(A, b)
1031
1032 del h, a, b
1033
1034 gc.collect()
1035 gc.collect()
1036 gc.collect()
1037
1038 self.assertIsNone(ref())
1039
1040 def test_hamt_gc_2(self):
1041 A = HashKey(100, 'A')
1042 B = HashKey(101, 'B')
1043
1044 h = hamt()
1045 h = h.set(A, 'a')
1046 h = h.set(A, h)
1047
1048 ref = weakref.ref(h)
1049 hi = h.items()
1050 next(hi)
1051
1052 del h, hi
1053
1054 gc.collect()
1055 gc.collect()
1056 gc.collect()
1057
1058 self.assertIsNone(ref())
1059
1060 def test_hamt_in_1(self):
1061 A = HashKey(100, 'A')
1062 AA = HashKey(100, 'A')
1063
1064 B = HashKey(101, 'B')
1065
1066 h = hamt()
1067 h = h.set(A, 1)
1068
1069 self.assertTrue(A in h)
1070 self.assertFalse(B in h)
1071
1072 with self.assertRaises(EqError):
1073 with HaskKeyCrasher(error_on_eq=True):
1074 AA in h
1075
1076 with self.assertRaises(HashingError):
1077 with HaskKeyCrasher(error_on_hash=True):
1078 AA in h
1079
1080 def test_hamt_getitem_1(self):
1081 A = HashKey(100, 'A')
1082 AA = HashKey(100, 'A')
1083
1084 B = HashKey(101, 'B')
1085
1086 h = hamt()
1087 h = h.set(A, 1)
1088
1089 self.assertEqual(h[A], 1)
1090 self.assertEqual(h[AA], 1)
1091
1092 with self.assertRaises(KeyError):
1093 h[B]
1094
1095 with self.assertRaises(EqError):
1096 with HaskKeyCrasher(error_on_eq=True):
1097 h[AA]
1098
1099 with self.assertRaises(HashingError):
1100 with HaskKeyCrasher(error_on_hash=True):
1101 h[AA]
1102
1103
1104 if __name__ == "__main__":
1105 unittest.main()