1  """ Test suite for the code in fixer_util """
       2  
       3  # Testing imports
       4  from . import support
       5  
       6  # Local imports
       7  from lib2to3.pytree import Node, Leaf
       8  from lib2to3 import fixer_util
       9  from lib2to3.fixer_util import Attr, Name, Call, Comma
      10  from lib2to3.pgen2 import token
      11  
      12  def parse(code, strip_levels=0):
      13      # The topmost node is file_input, which we don't care about.
      14      # The next-topmost node is a *_stmt node, which we also don't care about
      15      tree = support.parse_string(code)
      16      for i in range(strip_levels):
      17          tree = tree.children[0]
      18      tree.parent = None
      19      return tree
      20  
      21  class ESC[4;38;5;81mMacroTestCase(ESC[4;38;5;149msupportESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
      22      def assertStr(self, node, string):
      23          if isinstance(node, (tuple, list)):
      24              node = Node(fixer_util.syms.simple_stmt, node)
      25          self.assertEqual(str(node), string)
      26  
      27  
      28  class ESC[4;38;5;81mTest_is_tuple(ESC[4;38;5;149msupportESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
      29      def is_tuple(self, string):
      30          return fixer_util.is_tuple(parse(string, strip_levels=2))
      31  
      32      def test_valid(self):
      33          self.assertTrue(self.is_tuple("(a, b)"))
      34          self.assertTrue(self.is_tuple("(a, (b, c))"))
      35          self.assertTrue(self.is_tuple("((a, (b, c)),)"))
      36          self.assertTrue(self.is_tuple("(a,)"))
      37          self.assertTrue(self.is_tuple("()"))
      38  
      39      def test_invalid(self):
      40          self.assertFalse(self.is_tuple("(a)"))
      41          self.assertFalse(self.is_tuple("('foo') % (b, c)"))
      42  
      43  
      44  class ESC[4;38;5;81mTest_is_list(ESC[4;38;5;149msupportESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
      45      def is_list(self, string):
      46          return fixer_util.is_list(parse(string, strip_levels=2))
      47  
      48      def test_valid(self):
      49          self.assertTrue(self.is_list("[]"))
      50          self.assertTrue(self.is_list("[a]"))
      51          self.assertTrue(self.is_list("[a, b]"))
      52          self.assertTrue(self.is_list("[a, [b, c]]"))
      53          self.assertTrue(self.is_list("[[a, [b, c]],]"))
      54  
      55      def test_invalid(self):
      56          self.assertFalse(self.is_list("[]+[]"))
      57  
      58  
      59  class ESC[4;38;5;81mTest_Attr(ESC[4;38;5;149mMacroTestCase):
      60      def test(self):
      61          call = parse("foo()", strip_levels=2)
      62  
      63          self.assertStr(Attr(Name("a"), Name("b")), "a.b")
      64          self.assertStr(Attr(call, Name("b")), "foo().b")
      65  
      66      def test_returns(self):
      67          attr = Attr(Name("a"), Name("b"))
      68          self.assertEqual(type(attr), list)
      69  
      70  
      71  class ESC[4;38;5;81mTest_Name(ESC[4;38;5;149mMacroTestCase):
      72      def test(self):
      73          self.assertStr(Name("a"), "a")
      74          self.assertStr(Name("foo.foo().bar"), "foo.foo().bar")
      75          self.assertStr(Name("a", prefix="b"), "ba")
      76  
      77  
      78  class ESC[4;38;5;81mTest_Call(ESC[4;38;5;149mMacroTestCase):
      79      def _Call(self, name, args=None, prefix=None):
      80          """Help the next test"""
      81          children = []
      82          if isinstance(args, list):
      83              for arg in args:
      84                  children.append(arg)
      85                  children.append(Comma())
      86              children.pop()
      87          return Call(Name(name), children, prefix)
      88  
      89      def test(self):
      90          kids = [None,
      91                  [Leaf(token.NUMBER, 1), Leaf(token.NUMBER, 2),
      92                   Leaf(token.NUMBER, 3)],
      93                  [Leaf(token.NUMBER, 1), Leaf(token.NUMBER, 3),
      94                   Leaf(token.NUMBER, 2), Leaf(token.NUMBER, 4)],
      95                  [Leaf(token.STRING, "b"), Leaf(token.STRING, "j", prefix=" ")]
      96                  ]
      97          self.assertStr(self._Call("A"), "A()")
      98          self.assertStr(self._Call("b", kids[1]), "b(1,2,3)")
      99          self.assertStr(self._Call("a.b().c", kids[2]), "a.b().c(1,3,2,4)")
     100          self.assertStr(self._Call("d", kids[3], prefix=" "), " d(b, j)")
     101  
     102  
     103  class ESC[4;38;5;81mTest_does_tree_import(ESC[4;38;5;149msupportESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
     104      def _find_bind_rec(self, name, node):
     105          # Search a tree for a binding -- used to find the starting
     106          # point for these tests.
     107          c = fixer_util.find_binding(name, node)
     108          if c: return c
     109          for child in node.children:
     110              c = self._find_bind_rec(name, child)
     111              if c: return c
     112  
     113      def does_tree_import(self, package, name, string):
     114          node = parse(string)
     115          # Find the binding of start -- that's what we'll go from
     116          node = self._find_bind_rec('start', node)
     117          return fixer_util.does_tree_import(package, name, node)
     118  
     119      def try_with(self, string):
     120          failing_tests = (("a", "a", "from a import b"),
     121                           ("a.d", "a", "from a.d import b"),
     122                           ("d.a", "a", "from d.a import b"),
     123                           (None, "a", "import b"),
     124                           (None, "a", "import b, c, d"))
     125          for package, name, import_ in failing_tests:
     126              n = self.does_tree_import(package, name, import_ + "\n" + string)
     127              self.assertFalse(n)
     128              n = self.does_tree_import(package, name, string + "\n" + import_)
     129              self.assertFalse(n)
     130  
     131          passing_tests = (("a", "a", "from a import a"),
     132                           ("x", "a", "from x import a"),
     133                           ("x", "a", "from x import b, c, a, d"),
     134                           ("x.b", "a", "from x.b import a"),
     135                           ("x.b", "a", "from x.b import b, c, a, d"),
     136                           (None, "a", "import a"),
     137                           (None, "a", "import b, c, a, d"))
     138          for package, name, import_ in passing_tests:
     139              n = self.does_tree_import(package, name, import_ + "\n" + string)
     140              self.assertTrue(n)
     141              n = self.does_tree_import(package, name, string + "\n" + import_)
     142              self.assertTrue(n)
     143  
     144      def test_in_function(self):
     145          self.try_with("def foo():\n\tbar.baz()\n\tstart=3")
     146  
     147  class ESC[4;38;5;81mTest_find_binding(ESC[4;38;5;149msupportESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
     148      def find_binding(self, name, string, package=None):
     149          return fixer_util.find_binding(name, parse(string), package)
     150  
     151      def test_simple_assignment(self):
     152          self.assertTrue(self.find_binding("a", "a = b"))
     153          self.assertTrue(self.find_binding("a", "a = [b, c, d]"))
     154          self.assertTrue(self.find_binding("a", "a = foo()"))
     155          self.assertTrue(self.find_binding("a", "a = foo().foo.foo[6][foo]"))
     156          self.assertFalse(self.find_binding("a", "foo = a"))
     157          self.assertFalse(self.find_binding("a", "foo = (a, b, c)"))
     158  
     159      def test_tuple_assignment(self):
     160          self.assertTrue(self.find_binding("a", "(a,) = b"))
     161          self.assertTrue(self.find_binding("a", "(a, b, c) = [b, c, d]"))
     162          self.assertTrue(self.find_binding("a", "(c, (d, a), b) = foo()"))
     163          self.assertTrue(self.find_binding("a", "(a, b) = foo().foo[6][foo]"))
     164          self.assertFalse(self.find_binding("a", "(foo, b) = (b, a)"))
     165          self.assertFalse(self.find_binding("a", "(foo, (b, c)) = (a, b, c)"))
     166  
     167      def test_list_assignment(self):
     168          self.assertTrue(self.find_binding("a", "[a] = b"))
     169          self.assertTrue(self.find_binding("a", "[a, b, c] = [b, c, d]"))
     170          self.assertTrue(self.find_binding("a", "[c, [d, a], b] = foo()"))
     171          self.assertTrue(self.find_binding("a", "[a, b] = foo().foo[a][foo]"))
     172          self.assertFalse(self.find_binding("a", "[foo, b] = (b, a)"))
     173          self.assertFalse(self.find_binding("a", "[foo, [b, c]] = (a, b, c)"))
     174  
     175      def test_invalid_assignments(self):
     176          self.assertFalse(self.find_binding("a", "foo.a = 5"))
     177          self.assertFalse(self.find_binding("a", "foo[a] = 5"))
     178          self.assertFalse(self.find_binding("a", "foo(a) = 5"))
     179          self.assertFalse(self.find_binding("a", "foo(a, b) = 5"))
     180  
     181      def test_simple_import(self):
     182          self.assertTrue(self.find_binding("a", "import a"))
     183          self.assertTrue(self.find_binding("a", "import b, c, a, d"))
     184          self.assertFalse(self.find_binding("a", "import b"))
     185          self.assertFalse(self.find_binding("a", "import b, c, d"))
     186  
     187      def test_from_import(self):
     188          self.assertTrue(self.find_binding("a", "from x import a"))
     189          self.assertTrue(self.find_binding("a", "from a import a"))
     190          self.assertTrue(self.find_binding("a", "from x import b, c, a, d"))
     191          self.assertTrue(self.find_binding("a", "from x.b import a"))
     192          self.assertTrue(self.find_binding("a", "from x.b import b, c, a, d"))
     193          self.assertFalse(self.find_binding("a", "from a import b"))
     194          self.assertFalse(self.find_binding("a", "from a.d import b"))
     195          self.assertFalse(self.find_binding("a", "from d.a import b"))
     196  
     197      def test_import_as(self):
     198          self.assertTrue(self.find_binding("a", "import b as a"))
     199          self.assertTrue(self.find_binding("a", "import b as a, c, a as f, d"))
     200          self.assertFalse(self.find_binding("a", "import a as f"))
     201          self.assertFalse(self.find_binding("a", "import b, c as f, d as e"))
     202  
     203      def test_from_import_as(self):
     204          self.assertTrue(self.find_binding("a", "from x import b as a"))
     205          self.assertTrue(self.find_binding("a", "from x import g as a, d as b"))
     206          self.assertTrue(self.find_binding("a", "from x.b import t as a"))
     207          self.assertTrue(self.find_binding("a", "from x.b import g as a, d"))
     208          self.assertFalse(self.find_binding("a", "from a import b as t"))
     209          self.assertFalse(self.find_binding("a", "from a.d import b as t"))
     210          self.assertFalse(self.find_binding("a", "from d.a import b as t"))
     211  
     212      def test_simple_import_with_package(self):
     213          self.assertTrue(self.find_binding("b", "import b"))
     214          self.assertTrue(self.find_binding("b", "import b, c, d"))
     215          self.assertFalse(self.find_binding("b", "import b", "b"))
     216          self.assertFalse(self.find_binding("b", "import b, c, d", "c"))
     217  
     218      def test_from_import_with_package(self):
     219          self.assertTrue(self.find_binding("a", "from x import a", "x"))
     220          self.assertTrue(self.find_binding("a", "from a import a", "a"))
     221          self.assertTrue(self.find_binding("a", "from x import *", "x"))
     222          self.assertTrue(self.find_binding("a", "from x import b, c, a, d", "x"))
     223          self.assertTrue(self.find_binding("a", "from x.b import a", "x.b"))
     224          self.assertTrue(self.find_binding("a", "from x.b import *", "x.b"))
     225          self.assertTrue(self.find_binding("a", "from x.b import b, c, a, d", "x.b"))
     226          self.assertFalse(self.find_binding("a", "from a import b", "a"))
     227          self.assertFalse(self.find_binding("a", "from a.d import b", "a.d"))
     228          self.assertFalse(self.find_binding("a", "from d.a import b", "a.d"))
     229          self.assertFalse(self.find_binding("a", "from x.y import *", "a.b"))
     230  
     231      def test_import_as_with_package(self):
     232          self.assertFalse(self.find_binding("a", "import b.c as a", "b.c"))
     233          self.assertFalse(self.find_binding("a", "import a as f", "f"))
     234          self.assertFalse(self.find_binding("a", "import a as f", "a"))
     235  
     236      def test_from_import_as_with_package(self):
     237          # Because it would take a lot of special-case code in the fixers
     238          # to deal with from foo import bar as baz, we'll simply always
     239          # fail if there is an "from ... import ... as ..."
     240          self.assertFalse(self.find_binding("a", "from x import b as a", "x"))
     241          self.assertFalse(self.find_binding("a", "from x import g as a, d as b", "x"))
     242          self.assertFalse(self.find_binding("a", "from x.b import t as a", "x.b"))
     243          self.assertFalse(self.find_binding("a", "from x.b import g as a, d", "x.b"))
     244          self.assertFalse(self.find_binding("a", "from a import b as t", "a"))
     245          self.assertFalse(self.find_binding("a", "from a import b as t", "b"))
     246          self.assertFalse(self.find_binding("a", "from a import b as t", "t"))
     247  
     248      def test_function_def(self):
     249          self.assertTrue(self.find_binding("a", "def a(): pass"))
     250          self.assertTrue(self.find_binding("a", "def a(b, c, d): pass"))
     251          self.assertTrue(self.find_binding("a", "def a(): b = 7"))
     252          self.assertFalse(self.find_binding("a", "def d(b, (c, a), e): pass"))
     253          self.assertFalse(self.find_binding("a", "def d(a=7): pass"))
     254          self.assertFalse(self.find_binding("a", "def d(a): pass"))
     255          self.assertFalse(self.find_binding("a", "def d(): a = 7"))
     256  
     257          s = """
     258              def d():
     259                  def a():
     260                      pass"""
     261          self.assertFalse(self.find_binding("a", s))
     262  
     263      def test_class_def(self):
     264          self.assertTrue(self.find_binding("a", "class a: pass"))
     265          self.assertTrue(self.find_binding("a", "class a(): pass"))
     266          self.assertTrue(self.find_binding("a", "class a(b): pass"))
     267          self.assertTrue(self.find_binding("a", "class a(b, c=8): pass"))
     268          self.assertFalse(self.find_binding("a", "class d: pass"))
     269          self.assertFalse(self.find_binding("a", "class d(a): pass"))
     270          self.assertFalse(self.find_binding("a", "class d(b, a=7): pass"))
     271          self.assertFalse(self.find_binding("a", "class d(b, *a): pass"))
     272          self.assertFalse(self.find_binding("a", "class d(b, **a): pass"))
     273          self.assertFalse(self.find_binding("a", "class d: a = 7"))
     274  
     275          s = """
     276              class d():
     277                  class a():
     278                      pass"""
     279          self.assertFalse(self.find_binding("a", s))
     280  
     281      def test_for(self):
     282          self.assertTrue(self.find_binding("a", "for a in r: pass"))
     283          self.assertTrue(self.find_binding("a", "for a, b in r: pass"))
     284          self.assertTrue(self.find_binding("a", "for (a, b) in r: pass"))
     285          self.assertTrue(self.find_binding("a", "for c, (a,) in r: pass"))
     286          self.assertTrue(self.find_binding("a", "for c, (a, b) in r: pass"))
     287          self.assertTrue(self.find_binding("a", "for c in r: a = c"))
     288          self.assertFalse(self.find_binding("a", "for c in a: pass"))
     289  
     290      def test_for_nested(self):
     291          s = """
     292              for b in r:
     293                  for a in b:
     294                      pass"""
     295          self.assertTrue(self.find_binding("a", s))
     296  
     297          s = """
     298              for b in r:
     299                  for a, c in b:
     300                      pass"""
     301          self.assertTrue(self.find_binding("a", s))
     302  
     303          s = """
     304              for b in r:
     305                  for (a, c) in b:
     306                      pass"""
     307          self.assertTrue(self.find_binding("a", s))
     308  
     309          s = """
     310              for b in r:
     311                  for (a,) in b:
     312                      pass"""
     313          self.assertTrue(self.find_binding("a", s))
     314  
     315          s = """
     316              for b in r:
     317                  for c, (a, d) in b:
     318                      pass"""
     319          self.assertTrue(self.find_binding("a", s))
     320  
     321          s = """
     322              for b in r:
     323                  for c in b:
     324                      a = 7"""
     325          self.assertTrue(self.find_binding("a", s))
     326  
     327          s = """
     328              for b in r:
     329                  for c in b:
     330                      d = a"""
     331          self.assertFalse(self.find_binding("a", s))
     332  
     333          s = """
     334              for b in r:
     335                  for c in a:
     336                      d = 7"""
     337          self.assertFalse(self.find_binding("a", s))
     338  
     339      def test_if(self):
     340          self.assertTrue(self.find_binding("a", "if b in r: a = c"))
     341          self.assertFalse(self.find_binding("a", "if a in r: d = e"))
     342  
     343      def test_if_nested(self):
     344          s = """
     345              if b in r:
     346                  if c in d:
     347                      a = c"""
     348          self.assertTrue(self.find_binding("a", s))
     349  
     350          s = """
     351              if b in r:
     352                  if c in d:
     353                      c = a"""
     354          self.assertFalse(self.find_binding("a", s))
     355  
     356      def test_while(self):
     357          self.assertTrue(self.find_binding("a", "while b in r: a = c"))
     358          self.assertFalse(self.find_binding("a", "while a in r: d = e"))
     359  
     360      def test_while_nested(self):
     361          s = """
     362              while b in r:
     363                  while c in d:
     364                      a = c"""
     365          self.assertTrue(self.find_binding("a", s))
     366  
     367          s = """
     368              while b in r:
     369                  while c in d:
     370                      c = a"""
     371          self.assertFalse(self.find_binding("a", s))
     372  
     373      def test_try_except(self):
     374          s = """
     375              try:
     376                  a = 6
     377              except:
     378                  b = 8"""
     379          self.assertTrue(self.find_binding("a", s))
     380  
     381          s = """
     382              try:
     383                  b = 8
     384              except:
     385                  a = 6"""
     386          self.assertTrue(self.find_binding("a", s))
     387  
     388          s = """
     389              try:
     390                  b = 8
     391              except KeyError:
     392                  pass
     393              except:
     394                  a = 6"""
     395          self.assertTrue(self.find_binding("a", s))
     396  
     397          s = """
     398              try:
     399                  b = 8
     400              except:
     401                  b = 6"""
     402          self.assertFalse(self.find_binding("a", s))
     403  
     404      def test_try_except_nested(self):
     405          s = """
     406              try:
     407                  try:
     408                      a = 6
     409                  except:
     410                      pass
     411              except:
     412                  b = 8"""
     413          self.assertTrue(self.find_binding("a", s))
     414  
     415          s = """
     416              try:
     417                  b = 8
     418              except:
     419                  try:
     420                      a = 6
     421                  except:
     422                      pass"""
     423          self.assertTrue(self.find_binding("a", s))
     424  
     425          s = """
     426              try:
     427                  b = 8
     428              except:
     429                  try:
     430                      pass
     431                  except:
     432                      a = 6"""
     433          self.assertTrue(self.find_binding("a", s))
     434  
     435          s = """
     436              try:
     437                  try:
     438                      b = 8
     439                  except KeyError:
     440                      pass
     441                  except:
     442                      a = 6
     443              except:
     444                  pass"""
     445          self.assertTrue(self.find_binding("a", s))
     446  
     447          s = """
     448              try:
     449                  pass
     450              except:
     451                  try:
     452                      b = 8
     453                  except KeyError:
     454                      pass
     455                  except:
     456                      a = 6"""
     457          self.assertTrue(self.find_binding("a", s))
     458  
     459          s = """
     460              try:
     461                  b = 8
     462              except:
     463                  b = 6"""
     464          self.assertFalse(self.find_binding("a", s))
     465  
     466          s = """
     467              try:
     468                  try:
     469                      b = 8
     470                  except:
     471                      c = d
     472              except:
     473                  try:
     474                      b = 6
     475                  except:
     476                      t = 8
     477                  except:
     478                      o = y"""
     479          self.assertFalse(self.find_binding("a", s))
     480  
     481      def test_try_except_finally(self):
     482          s = """
     483              try:
     484                  c = 6
     485              except:
     486                  b = 8
     487              finally:
     488                  a = 9"""
     489          self.assertTrue(self.find_binding("a", s))
     490  
     491          s = """
     492              try:
     493                  b = 8
     494              finally:
     495                  a = 6"""
     496          self.assertTrue(self.find_binding("a", s))
     497  
     498          s = """
     499              try:
     500                  b = 8
     501              finally:
     502                  b = 6"""
     503          self.assertFalse(self.find_binding("a", s))
     504  
     505          s = """
     506              try:
     507                  b = 8
     508              except:
     509                  b = 9
     510              finally:
     511                  b = 6"""
     512          self.assertFalse(self.find_binding("a", s))
     513  
     514      def test_try_except_finally_nested(self):
     515          s = """
     516              try:
     517                  c = 6
     518              except:
     519                  b = 8
     520              finally:
     521                  try:
     522                      a = 9
     523                  except:
     524                      b = 9
     525                  finally:
     526                      c = 9"""
     527          self.assertTrue(self.find_binding("a", s))
     528  
     529          s = """
     530              try:
     531                  b = 8
     532              finally:
     533                  try:
     534                      pass
     535                  finally:
     536                      a = 6"""
     537          self.assertTrue(self.find_binding("a", s))
     538  
     539          s = """
     540              try:
     541                  b = 8
     542              finally:
     543                  try:
     544                      b = 6
     545                  finally:
     546                      b = 7"""
     547          self.assertFalse(self.find_binding("a", s))
     548  
     549  class ESC[4;38;5;81mTest_touch_import(ESC[4;38;5;149msupportESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
     550  
     551      def test_after_docstring(self):
     552          node = parse('"""foo"""\nbar()')
     553          fixer_util.touch_import(None, "foo", node)
     554          self.assertEqual(str(node), '"""foo"""\nimport foo\nbar()\n\n')
     555  
     556      def test_after_imports(self):
     557          node = parse('"""foo"""\nimport bar\nbar()')
     558          fixer_util.touch_import(None, "foo", node)
     559          self.assertEqual(str(node), '"""foo"""\nimport bar\nimport foo\nbar()\n\n')
     560  
     561      def test_beginning(self):
     562          node = parse('bar()')
     563          fixer_util.touch_import(None, "foo", node)
     564          self.assertEqual(str(node), 'import foo\nbar()\n\n')
     565  
     566      def test_from_import(self):
     567          node = parse('bar()')
     568          fixer_util.touch_import("html", "escape", node)
     569          self.assertEqual(str(node), 'from html import escape\nbar()\n\n')
     570  
     571      def test_name_import(self):
     572          node = parse('bar()')
     573          fixer_util.touch_import(None, "cgi", node)
     574          self.assertEqual(str(node), 'import cgi\nbar()\n\n')
     575  
     576  class ESC[4;38;5;81mTest_find_indentation(ESC[4;38;5;149msupportESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
     577  
     578      def test_nothing(self):
     579          fi = fixer_util.find_indentation
     580          node = parse("node()")
     581          self.assertEqual(fi(node), "")
     582          node = parse("")
     583          self.assertEqual(fi(node), "")
     584  
     585      def test_simple(self):
     586          fi = fixer_util.find_indentation
     587          node = parse("def f():\n    x()")
     588          self.assertEqual(fi(node), "")
     589          self.assertEqual(fi(node.children[0].children[4].children[2]), "    ")
     590          node = parse("def f():\n    x()\n    y()")
     591          self.assertEqual(fi(node.children[0].children[4].children[4]), "    ")