(root)/
Python-3.11.7/
Lib/
test/
test_heapq.py
       1  """Unittests for heapq."""
       2  
       3  import random
       4  import unittest
       5  import doctest
       6  
       7  from test import support
       8  from test.support import import_helper
       9  from unittest import TestCase, skipUnless
      10  from operator import itemgetter
      11  
      12  py_heapq = import_helper.import_fresh_module('heapq', blocked=['_heapq'])
      13  c_heapq = import_helper.import_fresh_module('heapq', fresh=['_heapq'])
      14  
      15  # _heapq.nlargest/nsmallest are saved in heapq._nlargest/_smallest when
      16  # _heapq is imported, so check them there
      17  func_names = ['heapify', 'heappop', 'heappush', 'heappushpop', 'heapreplace',
      18                '_heappop_max', '_heapreplace_max', '_heapify_max']
      19  
      20  class ESC[4;38;5;81mTestModules(ESC[4;38;5;149mTestCase):
      21      def test_py_functions(self):
      22          for fname in func_names:
      23              self.assertEqual(getattr(py_heapq, fname).__module__, 'heapq')
      24  
      25      @skipUnless(c_heapq, 'requires _heapq')
      26      def test_c_functions(self):
      27          for fname in func_names:
      28              self.assertEqual(getattr(c_heapq, fname).__module__, '_heapq')
      29  
      30  
      31  def load_tests(loader, tests, ignore):
      32      # The 'merge' function has examples in its docstring which we should test
      33      # with 'doctest'.
      34      #
      35      # However, doctest can't easily find all docstrings in the module (loading
      36      # it through import_fresh_module seems to confuse it), so we specifically
      37      # create a finder which returns the doctests from the merge method.
      38  
      39      class ESC[4;38;5;81mHeapqMergeDocTestFinder:
      40          def find(self, *args, **kwargs):
      41              dtf = doctest.DocTestFinder()
      42              return dtf.find(py_heapq.merge)
      43  
      44      tests.addTests(doctest.DocTestSuite(py_heapq,
      45                                          test_finder=HeapqMergeDocTestFinder()))
      46      return tests
      47  
      48  class ESC[4;38;5;81mTestHeap:
      49  
      50      def test_push_pop(self):
      51          # 1) Push 256 random numbers and pop them off, verifying all's OK.
      52          heap = []
      53          data = []
      54          self.check_invariant(heap)
      55          for i in range(256):
      56              item = random.random()
      57              data.append(item)
      58              self.module.heappush(heap, item)
      59              self.check_invariant(heap)
      60          results = []
      61          while heap:
      62              item = self.module.heappop(heap)
      63              self.check_invariant(heap)
      64              results.append(item)
      65          data_sorted = data[:]
      66          data_sorted.sort()
      67          self.assertEqual(data_sorted, results)
      68          # 2) Check that the invariant holds for a sorted array
      69          self.check_invariant(results)
      70  
      71          self.assertRaises(TypeError, self.module.heappush, [])
      72          try:
      73              self.assertRaises(TypeError, self.module.heappush, None, None)
      74              self.assertRaises(TypeError, self.module.heappop, None)
      75          except AttributeError:
      76              pass
      77  
      78      def check_invariant(self, heap):
      79          # Check the heap invariant.
      80          for pos, item in enumerate(heap):
      81              if pos: # pos 0 has no parent
      82                  parentpos = (pos-1) >> 1
      83                  self.assertTrue(heap[parentpos] <= item)
      84  
      85      def test_heapify(self):
      86          for size in list(range(30)) + [20000]:
      87              heap = [random.random() for dummy in range(size)]
      88              self.module.heapify(heap)
      89              self.check_invariant(heap)
      90  
      91          self.assertRaises(TypeError, self.module.heapify, None)
      92  
      93      def test_naive_nbest(self):
      94          data = [random.randrange(2000) for i in range(1000)]
      95          heap = []
      96          for item in data:
      97              self.module.heappush(heap, item)
      98              if len(heap) > 10:
      99                  self.module.heappop(heap)
     100          heap.sort()
     101          self.assertEqual(heap, sorted(data)[-10:])
     102  
     103      def heapiter(self, heap):
     104          # An iterator returning a heap's elements, smallest-first.
     105          try:
     106              while 1:
     107                  yield self.module.heappop(heap)
     108          except IndexError:
     109              pass
     110  
     111      def test_nbest(self):
     112          # Less-naive "N-best" algorithm, much faster (if len(data) is big
     113          # enough <wink>) than sorting all of data.  However, if we had a max
     114          # heap instead of a min heap, it could go faster still via
     115          # heapify'ing all of data (linear time), then doing 10 heappops
     116          # (10 log-time steps).
     117          data = [random.randrange(2000) for i in range(1000)]
     118          heap = data[:10]
     119          self.module.heapify(heap)
     120          for item in data[10:]:
     121              if item > heap[0]:  # this gets rarer the longer we run
     122                  self.module.heapreplace(heap, item)
     123          self.assertEqual(list(self.heapiter(heap)), sorted(data)[-10:])
     124  
     125          self.assertRaises(TypeError, self.module.heapreplace, None)
     126          self.assertRaises(TypeError, self.module.heapreplace, None, None)
     127          self.assertRaises(IndexError, self.module.heapreplace, [], None)
     128  
     129      def test_nbest_with_pushpop(self):
     130          data = [random.randrange(2000) for i in range(1000)]
     131          heap = data[:10]
     132          self.module.heapify(heap)
     133          for item in data[10:]:
     134              self.module.heappushpop(heap, item)
     135          self.assertEqual(list(self.heapiter(heap)), sorted(data)[-10:])
     136          self.assertEqual(self.module.heappushpop([], 'x'), 'x')
     137  
     138      def test_heappushpop(self):
     139          h = []
     140          x = self.module.heappushpop(h, 10)
     141          self.assertEqual((h, x), ([], 10))
     142  
     143          h = [10]
     144          x = self.module.heappushpop(h, 10.0)
     145          self.assertEqual((h, x), ([10], 10.0))
     146          self.assertEqual(type(h[0]), int)
     147          self.assertEqual(type(x), float)
     148  
     149          h = [10]
     150          x = self.module.heappushpop(h, 9)
     151          self.assertEqual((h, x), ([10], 9))
     152  
     153          h = [10]
     154          x = self.module.heappushpop(h, 11)
     155          self.assertEqual((h, x), ([11], 10))
     156  
     157      def test_heappop_max(self):
     158          # _heapop_max has an optimization for one-item lists which isn't
     159          # covered in other tests, so test that case explicitly here
     160          h = [3, 2]
     161          self.assertEqual(self.module._heappop_max(h), 3)
     162          self.assertEqual(self.module._heappop_max(h), 2)
     163  
     164      def test_heapsort(self):
     165          # Exercise everything with repeated heapsort checks
     166          for trial in range(100):
     167              size = random.randrange(50)
     168              data = [random.randrange(25) for i in range(size)]
     169              if trial & 1:     # Half of the time, use heapify
     170                  heap = data[:]
     171                  self.module.heapify(heap)
     172              else:             # The rest of the time, use heappush
     173                  heap = []
     174                  for item in data:
     175                      self.module.heappush(heap, item)
     176              heap_sorted = [self.module.heappop(heap) for i in range(size)]
     177              self.assertEqual(heap_sorted, sorted(data))
     178  
     179      def test_merge(self):
     180          inputs = []
     181          for i in range(random.randrange(25)):
     182              row = []
     183              for j in range(random.randrange(100)):
     184                  tup = random.choice('ABC'), random.randrange(-500, 500)
     185                  row.append(tup)
     186              inputs.append(row)
     187  
     188          for key in [None, itemgetter(0), itemgetter(1), itemgetter(1, 0)]:
     189              for reverse in [False, True]:
     190                  seqs = []
     191                  for seq in inputs:
     192                      seqs.append(sorted(seq, key=key, reverse=reverse))
     193                  self.assertEqual(sorted(chain(*inputs), key=key, reverse=reverse),
     194                                   list(self.module.merge(*seqs, key=key, reverse=reverse)))
     195                  self.assertEqual(list(self.module.merge()), [])
     196  
     197      def test_empty_merges(self):
     198          # Merging two empty lists (with or without a key) should produce
     199          # another empty list.
     200          self.assertEqual(list(self.module.merge([], [])), [])
     201          self.assertEqual(list(self.module.merge([], [], key=lambda: 6)), [])
     202  
     203      def test_merge_does_not_suppress_index_error(self):
     204          # Issue 19018: Heapq.merge suppresses IndexError from user generator
     205          def iterable():
     206              s = list(range(10))
     207              for i in range(20):
     208                  yield s[i]       # IndexError when i > 10
     209          with self.assertRaises(IndexError):
     210              list(self.module.merge(iterable(), iterable()))
     211  
     212      def test_merge_stability(self):
     213          class ESC[4;38;5;81mInt(ESC[4;38;5;149mint):
     214              pass
     215          inputs = [[], [], [], []]
     216          for i in range(20000):
     217              stream = random.randrange(4)
     218              x = random.randrange(500)
     219              obj = Int(x)
     220              obj.pair = (x, stream)
     221              inputs[stream].append(obj)
     222          for stream in inputs:
     223              stream.sort()
     224          result = [i.pair for i in self.module.merge(*inputs)]
     225          self.assertEqual(result, sorted(result))
     226  
     227      def test_nsmallest(self):
     228          data = [(random.randrange(2000), i) for i in range(1000)]
     229          for f in (None, lambda x:  x[0] * 547 % 2000):
     230              for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
     231                  self.assertEqual(list(self.module.nsmallest(n, data)),
     232                                   sorted(data)[:n])
     233                  self.assertEqual(list(self.module.nsmallest(n, data, key=f)),
     234                                   sorted(data, key=f)[:n])
     235  
     236      def test_nlargest(self):
     237          data = [(random.randrange(2000), i) for i in range(1000)]
     238          for f in (None, lambda x:  x[0] * 547 % 2000):
     239              for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
     240                  self.assertEqual(list(self.module.nlargest(n, data)),
     241                                   sorted(data, reverse=True)[:n])
     242                  self.assertEqual(list(self.module.nlargest(n, data, key=f)),
     243                                   sorted(data, key=f, reverse=True)[:n])
     244  
     245      def test_comparison_operator(self):
     246          # Issue 3051: Make sure heapq works with both __lt__
     247          # For python 3.0, __le__ alone is not enough
     248          def hsort(data, comp):
     249              data = [comp(x) for x in data]
     250              self.module.heapify(data)
     251              return [self.module.heappop(data).x for i in range(len(data))]
     252          class ESC[4;38;5;81mLT:
     253              def __init__(self, x):
     254                  self.x = x
     255              def __lt__(self, other):
     256                  return self.x > other.x
     257          class ESC[4;38;5;81mLE:
     258              def __init__(self, x):
     259                  self.x = x
     260              def __le__(self, other):
     261                  return self.x >= other.x
     262          data = [random.random() for i in range(100)]
     263          target = sorted(data, reverse=True)
     264          self.assertEqual(hsort(data, LT), target)
     265          self.assertRaises(TypeError, data, LE)
     266  
     267  
     268  class ESC[4;38;5;81mTestHeapPython(ESC[4;38;5;149mTestHeap, ESC[4;38;5;149mTestCase):
     269      module = py_heapq
     270  
     271  
     272  @skipUnless(c_heapq, 'requires _heapq')
     273  class ESC[4;38;5;81mTestHeapC(ESC[4;38;5;149mTestHeap, ESC[4;38;5;149mTestCase):
     274      module = c_heapq
     275  
     276  
     277  #==============================================================================
     278  
     279  class ESC[4;38;5;81mLenOnly:
     280      "Dummy sequence class defining __len__ but not __getitem__."
     281      def __len__(self):
     282          return 10
     283  
     284  class ESC[4;38;5;81mCmpErr:
     285      "Dummy element that always raises an error during comparison"
     286      def __eq__(self, other):
     287          raise ZeroDivisionError
     288      __ne__ = __lt__ = __le__ = __gt__ = __ge__ = __eq__
     289  
     290  def R(seqn):
     291      'Regular generator'
     292      for i in seqn:
     293          yield i
     294  
     295  class ESC[4;38;5;81mG:
     296      'Sequence using __getitem__'
     297      def __init__(self, seqn):
     298          self.seqn = seqn
     299      def __getitem__(self, i):
     300          return self.seqn[i]
     301  
     302  class ESC[4;38;5;81mI:
     303      'Sequence using iterator protocol'
     304      def __init__(self, seqn):
     305          self.seqn = seqn
     306          self.i = 0
     307      def __iter__(self):
     308          return self
     309      def __next__(self):
     310          if self.i >= len(self.seqn): raise StopIteration
     311          v = self.seqn[self.i]
     312          self.i += 1
     313          return v
     314  
     315  class ESC[4;38;5;81mIg:
     316      'Sequence using iterator protocol defined with a generator'
     317      def __init__(self, seqn):
     318          self.seqn = seqn
     319          self.i = 0
     320      def __iter__(self):
     321          for val in self.seqn:
     322              yield val
     323  
     324  class ESC[4;38;5;81mX:
     325      'Missing __getitem__ and __iter__'
     326      def __init__(self, seqn):
     327          self.seqn = seqn
     328          self.i = 0
     329      def __next__(self):
     330          if self.i >= len(self.seqn): raise StopIteration
     331          v = self.seqn[self.i]
     332          self.i += 1
     333          return v
     334  
     335  class ESC[4;38;5;81mN:
     336      'Iterator missing __next__()'
     337      def __init__(self, seqn):
     338          self.seqn = seqn
     339          self.i = 0
     340      def __iter__(self):
     341          return self
     342  
     343  class ESC[4;38;5;81mE:
     344      'Test propagation of exceptions'
     345      def __init__(self, seqn):
     346          self.seqn = seqn
     347          self.i = 0
     348      def __iter__(self):
     349          return self
     350      def __next__(self):
     351          3 // 0
     352  
     353  class ESC[4;38;5;81mS:
     354      'Test immediate stop'
     355      def __init__(self, seqn):
     356          pass
     357      def __iter__(self):
     358          return self
     359      def __next__(self):
     360          raise StopIteration
     361  
     362  from itertools import chain
     363  def L(seqn):
     364      'Test multiple tiers of iterators'
     365      return chain(map(lambda x:x, R(Ig(G(seqn)))))
     366  
     367  
     368  class ESC[4;38;5;81mSideEffectLT:
     369      def __init__(self, value, heap):
     370          self.value = value
     371          self.heap = heap
     372  
     373      def __lt__(self, other):
     374          self.heap[:] = []
     375          return self.value < other.value
     376  
     377  
     378  class ESC[4;38;5;81mTestErrorHandling:
     379  
     380      def test_non_sequence(self):
     381          for f in (self.module.heapify, self.module.heappop):
     382              self.assertRaises((TypeError, AttributeError), f, 10)
     383          for f in (self.module.heappush, self.module.heapreplace,
     384                    self.module.nlargest, self.module.nsmallest):
     385              self.assertRaises((TypeError, AttributeError), f, 10, 10)
     386  
     387      def test_len_only(self):
     388          for f in (self.module.heapify, self.module.heappop):
     389              self.assertRaises((TypeError, AttributeError), f, LenOnly())
     390          for f in (self.module.heappush, self.module.heapreplace):
     391              self.assertRaises((TypeError, AttributeError), f, LenOnly(), 10)
     392          for f in (self.module.nlargest, self.module.nsmallest):
     393              self.assertRaises(TypeError, f, 2, LenOnly())
     394  
     395      def test_cmp_err(self):
     396          seq = [CmpErr(), CmpErr(), CmpErr()]
     397          for f in (self.module.heapify, self.module.heappop):
     398              self.assertRaises(ZeroDivisionError, f, seq)
     399          for f in (self.module.heappush, self.module.heapreplace):
     400              self.assertRaises(ZeroDivisionError, f, seq, 10)
     401          for f in (self.module.nlargest, self.module.nsmallest):
     402              self.assertRaises(ZeroDivisionError, f, 2, seq)
     403  
     404      def test_arg_parsing(self):
     405          for f in (self.module.heapify, self.module.heappop,
     406                    self.module.heappush, self.module.heapreplace,
     407                    self.module.nlargest, self.module.nsmallest):
     408              self.assertRaises((TypeError, AttributeError), f, 10)
     409  
     410      def test_iterable_args(self):
     411          for f in (self.module.nlargest, self.module.nsmallest):
     412              for s in ("123", "", range(1000), (1, 1.2), range(2000,2200,5)):
     413                  for g in (G, I, Ig, L, R):
     414                      self.assertEqual(list(f(2, g(s))), list(f(2,s)))
     415                  self.assertEqual(list(f(2, S(s))), [])
     416                  self.assertRaises(TypeError, f, 2, X(s))
     417                  self.assertRaises(TypeError, f, 2, N(s))
     418                  self.assertRaises(ZeroDivisionError, f, 2, E(s))
     419  
     420      # Issue #17278: the heap may change size while it's being walked.
     421  
     422      def test_heappush_mutating_heap(self):
     423          heap = []
     424          heap.extend(SideEffectLT(i, heap) for i in range(200))
     425          # Python version raises IndexError, C version RuntimeError
     426          with self.assertRaises((IndexError, RuntimeError)):
     427              self.module.heappush(heap, SideEffectLT(5, heap))
     428  
     429      def test_heappop_mutating_heap(self):
     430          heap = []
     431          heap.extend(SideEffectLT(i, heap) for i in range(200))
     432          # Python version raises IndexError, C version RuntimeError
     433          with self.assertRaises((IndexError, RuntimeError)):
     434              self.module.heappop(heap)
     435  
     436      def test_comparison_operator_modifiying_heap(self):
     437          # See bpo-39421: Strong references need to be taken
     438          # when comparing objects as they can alter the heap
     439          class ESC[4;38;5;81mEvilClass(ESC[4;38;5;149mint):
     440              def __lt__(self, o):
     441                  heap.clear()
     442                  return NotImplemented
     443  
     444          heap = []
     445          self.module.heappush(heap, EvilClass(0))
     446          self.assertRaises(IndexError, self.module.heappushpop, heap, 1)
     447  
     448      def test_comparison_operator_modifiying_heap_two_heaps(self):
     449  
     450          class ESC[4;38;5;81mh(ESC[4;38;5;149mint):
     451              def __lt__(self, o):
     452                  list2.clear()
     453                  return NotImplemented
     454  
     455          class ESC[4;38;5;81mg(ESC[4;38;5;149mint):
     456              def __lt__(self, o):
     457                  list1.clear()
     458                  return NotImplemented
     459  
     460          list1, list2 = [], []
     461  
     462          self.module.heappush(list1, h(0))
     463          self.module.heappush(list2, g(0))
     464  
     465          self.assertRaises((IndexError, RuntimeError), self.module.heappush, list1, g(1))
     466          self.assertRaises((IndexError, RuntimeError), self.module.heappush, list2, h(1))
     467  
     468  class ESC[4;38;5;81mTestErrorHandlingPython(ESC[4;38;5;149mTestErrorHandling, ESC[4;38;5;149mTestCase):
     469      module = py_heapq
     470  
     471  @skipUnless(c_heapq, 'requires _heapq')
     472  class ESC[4;38;5;81mTestErrorHandlingC(ESC[4;38;5;149mTestErrorHandling, ESC[4;38;5;149mTestCase):
     473      module = c_heapq
     474  
     475  
     476  if __name__ == "__main__":
     477      unittest.main()