(root)/
Python-3.12.0/
Lib/
test/
test_heapq.py
       1  """Unittests for heapq."""
       2  
       3  import random
       4  import unittest
       5  import doctest
       6  
       7  from test.support import import_helper
       8  from unittest import TestCase, skipUnless
       9  from operator import itemgetter
      10  
      11  py_heapq = import_helper.import_fresh_module('heapq', blocked=['_heapq'])
      12  c_heapq = import_helper.import_fresh_module('heapq', fresh=['_heapq'])
      13  
      14  # _heapq.nlargest/nsmallest are saved in heapq._nlargest/_smallest when
      15  # _heapq is imported, so check them there
      16  func_names = ['heapify', 'heappop', 'heappush', 'heappushpop', 'heapreplace',
      17                '_heappop_max', '_heapreplace_max', '_heapify_max']
      18  
      19  class ESC[4;38;5;81mTestModules(ESC[4;38;5;149mTestCase):
      20      def test_py_functions(self):
      21          for fname in func_names:
      22              self.assertEqual(getattr(py_heapq, fname).__module__, 'heapq')
      23  
      24      @skipUnless(c_heapq, 'requires _heapq')
      25      def test_c_functions(self):
      26          for fname in func_names:
      27              self.assertEqual(getattr(c_heapq, fname).__module__, '_heapq')
      28  
      29  
      30  def load_tests(loader, tests, ignore):
      31      # The 'merge' function has examples in its docstring which we should test
      32      # with 'doctest'.
      33      #
      34      # However, doctest can't easily find all docstrings in the module (loading
      35      # it through import_fresh_module seems to confuse it), so we specifically
      36      # create a finder which returns the doctests from the merge method.
      37  
      38      class ESC[4;38;5;81mHeapqMergeDocTestFinder:
      39          def find(self, *args, **kwargs):
      40              dtf = doctest.DocTestFinder()
      41              return dtf.find(py_heapq.merge)
      42  
      43      tests.addTests(doctest.DocTestSuite(py_heapq,
      44                                          test_finder=HeapqMergeDocTestFinder()))
      45      return tests
      46  
      47  class ESC[4;38;5;81mTestHeap:
      48  
      49      def test_push_pop(self):
      50          # 1) Push 256 random numbers and pop them off, verifying all's OK.
      51          heap = []
      52          data = []
      53          self.check_invariant(heap)
      54          for i in range(256):
      55              item = random.random()
      56              data.append(item)
      57              self.module.heappush(heap, item)
      58              self.check_invariant(heap)
      59          results = []
      60          while heap:
      61              item = self.module.heappop(heap)
      62              self.check_invariant(heap)
      63              results.append(item)
      64          data_sorted = data[:]
      65          data_sorted.sort()
      66          self.assertEqual(data_sorted, results)
      67          # 2) Check that the invariant holds for a sorted array
      68          self.check_invariant(results)
      69  
      70          self.assertRaises(TypeError, self.module.heappush, [])
      71          try:
      72              self.assertRaises(TypeError, self.module.heappush, None, None)
      73              self.assertRaises(TypeError, self.module.heappop, None)
      74          except AttributeError:
      75              pass
      76  
      77      def check_invariant(self, heap):
      78          # Check the heap invariant.
      79          for pos, item in enumerate(heap):
      80              if pos: # pos 0 has no parent
      81                  parentpos = (pos-1) >> 1
      82                  self.assertTrue(heap[parentpos] <= item)
      83  
      84      def test_heapify(self):
      85          for size in list(range(30)) + [20000]:
      86              heap = [random.random() for dummy in range(size)]
      87              self.module.heapify(heap)
      88              self.check_invariant(heap)
      89  
      90          self.assertRaises(TypeError, self.module.heapify, None)
      91  
      92      def test_naive_nbest(self):
      93          data = [random.randrange(2000) for i in range(1000)]
      94          heap = []
      95          for item in data:
      96              self.module.heappush(heap, item)
      97              if len(heap) > 10:
      98                  self.module.heappop(heap)
      99          heap.sort()
     100          self.assertEqual(heap, sorted(data)[-10:])
     101  
     102      def heapiter(self, heap):
     103          # An iterator returning a heap's elements, smallest-first.
     104          try:
     105              while 1:
     106                  yield self.module.heappop(heap)
     107          except IndexError:
     108              pass
     109  
     110      def test_nbest(self):
     111          # Less-naive "N-best" algorithm, much faster (if len(data) is big
     112          # enough <wink>) than sorting all of data.  However, if we had a max
     113          # heap instead of a min heap, it could go faster still via
     114          # heapify'ing all of data (linear time), then doing 10 heappops
     115          # (10 log-time steps).
     116          data = [random.randrange(2000) for i in range(1000)]
     117          heap = data[:10]
     118          self.module.heapify(heap)
     119          for item in data[10:]:
     120              if item > heap[0]:  # this gets rarer the longer we run
     121                  self.module.heapreplace(heap, item)
     122          self.assertEqual(list(self.heapiter(heap)), sorted(data)[-10:])
     123  
     124          self.assertRaises(TypeError, self.module.heapreplace, None)
     125          self.assertRaises(TypeError, self.module.heapreplace, None, None)
     126          self.assertRaises(IndexError, self.module.heapreplace, [], None)
     127  
     128      def test_nbest_with_pushpop(self):
     129          data = [random.randrange(2000) for i in range(1000)]
     130          heap = data[:10]
     131          self.module.heapify(heap)
     132          for item in data[10:]:
     133              self.module.heappushpop(heap, item)
     134          self.assertEqual(list(self.heapiter(heap)), sorted(data)[-10:])
     135          self.assertEqual(self.module.heappushpop([], 'x'), 'x')
     136  
     137      def test_heappushpop(self):
     138          h = []
     139          x = self.module.heappushpop(h, 10)
     140          self.assertEqual((h, x), ([], 10))
     141  
     142          h = [10]
     143          x = self.module.heappushpop(h, 10.0)
     144          self.assertEqual((h, x), ([10], 10.0))
     145          self.assertEqual(type(h[0]), int)
     146          self.assertEqual(type(x), float)
     147  
     148          h = [10]
     149          x = self.module.heappushpop(h, 9)
     150          self.assertEqual((h, x), ([10], 9))
     151  
     152          h = [10]
     153          x = self.module.heappushpop(h, 11)
     154          self.assertEqual((h, x), ([11], 10))
     155  
     156      def test_heappop_max(self):
     157          # _heapop_max has an optimization for one-item lists which isn't
     158          # covered in other tests, so test that case explicitly here
     159          h = [3, 2]
     160          self.assertEqual(self.module._heappop_max(h), 3)
     161          self.assertEqual(self.module._heappop_max(h), 2)
     162  
     163      def test_heapsort(self):
     164          # Exercise everything with repeated heapsort checks
     165          for trial in range(100):
     166              size = random.randrange(50)
     167              data = [random.randrange(25) for i in range(size)]
     168              if trial & 1:     # Half of the time, use heapify
     169                  heap = data[:]
     170                  self.module.heapify(heap)
     171              else:             # The rest of the time, use heappush
     172                  heap = []
     173                  for item in data:
     174                      self.module.heappush(heap, item)
     175              heap_sorted = [self.module.heappop(heap) for i in range(size)]
     176              self.assertEqual(heap_sorted, sorted(data))
     177  
     178      def test_merge(self):
     179          inputs = []
     180          for i in range(random.randrange(25)):
     181              row = []
     182              for j in range(random.randrange(100)):
     183                  tup = random.choice('ABC'), random.randrange(-500, 500)
     184                  row.append(tup)
     185              inputs.append(row)
     186  
     187          for key in [None, itemgetter(0), itemgetter(1), itemgetter(1, 0)]:
     188              for reverse in [False, True]:
     189                  seqs = []
     190                  for seq in inputs:
     191                      seqs.append(sorted(seq, key=key, reverse=reverse))
     192                  self.assertEqual(sorted(chain(*inputs), key=key, reverse=reverse),
     193                                   list(self.module.merge(*seqs, key=key, reverse=reverse)))
     194                  self.assertEqual(list(self.module.merge()), [])
     195  
     196      def test_empty_merges(self):
     197          # Merging two empty lists (with or without a key) should produce
     198          # another empty list.
     199          self.assertEqual(list(self.module.merge([], [])), [])
     200          self.assertEqual(list(self.module.merge([], [], key=lambda: 6)), [])
     201  
     202      def test_merge_does_not_suppress_index_error(self):
     203          # Issue 19018: Heapq.merge suppresses IndexError from user generator
     204          def iterable():
     205              s = list(range(10))
     206              for i in range(20):
     207                  yield s[i]       # IndexError when i > 10
     208          with self.assertRaises(IndexError):
     209              list(self.module.merge(iterable(), iterable()))
     210  
     211      def test_merge_stability(self):
     212          class ESC[4;38;5;81mInt(ESC[4;38;5;149mint):
     213              pass
     214          inputs = [[], [], [], []]
     215          for i in range(20000):
     216              stream = random.randrange(4)
     217              x = random.randrange(500)
     218              obj = Int(x)
     219              obj.pair = (x, stream)
     220              inputs[stream].append(obj)
     221          for stream in inputs:
     222              stream.sort()
     223          result = [i.pair for i in self.module.merge(*inputs)]
     224          self.assertEqual(result, sorted(result))
     225  
     226      def test_nsmallest(self):
     227          data = [(random.randrange(2000), i) for i in range(1000)]
     228          for f in (None, lambda x:  x[0] * 547 % 2000):
     229              for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
     230                  self.assertEqual(list(self.module.nsmallest(n, data)),
     231                                   sorted(data)[:n])
     232                  self.assertEqual(list(self.module.nsmallest(n, data, key=f)),
     233                                   sorted(data, key=f)[:n])
     234  
     235      def test_nlargest(self):
     236          data = [(random.randrange(2000), i) for i in range(1000)]
     237          for f in (None, lambda x:  x[0] * 547 % 2000):
     238              for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
     239                  self.assertEqual(list(self.module.nlargest(n, data)),
     240                                   sorted(data, reverse=True)[:n])
     241                  self.assertEqual(list(self.module.nlargest(n, data, key=f)),
     242                                   sorted(data, key=f, reverse=True)[:n])
     243  
     244      def test_comparison_operator(self):
     245          # Issue 3051: Make sure heapq works with both __lt__
     246          # For python 3.0, __le__ alone is not enough
     247          def hsort(data, comp):
     248              data = [comp(x) for x in data]
     249              self.module.heapify(data)
     250              return [self.module.heappop(data).x for i in range(len(data))]
     251          class ESC[4;38;5;81mLT:
     252              def __init__(self, x):
     253                  self.x = x
     254              def __lt__(self, other):
     255                  return self.x > other.x
     256          class ESC[4;38;5;81mLE:
     257              def __init__(self, x):
     258                  self.x = x
     259              def __le__(self, other):
     260                  return self.x >= other.x
     261          data = [random.random() for i in range(100)]
     262          target = sorted(data, reverse=True)
     263          self.assertEqual(hsort(data, LT), target)
     264          self.assertRaises(TypeError, data, LE)
     265  
     266  
     267  class ESC[4;38;5;81mTestHeapPython(ESC[4;38;5;149mTestHeap, ESC[4;38;5;149mTestCase):
     268      module = py_heapq
     269  
     270  
     271  @skipUnless(c_heapq, 'requires _heapq')
     272  class ESC[4;38;5;81mTestHeapC(ESC[4;38;5;149mTestHeap, ESC[4;38;5;149mTestCase):
     273      module = c_heapq
     274  
     275  
     276  #==============================================================================
     277  
     278  class ESC[4;38;5;81mLenOnly:
     279      "Dummy sequence class defining __len__ but not __getitem__."
     280      def __len__(self):
     281          return 10
     282  
     283  class ESC[4;38;5;81mCmpErr:
     284      "Dummy element that always raises an error during comparison"
     285      def __eq__(self, other):
     286          raise ZeroDivisionError
     287      __ne__ = __lt__ = __le__ = __gt__ = __ge__ = __eq__
     288  
     289  def R(seqn):
     290      'Regular generator'
     291      for i in seqn:
     292          yield i
     293  
     294  class ESC[4;38;5;81mG:
     295      'Sequence using __getitem__'
     296      def __init__(self, seqn):
     297          self.seqn = seqn
     298      def __getitem__(self, i):
     299          return self.seqn[i]
     300  
     301  class ESC[4;38;5;81mI:
     302      'Sequence using iterator protocol'
     303      def __init__(self, seqn):
     304          self.seqn = seqn
     305          self.i = 0
     306      def __iter__(self):
     307          return self
     308      def __next__(self):
     309          if self.i >= len(self.seqn): raise StopIteration
     310          v = self.seqn[self.i]
     311          self.i += 1
     312          return v
     313  
     314  class ESC[4;38;5;81mIg:
     315      'Sequence using iterator protocol defined with a generator'
     316      def __init__(self, seqn):
     317          self.seqn = seqn
     318          self.i = 0
     319      def __iter__(self):
     320          for val in self.seqn:
     321              yield val
     322  
     323  class ESC[4;38;5;81mX:
     324      'Missing __getitem__ and __iter__'
     325      def __init__(self, seqn):
     326          self.seqn = seqn
     327          self.i = 0
     328      def __next__(self):
     329          if self.i >= len(self.seqn): raise StopIteration
     330          v = self.seqn[self.i]
     331          self.i += 1
     332          return v
     333  
     334  class ESC[4;38;5;81mN:
     335      'Iterator missing __next__()'
     336      def __init__(self, seqn):
     337          self.seqn = seqn
     338          self.i = 0
     339      def __iter__(self):
     340          return self
     341  
     342  class ESC[4;38;5;81mE:
     343      'Test propagation of exceptions'
     344      def __init__(self, seqn):
     345          self.seqn = seqn
     346          self.i = 0
     347      def __iter__(self):
     348          return self
     349      def __next__(self):
     350          3 // 0
     351  
     352  class ESC[4;38;5;81mS:
     353      'Test immediate stop'
     354      def __init__(self, seqn):
     355          pass
     356      def __iter__(self):
     357          return self
     358      def __next__(self):
     359          raise StopIteration
     360  
     361  from itertools import chain
     362  def L(seqn):
     363      'Test multiple tiers of iterators'
     364      return chain(map(lambda x:x, R(Ig(G(seqn)))))
     365  
     366  
     367  class ESC[4;38;5;81mSideEffectLT:
     368      def __init__(self, value, heap):
     369          self.value = value
     370          self.heap = heap
     371  
     372      def __lt__(self, other):
     373          self.heap[:] = []
     374          return self.value < other.value
     375  
     376  
     377  class ESC[4;38;5;81mTestErrorHandling:
     378  
     379      def test_non_sequence(self):
     380          for f in (self.module.heapify, self.module.heappop):
     381              self.assertRaises((TypeError, AttributeError), f, 10)
     382          for f in (self.module.heappush, self.module.heapreplace,
     383                    self.module.nlargest, self.module.nsmallest):
     384              self.assertRaises((TypeError, AttributeError), f, 10, 10)
     385  
     386      def test_len_only(self):
     387          for f in (self.module.heapify, self.module.heappop):
     388              self.assertRaises((TypeError, AttributeError), f, LenOnly())
     389          for f in (self.module.heappush, self.module.heapreplace):
     390              self.assertRaises((TypeError, AttributeError), f, LenOnly(), 10)
     391          for f in (self.module.nlargest, self.module.nsmallest):
     392              self.assertRaises(TypeError, f, 2, LenOnly())
     393  
     394      def test_cmp_err(self):
     395          seq = [CmpErr(), CmpErr(), CmpErr()]
     396          for f in (self.module.heapify, self.module.heappop):
     397              self.assertRaises(ZeroDivisionError, f, seq)
     398          for f in (self.module.heappush, self.module.heapreplace):
     399              self.assertRaises(ZeroDivisionError, f, seq, 10)
     400          for f in (self.module.nlargest, self.module.nsmallest):
     401              self.assertRaises(ZeroDivisionError, f, 2, seq)
     402  
     403      def test_arg_parsing(self):
     404          for f in (self.module.heapify, self.module.heappop,
     405                    self.module.heappush, self.module.heapreplace,
     406                    self.module.nlargest, self.module.nsmallest):
     407              self.assertRaises((TypeError, AttributeError), f, 10)
     408  
     409      def test_iterable_args(self):
     410          for f in (self.module.nlargest, self.module.nsmallest):
     411              for s in ("123", "", range(1000), (1, 1.2), range(2000,2200,5)):
     412                  for g in (G, I, Ig, L, R):
     413                      self.assertEqual(list(f(2, g(s))), list(f(2,s)))
     414                  self.assertEqual(list(f(2, S(s))), [])
     415                  self.assertRaises(TypeError, f, 2, X(s))
     416                  self.assertRaises(TypeError, f, 2, N(s))
     417                  self.assertRaises(ZeroDivisionError, f, 2, E(s))
     418  
     419      # Issue #17278: the heap may change size while it's being walked.
     420  
     421      def test_heappush_mutating_heap(self):
     422          heap = []
     423          heap.extend(SideEffectLT(i, heap) for i in range(200))
     424          # Python version raises IndexError, C version RuntimeError
     425          with self.assertRaises((IndexError, RuntimeError)):
     426              self.module.heappush(heap, SideEffectLT(5, heap))
     427  
     428      def test_heappop_mutating_heap(self):
     429          heap = []
     430          heap.extend(SideEffectLT(i, heap) for i in range(200))
     431          # Python version raises IndexError, C version RuntimeError
     432          with self.assertRaises((IndexError, RuntimeError)):
     433              self.module.heappop(heap)
     434  
     435      def test_comparison_operator_modifiying_heap(self):
     436          # See bpo-39421: Strong references need to be taken
     437          # when comparing objects as they can alter the heap
     438          class ESC[4;38;5;81mEvilClass(ESC[4;38;5;149mint):
     439              def __lt__(self, o):
     440                  heap.clear()
     441                  return NotImplemented
     442  
     443          heap = []
     444          self.module.heappush(heap, EvilClass(0))
     445          self.assertRaises(IndexError, self.module.heappushpop, heap, 1)
     446  
     447      def test_comparison_operator_modifiying_heap_two_heaps(self):
     448  
     449          class ESC[4;38;5;81mh(ESC[4;38;5;149mint):
     450              def __lt__(self, o):
     451                  list2.clear()
     452                  return NotImplemented
     453  
     454          class ESC[4;38;5;81mg(ESC[4;38;5;149mint):
     455              def __lt__(self, o):
     456                  list1.clear()
     457                  return NotImplemented
     458  
     459          list1, list2 = [], []
     460  
     461          self.module.heappush(list1, h(0))
     462          self.module.heappush(list2, g(0))
     463  
     464          self.assertRaises((IndexError, RuntimeError), self.module.heappush, list1, g(1))
     465          self.assertRaises((IndexError, RuntimeError), self.module.heappush, list2, h(1))
     466  
     467  class ESC[4;38;5;81mTestErrorHandlingPython(ESC[4;38;5;149mTestErrorHandling, ESC[4;38;5;149mTestCase):
     468      module = py_heapq
     469  
     470  @skipUnless(c_heapq, 'requires _heapq')
     471  class ESC[4;38;5;81mTestErrorHandlingC(ESC[4;38;5;149mTestErrorHandling, ESC[4;38;5;149mTestCase):
     472      module = c_heapq
     473  
     474  
     475  if __name__ == "__main__":
     476      unittest.main()