(root)/
Python-3.12.0/
Lib/
test/
test_iter.py
       1  # Test iterators.
       2  
       3  import sys
       4  import unittest
       5  from test.support import cpython_only
       6  from test.support.os_helper import TESTFN, unlink
       7  from test.support import check_free_after_iterating, ALWAYS_EQ, NEVER_EQ
       8  import pickle
       9  import collections.abc
      10  import functools
      11  import contextlib
      12  import builtins
      13  
      14  # Test result of triple loop (too big to inline)
      15  TRIPLETS = [(0, 0, 0), (0, 0, 1), (0, 0, 2),
      16              (0, 1, 0), (0, 1, 1), (0, 1, 2),
      17              (0, 2, 0), (0, 2, 1), (0, 2, 2),
      18  
      19              (1, 0, 0), (1, 0, 1), (1, 0, 2),
      20              (1, 1, 0), (1, 1, 1), (1, 1, 2),
      21              (1, 2, 0), (1, 2, 1), (1, 2, 2),
      22  
      23              (2, 0, 0), (2, 0, 1), (2, 0, 2),
      24              (2, 1, 0), (2, 1, 1), (2, 1, 2),
      25              (2, 2, 0), (2, 2, 1), (2, 2, 2)]
      26  
      27  # Helper classes
      28  
      29  class ESC[4;38;5;81mBasicIterClass:
      30      def __init__(self, n):
      31          self.n = n
      32          self.i = 0
      33      def __next__(self):
      34          res = self.i
      35          if res >= self.n:
      36              raise StopIteration
      37          self.i = res + 1
      38          return res
      39      def __iter__(self):
      40          return self
      41  
      42  class ESC[4;38;5;81mIteratingSequenceClass:
      43      def __init__(self, n):
      44          self.n = n
      45      def __iter__(self):
      46          return BasicIterClass(self.n)
      47  
      48  class ESC[4;38;5;81mIteratorProxyClass:
      49      def __init__(self, i):
      50          self.i = i
      51      def __next__(self):
      52          return next(self.i)
      53      def __iter__(self):
      54          return self
      55  
      56  class ESC[4;38;5;81mSequenceClass:
      57      def __init__(self, n):
      58          self.n = n
      59      def __getitem__(self, i):
      60          if 0 <= i < self.n:
      61              return i
      62          else:
      63              raise IndexError
      64  
      65  class ESC[4;38;5;81mSequenceProxyClass:
      66      def __init__(self, s):
      67          self.s = s
      68      def __getitem__(self, i):
      69          return self.s[i]
      70  
      71  class ESC[4;38;5;81mUnlimitedSequenceClass:
      72      def __getitem__(self, i):
      73          return i
      74  
      75  class ESC[4;38;5;81mDefaultIterClass:
      76      pass
      77  
      78  class ESC[4;38;5;81mNoIterClass:
      79      def __getitem__(self, i):
      80          return i
      81      __iter__ = None
      82  
      83  class ESC[4;38;5;81mBadIterableClass:
      84      def __iter__(self):
      85          raise ZeroDivisionError
      86  
      87  class ESC[4;38;5;81mCallableIterClass:
      88      def __init__(self):
      89          self.i = 0
      90      def __call__(self):
      91          i = self.i
      92          self.i = i + 1
      93          if i > 100:
      94              raise IndexError # Emergency stop
      95          return i
      96  
      97  class ESC[4;38;5;81mEmptyIterClass:
      98      def __len__(self):
      99          return 0
     100      def __getitem__(self, i):
     101          raise StopIteration
     102  
     103  # Main test suite
     104  
     105  class ESC[4;38;5;81mTestCase(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
     106  
     107      # Helper to check that an iterator returns a given sequence
     108      def check_iterator(self, it, seq, pickle=True):
     109          if pickle:
     110              self.check_pickle(it, seq)
     111          res = []
     112          while 1:
     113              try:
     114                  val = next(it)
     115              except StopIteration:
     116                  break
     117              res.append(val)
     118          self.assertEqual(res, seq)
     119  
     120      # Helper to check that a for loop generates a given sequence
     121      def check_for_loop(self, expr, seq, pickle=True):
     122          if pickle:
     123              self.check_pickle(iter(expr), seq)
     124          res = []
     125          for val in expr:
     126              res.append(val)
     127          self.assertEqual(res, seq)
     128  
     129      # Helper to check picklability
     130      def check_pickle(self, itorg, seq):
     131          for proto in range(pickle.HIGHEST_PROTOCOL + 1):
     132              d = pickle.dumps(itorg, proto)
     133              it = pickle.loads(d)
     134              # Cannot assert type equality because dict iterators unpickle as list
     135              # iterators.
     136              # self.assertEqual(type(itorg), type(it))
     137              self.assertTrue(isinstance(it, collections.abc.Iterator))
     138              self.assertEqual(list(it), seq)
     139  
     140              it = pickle.loads(d)
     141              try:
     142                  next(it)
     143              except StopIteration:
     144                  continue
     145              d = pickle.dumps(it, proto)
     146              it = pickle.loads(d)
     147              self.assertEqual(list(it), seq[1:])
     148  
     149      # Test basic use of iter() function
     150      def test_iter_basic(self):
     151          self.check_iterator(iter(range(10)), list(range(10)))
     152  
     153      # Test that iter(iter(x)) is the same as iter(x)
     154      def test_iter_idempotency(self):
     155          seq = list(range(10))
     156          it = iter(seq)
     157          it2 = iter(it)
     158          self.assertTrue(it is it2)
     159  
     160      # Test that for loops over iterators work
     161      def test_iter_for_loop(self):
     162          self.check_for_loop(iter(range(10)), list(range(10)))
     163  
     164      # Test several independent iterators over the same list
     165      def test_iter_independence(self):
     166          seq = range(3)
     167          res = []
     168          for i in iter(seq):
     169              for j in iter(seq):
     170                  for k in iter(seq):
     171                      res.append((i, j, k))
     172          self.assertEqual(res, TRIPLETS)
     173  
     174      # Test triple list comprehension using iterators
     175      def test_nested_comprehensions_iter(self):
     176          seq = range(3)
     177          res = [(i, j, k)
     178                 for i in iter(seq) for j in iter(seq) for k in iter(seq)]
     179          self.assertEqual(res, TRIPLETS)
     180  
     181      # Test triple list comprehension without iterators
     182      def test_nested_comprehensions_for(self):
     183          seq = range(3)
     184          res = [(i, j, k) for i in seq for j in seq for k in seq]
     185          self.assertEqual(res, TRIPLETS)
     186  
     187      # Test a class with __iter__ in a for loop
     188      def test_iter_class_for(self):
     189          self.check_for_loop(IteratingSequenceClass(10), list(range(10)))
     190  
     191      # Test a class with __iter__ with explicit iter()
     192      def test_iter_class_iter(self):
     193          self.check_iterator(iter(IteratingSequenceClass(10)), list(range(10)))
     194  
     195      # Test for loop on a sequence class without __iter__
     196      def test_seq_class_for(self):
     197          self.check_for_loop(SequenceClass(10), list(range(10)))
     198  
     199      # Test iter() on a sequence class without __iter__
     200      def test_seq_class_iter(self):
     201          self.check_iterator(iter(SequenceClass(10)), list(range(10)))
     202  
     203      def test_mutating_seq_class_iter_pickle(self):
     204          orig = SequenceClass(5)
     205          for proto in range(pickle.HIGHEST_PROTOCOL + 1):
     206              # initial iterator
     207              itorig = iter(orig)
     208              d = pickle.dumps((itorig, orig), proto)
     209              it, seq = pickle.loads(d)
     210              seq.n = 7
     211              self.assertIs(type(it), type(itorig))
     212              self.assertEqual(list(it), list(range(7)))
     213  
     214              # running iterator
     215              next(itorig)
     216              d = pickle.dumps((itorig, orig), proto)
     217              it, seq = pickle.loads(d)
     218              seq.n = 7
     219              self.assertIs(type(it), type(itorig))
     220              self.assertEqual(list(it), list(range(1, 7)))
     221  
     222              # empty iterator
     223              for i in range(1, 5):
     224                  next(itorig)
     225              d = pickle.dumps((itorig, orig), proto)
     226              it, seq = pickle.loads(d)
     227              seq.n = 7
     228              self.assertIs(type(it), type(itorig))
     229              self.assertEqual(list(it), list(range(5, 7)))
     230  
     231              # exhausted iterator
     232              self.assertRaises(StopIteration, next, itorig)
     233              d = pickle.dumps((itorig, orig), proto)
     234              it, seq = pickle.loads(d)
     235              seq.n = 7
     236              self.assertTrue(isinstance(it, collections.abc.Iterator))
     237              self.assertEqual(list(it), [])
     238  
     239      def test_mutating_seq_class_exhausted_iter(self):
     240          a = SequenceClass(5)
     241          exhit = iter(a)
     242          empit = iter(a)
     243          for x in exhit:  # exhaust the iterator
     244              next(empit)  # not exhausted
     245          a.n = 7
     246          self.assertEqual(list(exhit), [])
     247          self.assertEqual(list(empit), [5, 6])
     248          self.assertEqual(list(a), [0, 1, 2, 3, 4, 5, 6])
     249  
     250      def test_reduce_mutating_builtins_iter(self):
     251          # This is a reproducer of issue #101765
     252          # where iter `__reduce__` calls could lead to a segfault or SystemError
     253          # depending on the order of C argument evaluation, which is undefined
     254  
     255          # Backup builtins
     256          builtins_dict = builtins.__dict__
     257          orig = {"iter": iter, "reversed": reversed}
     258  
     259          def run(builtin_name, item, sentinel=None):
     260              it = iter(item) if sentinel is None else iter(item, sentinel)
     261  
     262              class ESC[4;38;5;81mCustomStr:
     263                  def __init__(self, name, iterator):
     264                      self.name = name
     265                      self.iterator = iterator
     266                  def __hash__(self):
     267                      return hash(self.name)
     268                  def __eq__(self, other):
     269                      # Here we exhaust our iterator, possibly changing
     270                      # its `it_seq` pointer to NULL
     271                      # The `__reduce__` call should correctly get
     272                      # the pointers after this call
     273                      list(self.iterator)
     274                      return other == self.name
     275  
     276              # del is required here
     277              # to not prematurely call __eq__ from
     278              # the hash collision with the old key
     279              del builtins_dict[builtin_name]
     280              builtins_dict[CustomStr(builtin_name, it)] = orig[builtin_name]
     281  
     282              return it.__reduce__()
     283  
     284          types = [
     285              (EmptyIterClass(),),
     286              (bytes(8),),
     287              (bytearray(8),),
     288              ((1, 2, 3),),
     289              (lambda: 0, 0),
     290              (tuple[int],)  # GenericAlias
     291          ]
     292  
     293          try:
     294              run_iter = functools.partial(run, "iter")
     295              # The returned value of `__reduce__` should not only be valid
     296              # but also *empty*, as `it` was exhausted during `__eq__`
     297              # i.e "xyz" returns (iter, ("",))
     298              self.assertEqual(run_iter("xyz"), (orig["iter"], ("",)))
     299              self.assertEqual(run_iter([1, 2, 3]), (orig["iter"], ([],)))
     300  
     301              # _PyEval_GetBuiltin is also called for `reversed` in a branch of
     302              # listiter_reduce_general
     303              self.assertEqual(
     304                  run("reversed", orig["reversed"](list(range(8)))),
     305                  (iter, ([],))
     306              )
     307  
     308              for case in types:
     309                  self.assertEqual(run_iter(*case), (orig["iter"], ((),)))
     310          finally:
     311              # Restore original builtins
     312              for key, func in orig.items():
     313                  # need to suppress KeyErrors in case
     314                  # a failed test deletes the key without setting anything
     315                  with contextlib.suppress(KeyError):
     316                      # del is required here
     317                      # to not invoke our custom __eq__ from
     318                      # the hash collision with the old key
     319                      del builtins_dict[key]
     320                  builtins_dict[key] = func
     321  
     322      # Test a new_style class with __iter__ but no next() method
     323      def test_new_style_iter_class(self):
     324          class ESC[4;38;5;81mIterClass(ESC[4;38;5;149mobject):
     325              def __iter__(self):
     326                  return self
     327          self.assertRaises(TypeError, iter, IterClass())
     328  
     329      # Test two-argument iter() with callable instance
     330      def test_iter_callable(self):
     331          self.check_iterator(iter(CallableIterClass(), 10), list(range(10)), pickle=True)
     332  
     333      # Test two-argument iter() with function
     334      def test_iter_function(self):
     335          def spam(state=[0]):
     336              i = state[0]
     337              state[0] = i+1
     338              return i
     339          self.check_iterator(iter(spam, 10), list(range(10)), pickle=False)
     340  
     341      # Test two-argument iter() with function that raises StopIteration
     342      def test_iter_function_stop(self):
     343          def spam(state=[0]):
     344              i = state[0]
     345              if i == 10:
     346                  raise StopIteration
     347              state[0] = i+1
     348              return i
     349          self.check_iterator(iter(spam, 20), list(range(10)), pickle=False)
     350  
     351      def test_iter_function_concealing_reentrant_exhaustion(self):
     352          # gh-101892: Test two-argument iter() with a function that
     353          # exhausts its associated iterator but forgets to either return
     354          # a sentinel value or raise StopIteration.
     355          HAS_MORE = 1
     356          NO_MORE = 2
     357  
     358          def exhaust(iterator):
     359              """Exhaust an iterator without raising StopIteration."""
     360              list(iterator)
     361  
     362          def spam():
     363              # Touching the iterator with exhaust() below will call
     364              # spam() once again so protect against recursion.
     365              if spam.is_recursive_call:
     366                  return NO_MORE
     367              spam.is_recursive_call = True
     368              exhaust(spam.iterator)
     369              return HAS_MORE
     370  
     371          spam.is_recursive_call = False
     372          spam.iterator = iter(spam, NO_MORE)
     373          with self.assertRaises(StopIteration):
     374              next(spam.iterator)
     375  
     376      # Test exception propagation through function iterator
     377      def test_exception_function(self):
     378          def spam(state=[0]):
     379              i = state[0]
     380              state[0] = i+1
     381              if i == 10:
     382                  raise RuntimeError
     383              return i
     384          res = []
     385          try:
     386              for x in iter(spam, 20):
     387                  res.append(x)
     388          except RuntimeError:
     389              self.assertEqual(res, list(range(10)))
     390          else:
     391              self.fail("should have raised RuntimeError")
     392  
     393      # Test exception propagation through sequence iterator
     394      def test_exception_sequence(self):
     395          class ESC[4;38;5;81mMySequenceClass(ESC[4;38;5;149mSequenceClass):
     396              def __getitem__(self, i):
     397                  if i == 10:
     398                      raise RuntimeError
     399                  return SequenceClass.__getitem__(self, i)
     400          res = []
     401          try:
     402              for x in MySequenceClass(20):
     403                  res.append(x)
     404          except RuntimeError:
     405              self.assertEqual(res, list(range(10)))
     406          else:
     407              self.fail("should have raised RuntimeError")
     408  
     409      # Test for StopIteration from __getitem__
     410      def test_stop_sequence(self):
     411          class ESC[4;38;5;81mMySequenceClass(ESC[4;38;5;149mSequenceClass):
     412              def __getitem__(self, i):
     413                  if i == 10:
     414                      raise StopIteration
     415                  return SequenceClass.__getitem__(self, i)
     416          self.check_for_loop(MySequenceClass(20), list(range(10)), pickle=False)
     417  
     418      # Test a big range
     419      def test_iter_big_range(self):
     420          self.check_for_loop(iter(range(10000)), list(range(10000)))
     421  
     422      # Test an empty list
     423      def test_iter_empty(self):
     424          self.check_for_loop(iter([]), [])
     425  
     426      # Test a tuple
     427      def test_iter_tuple(self):
     428          self.check_for_loop(iter((0,1,2,3,4,5,6,7,8,9)), list(range(10)))
     429  
     430      # Test a range
     431      def test_iter_range(self):
     432          self.check_for_loop(iter(range(10)), list(range(10)))
     433  
     434      # Test a string
     435      def test_iter_string(self):
     436          self.check_for_loop(iter("abcde"), ["a", "b", "c", "d", "e"])
     437  
     438      # Test a directory
     439      def test_iter_dict(self):
     440          dict = {}
     441          for i in range(10):
     442              dict[i] = None
     443          self.check_for_loop(dict, list(dict.keys()))
     444  
     445      # Test a file
     446      def test_iter_file(self):
     447          f = open(TESTFN, "w", encoding="utf-8")
     448          try:
     449              for i in range(5):
     450                  f.write("%d\n" % i)
     451          finally:
     452              f.close()
     453          f = open(TESTFN, "r", encoding="utf-8")
     454          try:
     455              self.check_for_loop(f, ["0\n", "1\n", "2\n", "3\n", "4\n"], pickle=False)
     456              self.check_for_loop(f, [], pickle=False)
     457          finally:
     458              f.close()
     459              try:
     460                  unlink(TESTFN)
     461              except OSError:
     462                  pass
     463  
     464      # Test list()'s use of iterators.
     465      def test_builtin_list(self):
     466          self.assertEqual(list(SequenceClass(5)), list(range(5)))
     467          self.assertEqual(list(SequenceClass(0)), [])
     468          self.assertEqual(list(()), [])
     469  
     470          d = {"one": 1, "two": 2, "three": 3}
     471          self.assertEqual(list(d), list(d.keys()))
     472  
     473          self.assertRaises(TypeError, list, list)
     474          self.assertRaises(TypeError, list, 42)
     475  
     476          f = open(TESTFN, "w", encoding="utf-8")
     477          try:
     478              for i in range(5):
     479                  f.write("%d\n" % i)
     480          finally:
     481              f.close()
     482          f = open(TESTFN, "r", encoding="utf-8")
     483          try:
     484              self.assertEqual(list(f), ["0\n", "1\n", "2\n", "3\n", "4\n"])
     485              f.seek(0, 0)
     486              self.assertEqual(list(f),
     487                               ["0\n", "1\n", "2\n", "3\n", "4\n"])
     488          finally:
     489              f.close()
     490              try:
     491                  unlink(TESTFN)
     492              except OSError:
     493                  pass
     494  
     495      # Test tuples()'s use of iterators.
     496      def test_builtin_tuple(self):
     497          self.assertEqual(tuple(SequenceClass(5)), (0, 1, 2, 3, 4))
     498          self.assertEqual(tuple(SequenceClass(0)), ())
     499          self.assertEqual(tuple([]), ())
     500          self.assertEqual(tuple(()), ())
     501          self.assertEqual(tuple("abc"), ("a", "b", "c"))
     502  
     503          d = {"one": 1, "two": 2, "three": 3}
     504          self.assertEqual(tuple(d), tuple(d.keys()))
     505  
     506          self.assertRaises(TypeError, tuple, list)
     507          self.assertRaises(TypeError, tuple, 42)
     508  
     509          f = open(TESTFN, "w", encoding="utf-8")
     510          try:
     511              for i in range(5):
     512                  f.write("%d\n" % i)
     513          finally:
     514              f.close()
     515          f = open(TESTFN, "r", encoding="utf-8")
     516          try:
     517              self.assertEqual(tuple(f), ("0\n", "1\n", "2\n", "3\n", "4\n"))
     518              f.seek(0, 0)
     519              self.assertEqual(tuple(f),
     520                               ("0\n", "1\n", "2\n", "3\n", "4\n"))
     521          finally:
     522              f.close()
     523              try:
     524                  unlink(TESTFN)
     525              except OSError:
     526                  pass
     527  
     528      # Test filter()'s use of iterators.
     529      def test_builtin_filter(self):
     530          self.assertEqual(list(filter(None, SequenceClass(5))),
     531                           list(range(1, 5)))
     532          self.assertEqual(list(filter(None, SequenceClass(0))), [])
     533          self.assertEqual(list(filter(None, ())), [])
     534          self.assertEqual(list(filter(None, "abc")), ["a", "b", "c"])
     535  
     536          d = {"one": 1, "two": 2, "three": 3}
     537          self.assertEqual(list(filter(None, d)), list(d.keys()))
     538  
     539          self.assertRaises(TypeError, filter, None, list)
     540          self.assertRaises(TypeError, filter, None, 42)
     541  
     542          class ESC[4;38;5;81mBoolean:
     543              def __init__(self, truth):
     544                  self.truth = truth
     545              def __bool__(self):
     546                  return self.truth
     547          bTrue = Boolean(True)
     548          bFalse = Boolean(False)
     549  
     550          class ESC[4;38;5;81mSeq:
     551              def __init__(self, *args):
     552                  self.vals = args
     553              def __iter__(self):
     554                  class ESC[4;38;5;81mSeqIter:
     555                      def __init__(self, vals):
     556                          self.vals = vals
     557                          self.i = 0
     558                      def __iter__(self):
     559                          return self
     560                      def __next__(self):
     561                          i = self.i
     562                          self.i = i + 1
     563                          if i < len(self.vals):
     564                              return self.vals[i]
     565                          else:
     566                              raise StopIteration
     567                  return SeqIter(self.vals)
     568  
     569          seq = Seq(*([bTrue, bFalse] * 25))
     570          self.assertEqual(list(filter(lambda x: not x, seq)), [bFalse]*25)
     571          self.assertEqual(list(filter(lambda x: not x, iter(seq))), [bFalse]*25)
     572  
     573      # Test max() and min()'s use of iterators.
     574      def test_builtin_max_min(self):
     575          self.assertEqual(max(SequenceClass(5)), 4)
     576          self.assertEqual(min(SequenceClass(5)), 0)
     577          self.assertEqual(max(8, -1), 8)
     578          self.assertEqual(min(8, -1), -1)
     579  
     580          d = {"one": 1, "two": 2, "three": 3}
     581          self.assertEqual(max(d), "two")
     582          self.assertEqual(min(d), "one")
     583          self.assertEqual(max(d.values()), 3)
     584          self.assertEqual(min(iter(d.values())), 1)
     585  
     586          f = open(TESTFN, "w", encoding="utf-8")
     587          try:
     588              f.write("medium line\n")
     589              f.write("xtra large line\n")
     590              f.write("itty-bitty line\n")
     591          finally:
     592              f.close()
     593          f = open(TESTFN, "r", encoding="utf-8")
     594          try:
     595              self.assertEqual(min(f), "itty-bitty line\n")
     596              f.seek(0, 0)
     597              self.assertEqual(max(f), "xtra large line\n")
     598          finally:
     599              f.close()
     600              try:
     601                  unlink(TESTFN)
     602              except OSError:
     603                  pass
     604  
     605      # Test map()'s use of iterators.
     606      def test_builtin_map(self):
     607          self.assertEqual(list(map(lambda x: x+1, SequenceClass(5))),
     608                           list(range(1, 6)))
     609  
     610          d = {"one": 1, "two": 2, "three": 3}
     611          self.assertEqual(list(map(lambda k, d=d: (k, d[k]), d)),
     612                           list(d.items()))
     613          dkeys = list(d.keys())
     614          expected = [(i < len(d) and dkeys[i] or None,
     615                       i,
     616                       i < len(d) and dkeys[i] or None)
     617                      for i in range(3)]
     618  
     619          f = open(TESTFN, "w", encoding="utf-8")
     620          try:
     621              for i in range(10):
     622                  f.write("xy" * i + "\n") # line i has len 2*i+1
     623          finally:
     624              f.close()
     625          f = open(TESTFN, "r", encoding="utf-8")
     626          try:
     627              self.assertEqual(list(map(len, f)), list(range(1, 21, 2)))
     628          finally:
     629              f.close()
     630              try:
     631                  unlink(TESTFN)
     632              except OSError:
     633                  pass
     634  
     635      # Test zip()'s use of iterators.
     636      def test_builtin_zip(self):
     637          self.assertEqual(list(zip()), [])
     638          self.assertEqual(list(zip(*[])), [])
     639          self.assertEqual(list(zip(*[(1, 2), 'ab'])), [(1, 'a'), (2, 'b')])
     640  
     641          self.assertRaises(TypeError, zip, None)
     642          self.assertRaises(TypeError, zip, range(10), 42)
     643          self.assertRaises(TypeError, zip, range(10), zip)
     644  
     645          self.assertEqual(list(zip(IteratingSequenceClass(3))),
     646                           [(0,), (1,), (2,)])
     647          self.assertEqual(list(zip(SequenceClass(3))),
     648                           [(0,), (1,), (2,)])
     649  
     650          d = {"one": 1, "two": 2, "three": 3}
     651          self.assertEqual(list(d.items()), list(zip(d, d.values())))
     652  
     653          # Generate all ints starting at constructor arg.
     654          class ESC[4;38;5;81mIntsFrom:
     655              def __init__(self, start):
     656                  self.i = start
     657  
     658              def __iter__(self):
     659                  return self
     660  
     661              def __next__(self):
     662                  i = self.i
     663                  self.i = i+1
     664                  return i
     665  
     666          f = open(TESTFN, "w", encoding="utf-8")
     667          try:
     668              f.write("a\n" "bbb\n" "cc\n")
     669          finally:
     670              f.close()
     671          f = open(TESTFN, "r", encoding="utf-8")
     672          try:
     673              self.assertEqual(list(zip(IntsFrom(0), f, IntsFrom(-100))),
     674                               [(0, "a\n", -100),
     675                                (1, "bbb\n", -99),
     676                                (2, "cc\n", -98)])
     677          finally:
     678              f.close()
     679              try:
     680                  unlink(TESTFN)
     681              except OSError:
     682                  pass
     683  
     684          self.assertEqual(list(zip(range(5))), [(i,) for i in range(5)])
     685  
     686          # Classes that lie about their lengths.
     687          class ESC[4;38;5;81mNoGuessLen5:
     688              def __getitem__(self, i):
     689                  if i >= 5:
     690                      raise IndexError
     691                  return i
     692  
     693          class ESC[4;38;5;81mGuess3Len5(ESC[4;38;5;149mNoGuessLen5):
     694              def __len__(self):
     695                  return 3
     696  
     697          class ESC[4;38;5;81mGuess30Len5(ESC[4;38;5;149mNoGuessLen5):
     698              def __len__(self):
     699                  return 30
     700  
     701          def lzip(*args):
     702              return list(zip(*args))
     703  
     704          self.assertEqual(len(Guess3Len5()), 3)
     705          self.assertEqual(len(Guess30Len5()), 30)
     706          self.assertEqual(lzip(NoGuessLen5()), lzip(range(5)))
     707          self.assertEqual(lzip(Guess3Len5()), lzip(range(5)))
     708          self.assertEqual(lzip(Guess30Len5()), lzip(range(5)))
     709  
     710          expected = [(i, i) for i in range(5)]
     711          for x in NoGuessLen5(), Guess3Len5(), Guess30Len5():
     712              for y in NoGuessLen5(), Guess3Len5(), Guess30Len5():
     713                  self.assertEqual(lzip(x, y), expected)
     714  
     715      def test_unicode_join_endcase(self):
     716  
     717          # This class inserts a Unicode object into its argument's natural
     718          # iteration, in the 3rd position.
     719          class ESC[4;38;5;81mOhPhooey:
     720              def __init__(self, seq):
     721                  self.it = iter(seq)
     722                  self.i = 0
     723  
     724              def __iter__(self):
     725                  return self
     726  
     727              def __next__(self):
     728                  i = self.i
     729                  self.i = i+1
     730                  if i == 2:
     731                      return "fooled you!"
     732                  return next(self.it)
     733  
     734          f = open(TESTFN, "w", encoding="utf-8")
     735          try:
     736              f.write("a\n" + "b\n" + "c\n")
     737          finally:
     738              f.close()
     739  
     740          f = open(TESTFN, "r", encoding="utf-8")
     741          # Nasty:  string.join(s) can't know whether unicode.join() is needed
     742          # until it's seen all of s's elements.  But in this case, f's
     743          # iterator cannot be restarted.  So what we're testing here is
     744          # whether string.join() can manage to remember everything it's seen
     745          # and pass that on to unicode.join().
     746          try:
     747              got = " - ".join(OhPhooey(f))
     748              self.assertEqual(got, "a\n - b\n - fooled you! - c\n")
     749          finally:
     750              f.close()
     751              try:
     752                  unlink(TESTFN)
     753              except OSError:
     754                  pass
     755  
     756      # Test iterators with 'x in y' and 'x not in y'.
     757      def test_in_and_not_in(self):
     758          for sc5 in IteratingSequenceClass(5), SequenceClass(5):
     759              for i in range(5):
     760                  self.assertIn(i, sc5)
     761              for i in "abc", -1, 5, 42.42, (3, 4), [], {1: 1}, 3-12j, sc5:
     762                  self.assertNotIn(i, sc5)
     763  
     764          self.assertIn(ALWAYS_EQ, IteratorProxyClass(iter([1])))
     765          self.assertIn(ALWAYS_EQ, SequenceProxyClass([1]))
     766          self.assertNotIn(ALWAYS_EQ, IteratorProxyClass(iter([NEVER_EQ])))
     767          self.assertNotIn(ALWAYS_EQ, SequenceProxyClass([NEVER_EQ]))
     768          self.assertIn(NEVER_EQ, IteratorProxyClass(iter([ALWAYS_EQ])))
     769          self.assertIn(NEVER_EQ, SequenceProxyClass([ALWAYS_EQ]))
     770  
     771          self.assertRaises(TypeError, lambda: 3 in 12)
     772          self.assertRaises(TypeError, lambda: 3 not in map)
     773          self.assertRaises(ZeroDivisionError, lambda: 3 in BadIterableClass())
     774  
     775          d = {"one": 1, "two": 2, "three": 3, 1j: 2j}
     776          for k in d:
     777              self.assertIn(k, d)
     778              self.assertNotIn(k, d.values())
     779          for v in d.values():
     780              self.assertIn(v, d.values())
     781              self.assertNotIn(v, d)
     782          for k, v in d.items():
     783              self.assertIn((k, v), d.items())
     784              self.assertNotIn((v, k), d.items())
     785  
     786          f = open(TESTFN, "w", encoding="utf-8")
     787          try:
     788              f.write("a\n" "b\n" "c\n")
     789          finally:
     790              f.close()
     791          f = open(TESTFN, "r", encoding="utf-8")
     792          try:
     793              for chunk in "abc":
     794                  f.seek(0, 0)
     795                  self.assertNotIn(chunk, f)
     796                  f.seek(0, 0)
     797                  self.assertIn((chunk + "\n"), f)
     798          finally:
     799              f.close()
     800              try:
     801                  unlink(TESTFN)
     802              except OSError:
     803                  pass
     804  
     805      # Test iterators with operator.countOf (PySequence_Count).
     806      def test_countOf(self):
     807          from operator import countOf
     808          self.assertEqual(countOf([1,2,2,3,2,5], 2), 3)
     809          self.assertEqual(countOf((1,2,2,3,2,5), 2), 3)
     810          self.assertEqual(countOf("122325", "2"), 3)
     811          self.assertEqual(countOf("122325", "6"), 0)
     812  
     813          self.assertRaises(TypeError, countOf, 42, 1)
     814          self.assertRaises(TypeError, countOf, countOf, countOf)
     815  
     816          d = {"one": 3, "two": 3, "three": 3, 1j: 2j}
     817          for k in d:
     818              self.assertEqual(countOf(d, k), 1)
     819          self.assertEqual(countOf(d.values(), 3), 3)
     820          self.assertEqual(countOf(d.values(), 2j), 1)
     821          self.assertEqual(countOf(d.values(), 1j), 0)
     822  
     823          f = open(TESTFN, "w", encoding="utf-8")
     824          try:
     825              f.write("a\n" "b\n" "c\n" "b\n")
     826          finally:
     827              f.close()
     828          f = open(TESTFN, "r", encoding="utf-8")
     829          try:
     830              for letter, count in ("a", 1), ("b", 2), ("c", 1), ("d", 0):
     831                  f.seek(0, 0)
     832                  self.assertEqual(countOf(f, letter + "\n"), count)
     833          finally:
     834              f.close()
     835              try:
     836                  unlink(TESTFN)
     837              except OSError:
     838                  pass
     839  
     840      # Test iterators with operator.indexOf (PySequence_Index).
     841      def test_indexOf(self):
     842          from operator import indexOf
     843          self.assertEqual(indexOf([1,2,2,3,2,5], 1), 0)
     844          self.assertEqual(indexOf((1,2,2,3,2,5), 2), 1)
     845          self.assertEqual(indexOf((1,2,2,3,2,5), 3), 3)
     846          self.assertEqual(indexOf((1,2,2,3,2,5), 5), 5)
     847          self.assertRaises(ValueError, indexOf, (1,2,2,3,2,5), 0)
     848          self.assertRaises(ValueError, indexOf, (1,2,2,3,2,5), 6)
     849  
     850          self.assertEqual(indexOf("122325", "2"), 1)
     851          self.assertEqual(indexOf("122325", "5"), 5)
     852          self.assertRaises(ValueError, indexOf, "122325", "6")
     853  
     854          self.assertRaises(TypeError, indexOf, 42, 1)
     855          self.assertRaises(TypeError, indexOf, indexOf, indexOf)
     856          self.assertRaises(ZeroDivisionError, indexOf, BadIterableClass(), 1)
     857  
     858          f = open(TESTFN, "w", encoding="utf-8")
     859          try:
     860              f.write("a\n" "b\n" "c\n" "d\n" "e\n")
     861          finally:
     862              f.close()
     863          f = open(TESTFN, "r", encoding="utf-8")
     864          try:
     865              fiter = iter(f)
     866              self.assertEqual(indexOf(fiter, "b\n"), 1)
     867              self.assertEqual(indexOf(fiter, "d\n"), 1)
     868              self.assertEqual(indexOf(fiter, "e\n"), 0)
     869              self.assertRaises(ValueError, indexOf, fiter, "a\n")
     870          finally:
     871              f.close()
     872              try:
     873                  unlink(TESTFN)
     874              except OSError:
     875                  pass
     876  
     877          iclass = IteratingSequenceClass(3)
     878          for i in range(3):
     879              self.assertEqual(indexOf(iclass, i), i)
     880          self.assertRaises(ValueError, indexOf, iclass, -1)
     881  
     882      # Test iterators with file.writelines().
     883      def test_writelines(self):
     884          f = open(TESTFN, "w", encoding="utf-8")
     885  
     886          try:
     887              self.assertRaises(TypeError, f.writelines, None)
     888              self.assertRaises(TypeError, f.writelines, 42)
     889  
     890              f.writelines(["1\n", "2\n"])
     891              f.writelines(("3\n", "4\n"))
     892              f.writelines({'5\n': None})
     893              f.writelines({})
     894  
     895              # Try a big chunk too.
     896              class ESC[4;38;5;81mIterator:
     897                  def __init__(self, start, finish):
     898                      self.start = start
     899                      self.finish = finish
     900                      self.i = self.start
     901  
     902                  def __next__(self):
     903                      if self.i >= self.finish:
     904                          raise StopIteration
     905                      result = str(self.i) + '\n'
     906                      self.i += 1
     907                      return result
     908  
     909                  def __iter__(self):
     910                      return self
     911  
     912              class ESC[4;38;5;81mWhatever:
     913                  def __init__(self, start, finish):
     914                      self.start = start
     915                      self.finish = finish
     916  
     917                  def __iter__(self):
     918                      return Iterator(self.start, self.finish)
     919  
     920              f.writelines(Whatever(6, 6+2000))
     921              f.close()
     922  
     923              f = open(TESTFN, encoding="utf-8")
     924              expected = [str(i) + "\n" for i in range(1, 2006)]
     925              self.assertEqual(list(f), expected)
     926  
     927          finally:
     928              f.close()
     929              try:
     930                  unlink(TESTFN)
     931              except OSError:
     932                  pass
     933  
     934  
     935      # Test iterators on RHS of unpacking assignments.
     936      def test_unpack_iter(self):
     937          a, b = 1, 2
     938          self.assertEqual((a, b), (1, 2))
     939  
     940          a, b, c = IteratingSequenceClass(3)
     941          self.assertEqual((a, b, c), (0, 1, 2))
     942  
     943          try:    # too many values
     944              a, b = IteratingSequenceClass(3)
     945          except ValueError:
     946              pass
     947          else:
     948              self.fail("should have raised ValueError")
     949  
     950          try:    # not enough values
     951              a, b, c = IteratingSequenceClass(2)
     952          except ValueError:
     953              pass
     954          else:
     955              self.fail("should have raised ValueError")
     956  
     957          try:    # not iterable
     958              a, b, c = len
     959          except TypeError:
     960              pass
     961          else:
     962              self.fail("should have raised TypeError")
     963  
     964          a, b, c = {1: 42, 2: 42, 3: 42}.values()
     965          self.assertEqual((a, b, c), (42, 42, 42))
     966  
     967          f = open(TESTFN, "w", encoding="utf-8")
     968          lines = ("a\n", "bb\n", "ccc\n")
     969          try:
     970              for line in lines:
     971                  f.write(line)
     972          finally:
     973              f.close()
     974          f = open(TESTFN, "r", encoding="utf-8")
     975          try:
     976              a, b, c = f
     977              self.assertEqual((a, b, c), lines)
     978          finally:
     979              f.close()
     980              try:
     981                  unlink(TESTFN)
     982              except OSError:
     983                  pass
     984  
     985          (a, b), (c,) = IteratingSequenceClass(2), {42: 24}
     986          self.assertEqual((a, b, c), (0, 1, 42))
     987  
     988  
     989      @cpython_only
     990      def test_ref_counting_behavior(self):
     991          class ESC[4;38;5;81mC(ESC[4;38;5;149mobject):
     992              count = 0
     993              def __new__(cls):
     994                  cls.count += 1
     995                  return object.__new__(cls)
     996              def __del__(self):
     997                  cls = self.__class__
     998                  assert cls.count > 0
     999                  cls.count -= 1
    1000          x = C()
    1001          self.assertEqual(C.count, 1)
    1002          del x
    1003          self.assertEqual(C.count, 0)
    1004          l = [C(), C(), C()]
    1005          self.assertEqual(C.count, 3)
    1006          try:
    1007              a, b = iter(l)
    1008          except ValueError:
    1009              pass
    1010          del l
    1011          self.assertEqual(C.count, 0)
    1012  
    1013  
    1014      # Make sure StopIteration is a "sink state".
    1015      # This tests various things that weren't sink states in Python 2.2.1,
    1016      # plus various things that always were fine.
    1017  
    1018      def test_sinkstate_list(self):
    1019          # This used to fail
    1020          a = list(range(5))
    1021          b = iter(a)
    1022          self.assertEqual(list(b), list(range(5)))
    1023          a.extend(range(5, 10))
    1024          self.assertEqual(list(b), [])
    1025  
    1026      def test_sinkstate_tuple(self):
    1027          a = (0, 1, 2, 3, 4)
    1028          b = iter(a)
    1029          self.assertEqual(list(b), list(range(5)))
    1030          self.assertEqual(list(b), [])
    1031  
    1032      def test_sinkstate_string(self):
    1033          a = "abcde"
    1034          b = iter(a)
    1035          self.assertEqual(list(b), ['a', 'b', 'c', 'd', 'e'])
    1036          self.assertEqual(list(b), [])
    1037  
    1038      def test_sinkstate_sequence(self):
    1039          # This used to fail
    1040          a = SequenceClass(5)
    1041          b = iter(a)
    1042          self.assertEqual(list(b), list(range(5)))
    1043          a.n = 10
    1044          self.assertEqual(list(b), [])
    1045  
    1046      def test_sinkstate_callable(self):
    1047          # This used to fail
    1048          def spam(state=[0]):
    1049              i = state[0]
    1050              state[0] = i+1
    1051              if i == 10:
    1052                  raise AssertionError("shouldn't have gotten this far")
    1053              return i
    1054          b = iter(spam, 5)
    1055          self.assertEqual(list(b), list(range(5)))
    1056          self.assertEqual(list(b), [])
    1057  
    1058      def test_sinkstate_dict(self):
    1059          # XXX For a more thorough test, see towards the end of:
    1060          # http://mail.python.org/pipermail/python-dev/2002-July/026512.html
    1061          a = {1:1, 2:2, 0:0, 4:4, 3:3}
    1062          for b in iter(a), a.keys(), a.items(), a.values():
    1063              b = iter(a)
    1064              self.assertEqual(len(list(b)), 5)
    1065              self.assertEqual(list(b), [])
    1066  
    1067      def test_sinkstate_yield(self):
    1068          def gen():
    1069              for i in range(5):
    1070                  yield i
    1071          b = gen()
    1072          self.assertEqual(list(b), list(range(5)))
    1073          self.assertEqual(list(b), [])
    1074  
    1075      def test_sinkstate_range(self):
    1076          a = range(5)
    1077          b = iter(a)
    1078          self.assertEqual(list(b), list(range(5)))
    1079          self.assertEqual(list(b), [])
    1080  
    1081      def test_sinkstate_enumerate(self):
    1082          a = range(5)
    1083          e = enumerate(a)
    1084          b = iter(e)
    1085          self.assertEqual(list(b), list(zip(range(5), range(5))))
    1086          self.assertEqual(list(b), [])
    1087  
    1088      def test_3720(self):
    1089          # Avoid a crash, when an iterator deletes its next() method.
    1090          class ESC[4;38;5;81mBadIterator(ESC[4;38;5;149mobject):
    1091              def __iter__(self):
    1092                  return self
    1093              def __next__(self):
    1094                  del BadIterator.__next__
    1095                  return 1
    1096  
    1097          try:
    1098              for i in BadIterator() :
    1099                  pass
    1100          except TypeError:
    1101              pass
    1102  
    1103      def test_extending_list_with_iterator_does_not_segfault(self):
    1104          # The code to extend a list with an iterator has a fair
    1105          # amount of nontrivial logic in terms of guessing how
    1106          # much memory to allocate in advance, "stealing" refs,
    1107          # and then shrinking at the end.  This is a basic smoke
    1108          # test for that scenario.
    1109          def gen():
    1110              for i in range(500):
    1111                  yield i
    1112          lst = [0] * 500
    1113          for i in range(240):
    1114              lst.pop(0)
    1115          lst.extend(gen())
    1116          self.assertEqual(len(lst), 760)
    1117  
    1118      @cpython_only
    1119      def test_iter_overflow(self):
    1120          # Test for the issue 22939
    1121          it = iter(UnlimitedSequenceClass())
    1122          # Manually set `it_index` to PY_SSIZE_T_MAX-2 without a loop
    1123          it.__setstate__(sys.maxsize - 2)
    1124          self.assertEqual(next(it), sys.maxsize - 2)
    1125          self.assertEqual(next(it), sys.maxsize - 1)
    1126          with self.assertRaises(OverflowError):
    1127              next(it)
    1128          # Check that Overflow error is always raised
    1129          with self.assertRaises(OverflowError):
    1130              next(it)
    1131  
    1132      def test_iter_neg_setstate(self):
    1133          it = iter(UnlimitedSequenceClass())
    1134          it.__setstate__(-42)
    1135          self.assertEqual(next(it), 0)
    1136          self.assertEqual(next(it), 1)
    1137  
    1138      def test_free_after_iterating(self):
    1139          check_free_after_iterating(self, iter, SequenceClass, (0,))
    1140  
    1141      def test_error_iter(self):
    1142          for typ in (DefaultIterClass, NoIterClass):
    1143              self.assertRaises(TypeError, iter, typ())
    1144          self.assertRaises(ZeroDivisionError, iter, BadIterableClass())
    1145  
    1146  
    1147  if __name__ == "__main__":
    1148      unittest.main()