(root)/
Python-3.11.7/
Lib/
test/
test_context.py
       1  import concurrent.futures
       2  import contextvars
       3  import functools
       4  import gc
       5  import random
       6  import time
       7  import unittest
       8  import weakref
       9  from test import support
      10  from test.support import threading_helper
      11  
      12  try:
      13      from _testcapi import hamt
      14  except ImportError:
      15      hamt = None
      16  
      17  
      18  def isolated_context(func):
      19      """Needed to make reftracking test mode work."""
      20      @functools.wraps(func)
      21      def wrapper(*args, **kwargs):
      22          ctx = contextvars.Context()
      23          return ctx.run(func, *args, **kwargs)
      24      return wrapper
      25  
      26  
      27  class ESC[4;38;5;81mContextTest(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
      28      def test_context_var_new_1(self):
      29          with self.assertRaisesRegex(TypeError, 'takes exactly 1'):
      30              contextvars.ContextVar()
      31  
      32          with self.assertRaisesRegex(TypeError, 'must be a str'):
      33              contextvars.ContextVar(1)
      34  
      35          c = contextvars.ContextVar('aaa')
      36          self.assertEqual(c.name, 'aaa')
      37  
      38          with self.assertRaises(AttributeError):
      39              c.name = 'bbb'
      40  
      41          self.assertNotEqual(hash(c), hash('aaa'))
      42  
      43      @isolated_context
      44      def test_context_var_repr_1(self):
      45          c = contextvars.ContextVar('a')
      46          self.assertIn('a', repr(c))
      47  
      48          c = contextvars.ContextVar('a', default=123)
      49          self.assertIn('123', repr(c))
      50  
      51          lst = []
      52          c = contextvars.ContextVar('a', default=lst)
      53          lst.append(c)
      54          self.assertIn('...', repr(c))
      55          self.assertIn('...', repr(lst))
      56  
      57          t = c.set(1)
      58          self.assertIn(repr(c), repr(t))
      59          self.assertNotIn(' used ', repr(t))
      60          c.reset(t)
      61          self.assertIn(' used ', repr(t))
      62  
      63      def test_context_subclassing_1(self):
      64          with self.assertRaisesRegex(TypeError, 'not an acceptable base type'):
      65              class ESC[4;38;5;81mMyContextVar(ESC[4;38;5;149mcontextvarsESC[4;38;5;149m.ESC[4;38;5;149mContextVar):
      66                  # Potentially we might want ContextVars to be subclassable.
      67                  pass
      68  
      69          with self.assertRaisesRegex(TypeError, 'not an acceptable base type'):
      70              class ESC[4;38;5;81mMyContext(ESC[4;38;5;149mcontextvarsESC[4;38;5;149m.ESC[4;38;5;149mContext):
      71                  pass
      72  
      73          with self.assertRaisesRegex(TypeError, 'not an acceptable base type'):
      74              class ESC[4;38;5;81mMyToken(ESC[4;38;5;149mcontextvarsESC[4;38;5;149m.ESC[4;38;5;149mToken):
      75                  pass
      76  
      77      def test_context_new_1(self):
      78          with self.assertRaisesRegex(TypeError, 'any arguments'):
      79              contextvars.Context(1)
      80          with self.assertRaisesRegex(TypeError, 'any arguments'):
      81              contextvars.Context(1, a=1)
      82          with self.assertRaisesRegex(TypeError, 'any arguments'):
      83              contextvars.Context(a=1)
      84          contextvars.Context(**{})
      85  
      86      def test_context_typerrors_1(self):
      87          ctx = contextvars.Context()
      88  
      89          with self.assertRaisesRegex(TypeError, 'ContextVar key was expected'):
      90              ctx[1]
      91          with self.assertRaisesRegex(TypeError, 'ContextVar key was expected'):
      92              1 in ctx
      93          with self.assertRaisesRegex(TypeError, 'ContextVar key was expected'):
      94              ctx.get(1)
      95  
      96      def test_context_get_context_1(self):
      97          ctx = contextvars.copy_context()
      98          self.assertIsInstance(ctx, contextvars.Context)
      99  
     100      def test_context_run_1(self):
     101          ctx = contextvars.Context()
     102  
     103          with self.assertRaisesRegex(TypeError, 'missing 1 required'):
     104              ctx.run()
     105  
     106      def test_context_run_2(self):
     107          ctx = contextvars.Context()
     108  
     109          def func(*args, **kwargs):
     110              kwargs['spam'] = 'foo'
     111              args += ('bar',)
     112              return args, kwargs
     113  
     114          for f in (func, functools.partial(func)):
     115              # partial doesn't support FASTCALL
     116  
     117              self.assertEqual(ctx.run(f), (('bar',), {'spam': 'foo'}))
     118              self.assertEqual(ctx.run(f, 1), ((1, 'bar'), {'spam': 'foo'}))
     119  
     120              self.assertEqual(
     121                  ctx.run(f, a=2),
     122                  (('bar',), {'a': 2, 'spam': 'foo'}))
     123  
     124              self.assertEqual(
     125                  ctx.run(f, 11, a=2),
     126                  ((11, 'bar'), {'a': 2, 'spam': 'foo'}))
     127  
     128              a = {}
     129              self.assertEqual(
     130                  ctx.run(f, 11, **a),
     131                  ((11, 'bar'), {'spam': 'foo'}))
     132              self.assertEqual(a, {})
     133  
     134      def test_context_run_3(self):
     135          ctx = contextvars.Context()
     136  
     137          def func(*args, **kwargs):
     138              1 / 0
     139  
     140          with self.assertRaises(ZeroDivisionError):
     141              ctx.run(func)
     142          with self.assertRaises(ZeroDivisionError):
     143              ctx.run(func, 1, 2)
     144          with self.assertRaises(ZeroDivisionError):
     145              ctx.run(func, 1, 2, a=123)
     146  
     147      @isolated_context
     148      def test_context_run_4(self):
     149          ctx1 = contextvars.Context()
     150          ctx2 = contextvars.Context()
     151          var = contextvars.ContextVar('var')
     152  
     153          def func2():
     154              self.assertIsNone(var.get(None))
     155  
     156          def func1():
     157              self.assertIsNone(var.get(None))
     158              var.set('spam')
     159              ctx2.run(func2)
     160              self.assertEqual(var.get(None), 'spam')
     161  
     162              cur = contextvars.copy_context()
     163              self.assertEqual(len(cur), 1)
     164              self.assertEqual(cur[var], 'spam')
     165              return cur
     166  
     167          returned_ctx = ctx1.run(func1)
     168          self.assertEqual(ctx1, returned_ctx)
     169          self.assertEqual(returned_ctx[var], 'spam')
     170          self.assertIn(var, returned_ctx)
     171  
     172      def test_context_run_5(self):
     173          ctx = contextvars.Context()
     174          var = contextvars.ContextVar('var')
     175  
     176          def func():
     177              self.assertIsNone(var.get(None))
     178              var.set('spam')
     179              1 / 0
     180  
     181          with self.assertRaises(ZeroDivisionError):
     182              ctx.run(func)
     183  
     184          self.assertIsNone(var.get(None))
     185  
     186      def test_context_run_6(self):
     187          ctx = contextvars.Context()
     188          c = contextvars.ContextVar('a', default=0)
     189  
     190          def fun():
     191              self.assertEqual(c.get(), 0)
     192              self.assertIsNone(ctx.get(c))
     193  
     194              c.set(42)
     195              self.assertEqual(c.get(), 42)
     196              self.assertEqual(ctx.get(c), 42)
     197  
     198          ctx.run(fun)
     199  
     200      def test_context_run_7(self):
     201          ctx = contextvars.Context()
     202  
     203          def fun():
     204              with self.assertRaisesRegex(RuntimeError, 'is already entered'):
     205                  ctx.run(fun)
     206  
     207          ctx.run(fun)
     208  
     209      @isolated_context
     210      def test_context_getset_1(self):
     211          c = contextvars.ContextVar('c')
     212          with self.assertRaises(LookupError):
     213              c.get()
     214  
     215          self.assertIsNone(c.get(None))
     216  
     217          t0 = c.set(42)
     218          self.assertEqual(c.get(), 42)
     219          self.assertEqual(c.get(None), 42)
     220          self.assertIs(t0.old_value, t0.MISSING)
     221          self.assertIs(t0.old_value, contextvars.Token.MISSING)
     222          self.assertIs(t0.var, c)
     223  
     224          t = c.set('spam')
     225          self.assertEqual(c.get(), 'spam')
     226          self.assertEqual(c.get(None), 'spam')
     227          self.assertEqual(t.old_value, 42)
     228          c.reset(t)
     229  
     230          self.assertEqual(c.get(), 42)
     231          self.assertEqual(c.get(None), 42)
     232  
     233          c.set('spam2')
     234          with self.assertRaisesRegex(RuntimeError, 'has already been used'):
     235              c.reset(t)
     236          self.assertEqual(c.get(), 'spam2')
     237  
     238          ctx1 = contextvars.copy_context()
     239          self.assertIn(c, ctx1)
     240  
     241          c.reset(t0)
     242          with self.assertRaisesRegex(RuntimeError, 'has already been used'):
     243              c.reset(t0)
     244          self.assertIsNone(c.get(None))
     245  
     246          self.assertIn(c, ctx1)
     247          self.assertEqual(ctx1[c], 'spam2')
     248          self.assertEqual(ctx1.get(c, 'aa'), 'spam2')
     249          self.assertEqual(len(ctx1), 1)
     250          self.assertEqual(list(ctx1.items()), [(c, 'spam2')])
     251          self.assertEqual(list(ctx1.values()), ['spam2'])
     252          self.assertEqual(list(ctx1.keys()), [c])
     253          self.assertEqual(list(ctx1), [c])
     254  
     255          ctx2 = contextvars.copy_context()
     256          self.assertNotIn(c, ctx2)
     257          with self.assertRaises(KeyError):
     258              ctx2[c]
     259          self.assertEqual(ctx2.get(c, 'aa'), 'aa')
     260          self.assertEqual(len(ctx2), 0)
     261          self.assertEqual(list(ctx2), [])
     262  
     263      @isolated_context
     264      def test_context_getset_2(self):
     265          v1 = contextvars.ContextVar('v1')
     266          v2 = contextvars.ContextVar('v2')
     267  
     268          t1 = v1.set(42)
     269          with self.assertRaisesRegex(ValueError, 'by a different'):
     270              v2.reset(t1)
     271  
     272      @isolated_context
     273      def test_context_getset_3(self):
     274          c = contextvars.ContextVar('c', default=42)
     275          ctx = contextvars.Context()
     276  
     277          def fun():
     278              self.assertEqual(c.get(), 42)
     279              with self.assertRaises(KeyError):
     280                  ctx[c]
     281              self.assertIsNone(ctx.get(c))
     282              self.assertEqual(ctx.get(c, 'spam'), 'spam')
     283              self.assertNotIn(c, ctx)
     284              self.assertEqual(list(ctx.keys()), [])
     285  
     286              t = c.set(1)
     287              self.assertEqual(list(ctx.keys()), [c])
     288              self.assertEqual(ctx[c], 1)
     289  
     290              c.reset(t)
     291              self.assertEqual(list(ctx.keys()), [])
     292              with self.assertRaises(KeyError):
     293                  ctx[c]
     294  
     295          ctx.run(fun)
     296  
     297      @isolated_context
     298      def test_context_getset_4(self):
     299          c = contextvars.ContextVar('c', default=42)
     300          ctx = contextvars.Context()
     301  
     302          tok = ctx.run(c.set, 1)
     303  
     304          with self.assertRaisesRegex(ValueError, 'different Context'):
     305              c.reset(tok)
     306  
     307      @isolated_context
     308      def test_context_getset_5(self):
     309          c = contextvars.ContextVar('c', default=42)
     310          c.set([])
     311  
     312          def fun():
     313              c.set([])
     314              c.get().append(42)
     315              self.assertEqual(c.get(), [42])
     316  
     317          contextvars.copy_context().run(fun)
     318          self.assertEqual(c.get(), [])
     319  
     320      def test_context_copy_1(self):
     321          ctx1 = contextvars.Context()
     322          c = contextvars.ContextVar('c', default=42)
     323  
     324          def ctx1_fun():
     325              c.set(10)
     326  
     327              ctx2 = ctx1.copy()
     328              self.assertEqual(ctx2[c], 10)
     329  
     330              c.set(20)
     331              self.assertEqual(ctx1[c], 20)
     332              self.assertEqual(ctx2[c], 10)
     333  
     334              ctx2.run(ctx2_fun)
     335              self.assertEqual(ctx1[c], 20)
     336              self.assertEqual(ctx2[c], 30)
     337  
     338          def ctx2_fun():
     339              self.assertEqual(c.get(), 10)
     340              c.set(30)
     341              self.assertEqual(c.get(), 30)
     342  
     343          ctx1.run(ctx1_fun)
     344  
     345      @isolated_context
     346      @threading_helper.requires_working_threading()
     347      def test_context_threads_1(self):
     348          cvar = contextvars.ContextVar('cvar')
     349  
     350          def sub(num):
     351              for i in range(10):
     352                  cvar.set(num + i)
     353                  time.sleep(random.uniform(0.001, 0.05))
     354                  self.assertEqual(cvar.get(), num + i)
     355              return num
     356  
     357          tp = concurrent.futures.ThreadPoolExecutor(max_workers=10)
     358          try:
     359              results = list(tp.map(sub, range(10)))
     360          finally:
     361              tp.shutdown()
     362          self.assertEqual(results, list(range(10)))
     363  
     364  
     365  # HAMT Tests
     366  
     367  
     368  class ESC[4;38;5;81mHashKey:
     369      _crasher = None
     370  
     371      def __init__(self, hash, name, *, error_on_eq_to=None):
     372          assert hash != -1
     373          self.name = name
     374          self.hash = hash
     375          self.error_on_eq_to = error_on_eq_to
     376  
     377      def __repr__(self):
     378          return f'<Key name:{self.name} hash:{self.hash}>'
     379  
     380      def __hash__(self):
     381          if self._crasher is not None and self._crasher.error_on_hash:
     382              raise HashingError
     383  
     384          return self.hash
     385  
     386      def __eq__(self, other):
     387          if not isinstance(other, HashKey):
     388              return NotImplemented
     389  
     390          if self._crasher is not None and self._crasher.error_on_eq:
     391              raise EqError
     392  
     393          if self.error_on_eq_to is not None and self.error_on_eq_to is other:
     394              raise ValueError(f'cannot compare {self!r} to {other!r}')
     395          if other.error_on_eq_to is not None and other.error_on_eq_to is self:
     396              raise ValueError(f'cannot compare {other!r} to {self!r}')
     397  
     398          return (self.name, self.hash) == (other.name, other.hash)
     399  
     400  
     401  class ESC[4;38;5;81mKeyStr(ESC[4;38;5;149mstr):
     402      def __hash__(self):
     403          if HashKey._crasher is not None and HashKey._crasher.error_on_hash:
     404              raise HashingError
     405          return super().__hash__()
     406  
     407      def __eq__(self, other):
     408          if HashKey._crasher is not None and HashKey._crasher.error_on_eq:
     409              raise EqError
     410          return super().__eq__(other)
     411  
     412  
     413  class ESC[4;38;5;81mHaskKeyCrasher:
     414      def __init__(self, *, error_on_hash=False, error_on_eq=False):
     415          self.error_on_hash = error_on_hash
     416          self.error_on_eq = error_on_eq
     417  
     418      def __enter__(self):
     419          if HashKey._crasher is not None:
     420              raise RuntimeError('cannot nest crashers')
     421          HashKey._crasher = self
     422  
     423      def __exit__(self, *exc):
     424          HashKey._crasher = None
     425  
     426  
     427  class ESC[4;38;5;81mHashingError(ESC[4;38;5;149mException):
     428      pass
     429  
     430  
     431  class ESC[4;38;5;81mEqError(ESC[4;38;5;149mException):
     432      pass
     433  
     434  
     435  @unittest.skipIf(hamt is None, '_testcapi lacks "hamt()" function')
     436  class ESC[4;38;5;81mHamtTest(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
     437  
     438      def test_hashkey_helper_1(self):
     439          k1 = HashKey(10, 'aaa')
     440          k2 = HashKey(10, 'bbb')
     441  
     442          self.assertNotEqual(k1, k2)
     443          self.assertEqual(hash(k1), hash(k2))
     444  
     445          d = dict()
     446          d[k1] = 'a'
     447          d[k2] = 'b'
     448  
     449          self.assertEqual(d[k1], 'a')
     450          self.assertEqual(d[k2], 'b')
     451  
     452      def test_hamt_basics_1(self):
     453          h = hamt()
     454          h = None  # NoQA
     455  
     456      def test_hamt_basics_2(self):
     457          h = hamt()
     458          self.assertEqual(len(h), 0)
     459  
     460          h2 = h.set('a', 'b')
     461          self.assertIsNot(h, h2)
     462          self.assertEqual(len(h), 0)
     463          self.assertEqual(len(h2), 1)
     464  
     465          self.assertIsNone(h.get('a'))
     466          self.assertEqual(h.get('a', 42), 42)
     467  
     468          self.assertEqual(h2.get('a'), 'b')
     469  
     470          h3 = h2.set('b', 10)
     471          self.assertIsNot(h2, h3)
     472          self.assertEqual(len(h), 0)
     473          self.assertEqual(len(h2), 1)
     474          self.assertEqual(len(h3), 2)
     475          self.assertEqual(h3.get('a'), 'b')
     476          self.assertEqual(h3.get('b'), 10)
     477  
     478          self.assertIsNone(h.get('b'))
     479          self.assertIsNone(h2.get('b'))
     480  
     481          self.assertIsNone(h.get('a'))
     482          self.assertEqual(h2.get('a'), 'b')
     483  
     484          h = h2 = h3 = None
     485  
     486      def test_hamt_basics_3(self):
     487          h = hamt()
     488          o = object()
     489          h1 = h.set('1', o)
     490          h2 = h1.set('1', o)
     491          self.assertIs(h1, h2)
     492  
     493      def test_hamt_basics_4(self):
     494          h = hamt()
     495          h1 = h.set('key', [])
     496          h2 = h1.set('key', [])
     497          self.assertIsNot(h1, h2)
     498          self.assertEqual(len(h1), 1)
     499          self.assertEqual(len(h2), 1)
     500          self.assertIsNot(h1.get('key'), h2.get('key'))
     501  
     502      def test_hamt_collision_1(self):
     503          k1 = HashKey(10, 'aaa')
     504          k2 = HashKey(10, 'bbb')
     505          k3 = HashKey(10, 'ccc')
     506  
     507          h = hamt()
     508          h2 = h.set(k1, 'a')
     509          h3 = h2.set(k2, 'b')
     510  
     511          self.assertEqual(h.get(k1), None)
     512          self.assertEqual(h.get(k2), None)
     513  
     514          self.assertEqual(h2.get(k1), 'a')
     515          self.assertEqual(h2.get(k2), None)
     516  
     517          self.assertEqual(h3.get(k1), 'a')
     518          self.assertEqual(h3.get(k2), 'b')
     519  
     520          h4 = h3.set(k2, 'cc')
     521          h5 = h4.set(k3, 'aa')
     522  
     523          self.assertEqual(h3.get(k1), 'a')
     524          self.assertEqual(h3.get(k2), 'b')
     525          self.assertEqual(h4.get(k1), 'a')
     526          self.assertEqual(h4.get(k2), 'cc')
     527          self.assertEqual(h4.get(k3), None)
     528          self.assertEqual(h5.get(k1), 'a')
     529          self.assertEqual(h5.get(k2), 'cc')
     530          self.assertEqual(h5.get(k2), 'cc')
     531          self.assertEqual(h5.get(k3), 'aa')
     532  
     533          self.assertEqual(len(h), 0)
     534          self.assertEqual(len(h2), 1)
     535          self.assertEqual(len(h3), 2)
     536          self.assertEqual(len(h4), 2)
     537          self.assertEqual(len(h5), 3)
     538  
     539      def test_hamt_collision_3(self):
     540          # Test that iteration works with the deepest tree possible.
     541          # https://github.com/python/cpython/issues/93065
     542  
     543          C = HashKey(0b10000000_00000000_00000000_00000000, 'C')
     544          D = HashKey(0b10000000_00000000_00000000_00000000, 'D')
     545  
     546          E = HashKey(0b00000000_00000000_00000000_00000000, 'E')
     547  
     548          h = hamt()
     549          h = h.set(C, 'C')
     550          h = h.set(D, 'D')
     551          h = h.set(E, 'E')
     552  
     553          # BitmapNode(size=2 count=1 bitmap=0b1):
     554          #   NULL:
     555          #     BitmapNode(size=2 count=1 bitmap=0b1):
     556          #       NULL:
     557          #         BitmapNode(size=2 count=1 bitmap=0b1):
     558          #           NULL:
     559          #             BitmapNode(size=2 count=1 bitmap=0b1):
     560          #               NULL:
     561          #                 BitmapNode(size=2 count=1 bitmap=0b1):
     562          #                   NULL:
     563          #                     BitmapNode(size=2 count=1 bitmap=0b1):
     564          #                       NULL:
     565          #                         BitmapNode(size=4 count=2 bitmap=0b101):
     566          #                           <Key name:E hash:0>: 'E'
     567          #                           NULL:
     568          #                             CollisionNode(size=4 id=0x107a24520):
     569          #                               <Key name:C hash:2147483648>: 'C'
     570          #                               <Key name:D hash:2147483648>: 'D'
     571  
     572          self.assertEqual({k.name for k in h.keys()}, {'C', 'D', 'E'})
     573  
     574      @support.requires_resource('cpu')
     575      def test_hamt_stress(self):
     576          COLLECTION_SIZE = 7000
     577          TEST_ITERS_EVERY = 647
     578          CRASH_HASH_EVERY = 97
     579          CRASH_EQ_EVERY = 11
     580          RUN_XTIMES = 3
     581  
     582          for _ in range(RUN_XTIMES):
     583              h = hamt()
     584              d = dict()
     585  
     586              for i in range(COLLECTION_SIZE):
     587                  key = KeyStr(i)
     588  
     589                  if not (i % CRASH_HASH_EVERY):
     590                      with HaskKeyCrasher(error_on_hash=True):
     591                          with self.assertRaises(HashingError):
     592                              h.set(key, i)
     593  
     594                  h = h.set(key, i)
     595  
     596                  if not (i % CRASH_EQ_EVERY):
     597                      with HaskKeyCrasher(error_on_eq=True):
     598                          with self.assertRaises(EqError):
     599                              h.get(KeyStr(i))  # really trigger __eq__
     600  
     601                  d[key] = i
     602                  self.assertEqual(len(d), len(h))
     603  
     604                  if not (i % TEST_ITERS_EVERY):
     605                      self.assertEqual(set(h.items()), set(d.items()))
     606                      self.assertEqual(len(h.items()), len(d.items()))
     607  
     608              self.assertEqual(len(h), COLLECTION_SIZE)
     609  
     610              for key in range(COLLECTION_SIZE):
     611                  self.assertEqual(h.get(KeyStr(key), 'not found'), key)
     612  
     613              keys_to_delete = list(range(COLLECTION_SIZE))
     614              random.shuffle(keys_to_delete)
     615              for iter_i, i in enumerate(keys_to_delete):
     616                  key = KeyStr(i)
     617  
     618                  if not (iter_i % CRASH_HASH_EVERY):
     619                      with HaskKeyCrasher(error_on_hash=True):
     620                          with self.assertRaises(HashingError):
     621                              h.delete(key)
     622  
     623                  if not (iter_i % CRASH_EQ_EVERY):
     624                      with HaskKeyCrasher(error_on_eq=True):
     625                          with self.assertRaises(EqError):
     626                              h.delete(KeyStr(i))
     627  
     628                  h = h.delete(key)
     629                  self.assertEqual(h.get(key, 'not found'), 'not found')
     630                  del d[key]
     631                  self.assertEqual(len(d), len(h))
     632  
     633                  if iter_i == COLLECTION_SIZE // 2:
     634                      hm = h
     635                      dm = d.copy()
     636  
     637                  if not (iter_i % TEST_ITERS_EVERY):
     638                      self.assertEqual(set(h.keys()), set(d.keys()))
     639                      self.assertEqual(len(h.keys()), len(d.keys()))
     640  
     641              self.assertEqual(len(d), 0)
     642              self.assertEqual(len(h), 0)
     643  
     644              # ============
     645  
     646              for key in dm:
     647                  self.assertEqual(hm.get(str(key)), dm[key])
     648              self.assertEqual(len(dm), len(hm))
     649  
     650              for i, key in enumerate(keys_to_delete):
     651                  hm = hm.delete(str(key))
     652                  self.assertEqual(hm.get(str(key), 'not found'), 'not found')
     653                  dm.pop(str(key), None)
     654                  self.assertEqual(len(d), len(h))
     655  
     656                  if not (i % TEST_ITERS_EVERY):
     657                      self.assertEqual(set(h.values()), set(d.values()))
     658                      self.assertEqual(len(h.values()), len(d.values()))
     659  
     660              self.assertEqual(len(d), 0)
     661              self.assertEqual(len(h), 0)
     662              self.assertEqual(list(h.items()), [])
     663  
     664      def test_hamt_delete_1(self):
     665          A = HashKey(100, 'A')
     666          B = HashKey(101, 'B')
     667          C = HashKey(102, 'C')
     668          D = HashKey(103, 'D')
     669          E = HashKey(104, 'E')
     670          Z = HashKey(-100, 'Z')
     671  
     672          Er = HashKey(103, 'Er', error_on_eq_to=D)
     673  
     674          h = hamt()
     675          h = h.set(A, 'a')
     676          h = h.set(B, 'b')
     677          h = h.set(C, 'c')
     678          h = h.set(D, 'd')
     679          h = h.set(E, 'e')
     680  
     681          orig_len = len(h)
     682  
     683          # BitmapNode(size=10 bitmap=0b111110000 id=0x10eadc618):
     684          #     <Key name:A hash:100>: 'a'
     685          #     <Key name:B hash:101>: 'b'
     686          #     <Key name:C hash:102>: 'c'
     687          #     <Key name:D hash:103>: 'd'
     688          #     <Key name:E hash:104>: 'e'
     689  
     690          h = h.delete(C)
     691          self.assertEqual(len(h), orig_len - 1)
     692  
     693          with self.assertRaisesRegex(ValueError, 'cannot compare'):
     694              h.delete(Er)
     695  
     696          h = h.delete(D)
     697          self.assertEqual(len(h), orig_len - 2)
     698  
     699          h2 = h.delete(Z)
     700          self.assertIs(h2, h)
     701  
     702          h = h.delete(A)
     703          self.assertEqual(len(h), orig_len - 3)
     704  
     705          self.assertEqual(h.get(A, 42), 42)
     706          self.assertEqual(h.get(B), 'b')
     707          self.assertEqual(h.get(E), 'e')
     708  
     709      def test_hamt_delete_2(self):
     710          A = HashKey(100, 'A')
     711          B = HashKey(201001, 'B')
     712          C = HashKey(101001, 'C')
     713          D = HashKey(103, 'D')
     714          E = HashKey(104, 'E')
     715          Z = HashKey(-100, 'Z')
     716  
     717          Er = HashKey(201001, 'Er', error_on_eq_to=B)
     718  
     719          h = hamt()
     720          h = h.set(A, 'a')
     721          h = h.set(B, 'b')
     722          h = h.set(C, 'c')
     723          h = h.set(D, 'd')
     724          h = h.set(E, 'e')
     725  
     726          orig_len = len(h)
     727  
     728          # BitmapNode(size=8 bitmap=0b1110010000):
     729          #     <Key name:A hash:100>: 'a'
     730          #     <Key name:D hash:103>: 'd'
     731          #     <Key name:E hash:104>: 'e'
     732          #     NULL:
     733          #         BitmapNode(size=4 bitmap=0b100000000001000000000):
     734          #             <Key name:B hash:201001>: 'b'
     735          #             <Key name:C hash:101001>: 'c'
     736  
     737          with self.assertRaisesRegex(ValueError, 'cannot compare'):
     738              h.delete(Er)
     739  
     740          h = h.delete(Z)
     741          self.assertEqual(len(h), orig_len)
     742  
     743          h = h.delete(C)
     744          self.assertEqual(len(h), orig_len - 1)
     745  
     746          h = h.delete(B)
     747          self.assertEqual(len(h), orig_len - 2)
     748  
     749          h = h.delete(A)
     750          self.assertEqual(len(h), orig_len - 3)
     751  
     752          self.assertEqual(h.get(D), 'd')
     753          self.assertEqual(h.get(E), 'e')
     754  
     755          h = h.delete(A)
     756          h = h.delete(B)
     757          h = h.delete(D)
     758          h = h.delete(E)
     759          self.assertEqual(len(h), 0)
     760  
     761      def test_hamt_delete_3(self):
     762          A = HashKey(100, 'A')
     763          B = HashKey(101, 'B')
     764          C = HashKey(100100, 'C')
     765          D = HashKey(100100, 'D')
     766          E = HashKey(104, 'E')
     767  
     768          h = hamt()
     769          h = h.set(A, 'a')
     770          h = h.set(B, 'b')
     771          h = h.set(C, 'c')
     772          h = h.set(D, 'd')
     773          h = h.set(E, 'e')
     774  
     775          orig_len = len(h)
     776  
     777          # BitmapNode(size=6 bitmap=0b100110000):
     778          #     NULL:
     779          #         BitmapNode(size=4 bitmap=0b1000000000000000000001000):
     780          #             <Key name:A hash:100>: 'a'
     781          #             NULL:
     782          #                 CollisionNode(size=4 id=0x108572410):
     783          #                     <Key name:C hash:100100>: 'c'
     784          #                     <Key name:D hash:100100>: 'd'
     785          #     <Key name:B hash:101>: 'b'
     786          #     <Key name:E hash:104>: 'e'
     787  
     788          h = h.delete(A)
     789          self.assertEqual(len(h), orig_len - 1)
     790  
     791          h = h.delete(E)
     792          self.assertEqual(len(h), orig_len - 2)
     793  
     794          self.assertEqual(h.get(C), 'c')
     795          self.assertEqual(h.get(B), 'b')
     796  
     797      def test_hamt_delete_4(self):
     798          A = HashKey(100, 'A')
     799          B = HashKey(101, 'B')
     800          C = HashKey(100100, 'C')
     801          D = HashKey(100100, 'D')
     802          E = HashKey(100100, 'E')
     803  
     804          h = hamt()
     805          h = h.set(A, 'a')
     806          h = h.set(B, 'b')
     807          h = h.set(C, 'c')
     808          h = h.set(D, 'd')
     809          h = h.set(E, 'e')
     810  
     811          orig_len = len(h)
     812  
     813          # BitmapNode(size=4 bitmap=0b110000):
     814          #     NULL:
     815          #         BitmapNode(size=4 bitmap=0b1000000000000000000001000):
     816          #             <Key name:A hash:100>: 'a'
     817          #             NULL:
     818          #                 CollisionNode(size=6 id=0x10515ef30):
     819          #                     <Key name:C hash:100100>: 'c'
     820          #                     <Key name:D hash:100100>: 'd'
     821          #                     <Key name:E hash:100100>: 'e'
     822          #     <Key name:B hash:101>: 'b'
     823  
     824          h = h.delete(D)
     825          self.assertEqual(len(h), orig_len - 1)
     826  
     827          h = h.delete(E)
     828          self.assertEqual(len(h), orig_len - 2)
     829  
     830          h = h.delete(C)
     831          self.assertEqual(len(h), orig_len - 3)
     832  
     833          h = h.delete(A)
     834          self.assertEqual(len(h), orig_len - 4)
     835  
     836          h = h.delete(B)
     837          self.assertEqual(len(h), 0)
     838  
     839      def test_hamt_delete_5(self):
     840          h = hamt()
     841  
     842          keys = []
     843          for i in range(17):
     844              key = HashKey(i, str(i))
     845              keys.append(key)
     846              h = h.set(key, f'val-{i}')
     847  
     848          collision_key16 = HashKey(16, '18')
     849          h = h.set(collision_key16, 'collision')
     850  
     851          # ArrayNode(id=0x10f8b9318):
     852          #     0::
     853          #     BitmapNode(size=2 count=1 bitmap=0b1):
     854          #         <Key name:0 hash:0>: 'val-0'
     855          #
     856          # ... 14 more BitmapNodes ...
     857          #
     858          #     15::
     859          #     BitmapNode(size=2 count=1 bitmap=0b1):
     860          #         <Key name:15 hash:15>: 'val-15'
     861          #
     862          #     16::
     863          #     BitmapNode(size=2 count=1 bitmap=0b1):
     864          #         NULL:
     865          #             CollisionNode(size=4 id=0x10f2f5af8):
     866          #                 <Key name:16 hash:16>: 'val-16'
     867          #                 <Key name:18 hash:16>: 'collision'
     868  
     869          self.assertEqual(len(h), 18)
     870  
     871          h = h.delete(keys[2])
     872          self.assertEqual(len(h), 17)
     873  
     874          h = h.delete(collision_key16)
     875          self.assertEqual(len(h), 16)
     876          h = h.delete(keys[16])
     877          self.assertEqual(len(h), 15)
     878  
     879          h = h.delete(keys[1])
     880          self.assertEqual(len(h), 14)
     881          h = h.delete(keys[1])
     882          self.assertEqual(len(h), 14)
     883  
     884          for key in keys:
     885              h = h.delete(key)
     886          self.assertEqual(len(h), 0)
     887  
     888      def test_hamt_items_1(self):
     889          A = HashKey(100, 'A')
     890          B = HashKey(201001, 'B')
     891          C = HashKey(101001, 'C')
     892          D = HashKey(103, 'D')
     893          E = HashKey(104, 'E')
     894          F = HashKey(110, 'F')
     895  
     896          h = hamt()
     897          h = h.set(A, 'a')
     898          h = h.set(B, 'b')
     899          h = h.set(C, 'c')
     900          h = h.set(D, 'd')
     901          h = h.set(E, 'e')
     902          h = h.set(F, 'f')
     903  
     904          it = h.items()
     905          self.assertEqual(
     906              set(list(it)),
     907              {(A, 'a'), (B, 'b'), (C, 'c'), (D, 'd'), (E, 'e'), (F, 'f')})
     908  
     909      def test_hamt_items_2(self):
     910          A = HashKey(100, 'A')
     911          B = HashKey(101, 'B')
     912          C = HashKey(100100, 'C')
     913          D = HashKey(100100, 'D')
     914          E = HashKey(100100, 'E')
     915          F = HashKey(110, 'F')
     916  
     917          h = hamt()
     918          h = h.set(A, 'a')
     919          h = h.set(B, 'b')
     920          h = h.set(C, 'c')
     921          h = h.set(D, 'd')
     922          h = h.set(E, 'e')
     923          h = h.set(F, 'f')
     924  
     925          it = h.items()
     926          self.assertEqual(
     927              set(list(it)),
     928              {(A, 'a'), (B, 'b'), (C, 'c'), (D, 'd'), (E, 'e'), (F, 'f')})
     929  
     930      def test_hamt_keys_1(self):
     931          A = HashKey(100, 'A')
     932          B = HashKey(101, 'B')
     933          C = HashKey(100100, 'C')
     934          D = HashKey(100100, 'D')
     935          E = HashKey(100100, 'E')
     936          F = HashKey(110, 'F')
     937  
     938          h = hamt()
     939          h = h.set(A, 'a')
     940          h = h.set(B, 'b')
     941          h = h.set(C, 'c')
     942          h = h.set(D, 'd')
     943          h = h.set(E, 'e')
     944          h = h.set(F, 'f')
     945  
     946          self.assertEqual(set(list(h.keys())), {A, B, C, D, E, F})
     947          self.assertEqual(set(list(h)), {A, B, C, D, E, F})
     948  
     949      def test_hamt_items_3(self):
     950          h = hamt()
     951          self.assertEqual(len(h.items()), 0)
     952          self.assertEqual(list(h.items()), [])
     953  
     954      def test_hamt_eq_1(self):
     955          A = HashKey(100, 'A')
     956          B = HashKey(101, 'B')
     957          C = HashKey(100100, 'C')
     958          D = HashKey(100100, 'D')
     959          E = HashKey(120, 'E')
     960  
     961          h1 = hamt()
     962          h1 = h1.set(A, 'a')
     963          h1 = h1.set(B, 'b')
     964          h1 = h1.set(C, 'c')
     965          h1 = h1.set(D, 'd')
     966  
     967          h2 = hamt()
     968          h2 = h2.set(A, 'a')
     969  
     970          self.assertFalse(h1 == h2)
     971          self.assertTrue(h1 != h2)
     972  
     973          h2 = h2.set(B, 'b')
     974          self.assertFalse(h1 == h2)
     975          self.assertTrue(h1 != h2)
     976  
     977          h2 = h2.set(C, 'c')
     978          self.assertFalse(h1 == h2)
     979          self.assertTrue(h1 != h2)
     980  
     981          h2 = h2.set(D, 'd2')
     982          self.assertFalse(h1 == h2)
     983          self.assertTrue(h1 != h2)
     984  
     985          h2 = h2.set(D, 'd')
     986          self.assertTrue(h1 == h2)
     987          self.assertFalse(h1 != h2)
     988  
     989          h2 = h2.set(E, 'e')
     990          self.assertFalse(h1 == h2)
     991          self.assertTrue(h1 != h2)
     992  
     993          h2 = h2.delete(D)
     994          self.assertFalse(h1 == h2)
     995          self.assertTrue(h1 != h2)
     996  
     997          h2 = h2.set(E, 'd')
     998          self.assertFalse(h1 == h2)
     999          self.assertTrue(h1 != h2)
    1000  
    1001      def test_hamt_eq_2(self):
    1002          A = HashKey(100, 'A')
    1003          Er = HashKey(100, 'Er', error_on_eq_to=A)
    1004  
    1005          h1 = hamt()
    1006          h1 = h1.set(A, 'a')
    1007  
    1008          h2 = hamt()
    1009          h2 = h2.set(Er, 'a')
    1010  
    1011          with self.assertRaisesRegex(ValueError, 'cannot compare'):
    1012              h1 == h2
    1013  
    1014          with self.assertRaisesRegex(ValueError, 'cannot compare'):
    1015              h1 != h2
    1016  
    1017      def test_hamt_gc_1(self):
    1018          A = HashKey(100, 'A')
    1019  
    1020          h = hamt()
    1021          h = h.set(0, 0)  # empty HAMT node is memoized in hamt.c
    1022          ref = weakref.ref(h)
    1023  
    1024          a = []
    1025          a.append(a)
    1026          a.append(h)
    1027          b = []
    1028          a.append(b)
    1029          b.append(a)
    1030          h = h.set(A, b)
    1031  
    1032          del h, a, b
    1033  
    1034          gc.collect()
    1035          gc.collect()
    1036          gc.collect()
    1037  
    1038          self.assertIsNone(ref())
    1039  
    1040      def test_hamt_gc_2(self):
    1041          A = HashKey(100, 'A')
    1042          B = HashKey(101, 'B')
    1043  
    1044          h = hamt()
    1045          h = h.set(A, 'a')
    1046          h = h.set(A, h)
    1047  
    1048          ref = weakref.ref(h)
    1049          hi = h.items()
    1050          next(hi)
    1051  
    1052          del h, hi
    1053  
    1054          gc.collect()
    1055          gc.collect()
    1056          gc.collect()
    1057  
    1058          self.assertIsNone(ref())
    1059  
    1060      def test_hamt_in_1(self):
    1061          A = HashKey(100, 'A')
    1062          AA = HashKey(100, 'A')
    1063  
    1064          B = HashKey(101, 'B')
    1065  
    1066          h = hamt()
    1067          h = h.set(A, 1)
    1068  
    1069          self.assertTrue(A in h)
    1070          self.assertFalse(B in h)
    1071  
    1072          with self.assertRaises(EqError):
    1073              with HaskKeyCrasher(error_on_eq=True):
    1074                  AA in h
    1075  
    1076          with self.assertRaises(HashingError):
    1077              with HaskKeyCrasher(error_on_hash=True):
    1078                  AA in h
    1079  
    1080      def test_hamt_getitem_1(self):
    1081          A = HashKey(100, 'A')
    1082          AA = HashKey(100, 'A')
    1083  
    1084          B = HashKey(101, 'B')
    1085  
    1086          h = hamt()
    1087          h = h.set(A, 1)
    1088  
    1089          self.assertEqual(h[A], 1)
    1090          self.assertEqual(h[AA], 1)
    1091  
    1092          with self.assertRaises(KeyError):
    1093              h[B]
    1094  
    1095          with self.assertRaises(EqError):
    1096              with HaskKeyCrasher(error_on_eq=True):
    1097                  h[AA]
    1098  
    1099          with self.assertRaises(HashingError):
    1100              with HaskKeyCrasher(error_on_hash=True):
    1101                  h[AA]
    1102  
    1103  
    1104  if __name__ == "__main__":
    1105      unittest.main()