1  """
       2  Unit tests for refactor.py.
       3  """
       4  
       5  import sys
       6  import os
       7  import codecs
       8  import io
       9  import re
      10  import tempfile
      11  import shutil
      12  import unittest
      13  
      14  from lib2to3 import refactor, pygram, fixer_base
      15  from lib2to3.pgen2 import token
      16  
      17  
      18  TEST_DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
      19  FIXER_DIR = os.path.join(TEST_DATA_DIR, "fixers")
      20  
      21  sys.path.append(FIXER_DIR)
      22  try:
      23      _DEFAULT_FIXERS = refactor.get_fixers_from_package("myfixes")
      24  finally:
      25      sys.path.pop()
      26  
      27  _2TO3_FIXERS = refactor.get_fixers_from_package("lib2to3.fixes")
      28  
      29  class ESC[4;38;5;81mTestRefactoringTool(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
      30  
      31      def setUp(self):
      32          sys.path.append(FIXER_DIR)
      33  
      34      def tearDown(self):
      35          sys.path.pop()
      36  
      37      def check_instances(self, instances, classes):
      38          for inst, cls in zip(instances, classes):
      39              if not isinstance(inst, cls):
      40                  self.fail("%s are not instances of %s" % instances, classes)
      41  
      42      def rt(self, options=None, fixers=_DEFAULT_FIXERS, explicit=None):
      43          return refactor.RefactoringTool(fixers, options, explicit)
      44  
      45      def test_print_function_option(self):
      46          rt = self.rt({"print_function" : True})
      47          self.assertNotIn("print", rt.grammar.keywords)
      48          self.assertNotIn("print", rt.driver.grammar.keywords)
      49  
      50      def test_exec_function_option(self):
      51          rt = self.rt({"exec_function" : True})
      52          self.assertNotIn("exec", rt.grammar.keywords)
      53          self.assertNotIn("exec", rt.driver.grammar.keywords)
      54  
      55      def test_write_unchanged_files_option(self):
      56          rt = self.rt()
      57          self.assertFalse(rt.write_unchanged_files)
      58          rt = self.rt({"write_unchanged_files" : True})
      59          self.assertTrue(rt.write_unchanged_files)
      60  
      61      def test_fixer_loading_helpers(self):
      62          contents = ["explicit", "first", "last", "parrot", "preorder"]
      63          non_prefixed = refactor.get_all_fix_names("myfixes")
      64          prefixed = refactor.get_all_fix_names("myfixes", False)
      65          full_names = refactor.get_fixers_from_package("myfixes")
      66          self.assertEqual(prefixed, ["fix_" + name for name in contents])
      67          self.assertEqual(non_prefixed, contents)
      68          self.assertEqual(full_names,
      69                           ["myfixes.fix_" + name for name in contents])
      70  
      71      def test_detect_future_features(self):
      72          run = refactor._detect_future_features
      73          fs = frozenset
      74          empty = fs()
      75          self.assertEqual(run(""), empty)
      76          self.assertEqual(run("from __future__ import print_function"),
      77                           fs(("print_function",)))
      78          self.assertEqual(run("from __future__ import generators"),
      79                           fs(("generators",)))
      80          self.assertEqual(run("from __future__ import generators, feature"),
      81                           fs(("generators", "feature")))
      82          inp = "from __future__ import generators, print_function"
      83          self.assertEqual(run(inp), fs(("generators", "print_function")))
      84          inp ="from __future__ import print_function, generators"
      85          self.assertEqual(run(inp), fs(("print_function", "generators")))
      86          inp = "from __future__ import (print_function,)"
      87          self.assertEqual(run(inp), fs(("print_function",)))
      88          inp = "from __future__ import (generators, print_function)"
      89          self.assertEqual(run(inp), fs(("generators", "print_function")))
      90          inp = "from __future__ import (generators, nested_scopes)"
      91          self.assertEqual(run(inp), fs(("generators", "nested_scopes")))
      92          inp = """from __future__ import generators
      93  from __future__ import print_function"""
      94          self.assertEqual(run(inp), fs(("generators", "print_function")))
      95          invalid = ("from",
      96                     "from 4",
      97                     "from x",
      98                     "from x 5",
      99                     "from x im",
     100                     "from x import",
     101                     "from x import 4",
     102                     )
     103          for inp in invalid:
     104              self.assertEqual(run(inp), empty)
     105          inp = "'docstring'\nfrom __future__ import print_function"
     106          self.assertEqual(run(inp), fs(("print_function",)))
     107          inp = "'docstring'\n'somng'\nfrom __future__ import print_function"
     108          self.assertEqual(run(inp), empty)
     109          inp = "# comment\nfrom __future__ import print_function"
     110          self.assertEqual(run(inp), fs(("print_function",)))
     111          inp = "# comment\n'doc'\nfrom __future__ import print_function"
     112          self.assertEqual(run(inp), fs(("print_function",)))
     113          inp = "class x: pass\nfrom __future__ import print_function"
     114          self.assertEqual(run(inp), empty)
     115  
     116      def test_get_headnode_dict(self):
     117          class ESC[4;38;5;81mNoneFix(ESC[4;38;5;149mfixer_baseESC[4;38;5;149m.ESC[4;38;5;149mBaseFix):
     118              pass
     119  
     120          class ESC[4;38;5;81mFileInputFix(ESC[4;38;5;149mfixer_baseESC[4;38;5;149m.ESC[4;38;5;149mBaseFix):
     121              PATTERN = "file_input< any * >"
     122  
     123          class ESC[4;38;5;81mSimpleFix(ESC[4;38;5;149mfixer_baseESC[4;38;5;149m.ESC[4;38;5;149mBaseFix):
     124              PATTERN = "'name'"
     125  
     126          no_head = NoneFix({}, [])
     127          with_head = FileInputFix({}, [])
     128          simple = SimpleFix({}, [])
     129          d = refactor._get_headnode_dict([no_head, with_head, simple])
     130          top_fixes = d.pop(pygram.python_symbols.file_input)
     131          self.assertEqual(top_fixes, [with_head, no_head])
     132          name_fixes = d.pop(token.NAME)
     133          self.assertEqual(name_fixes, [simple, no_head])
     134          for fixes in d.values():
     135              self.assertEqual(fixes, [no_head])
     136  
     137      def test_fixer_loading(self):
     138          from myfixes.fix_first import FixFirst
     139          from myfixes.fix_last import FixLast
     140          from myfixes.fix_parrot import FixParrot
     141          from myfixes.fix_preorder import FixPreorder
     142  
     143          rt = self.rt()
     144          pre, post = rt.get_fixers()
     145  
     146          self.check_instances(pre, [FixPreorder])
     147          self.check_instances(post, [FixFirst, FixParrot, FixLast])
     148  
     149      def test_naughty_fixers(self):
     150          self.assertRaises(ImportError, self.rt, fixers=["not_here"])
     151          self.assertRaises(refactor.FixerError, self.rt, fixers=["no_fixer_cls"])
     152          self.assertRaises(refactor.FixerError, self.rt, fixers=["bad_order"])
     153  
     154      def test_refactor_string(self):
     155          rt = self.rt()
     156          input = "def parrot(): pass\n\n"
     157          tree = rt.refactor_string(input, "<test>")
     158          self.assertNotEqual(str(tree), input)
     159  
     160          input = "def f(): pass\n\n"
     161          tree = rt.refactor_string(input, "<test>")
     162          self.assertEqual(str(tree), input)
     163  
     164      def test_refactor_stdin(self):
     165  
     166          class ESC[4;38;5;81mMyRT(ESC[4;38;5;149mrefactorESC[4;38;5;149m.ESC[4;38;5;149mRefactoringTool):
     167  
     168              def print_output(self, old_text, new_text, filename, equal):
     169                  results.extend([old_text, new_text, filename, equal])
     170  
     171          results = []
     172          rt = MyRT(_DEFAULT_FIXERS)
     173          save = sys.stdin
     174          sys.stdin = io.StringIO("def parrot(): pass\n\n")
     175          try:
     176              rt.refactor_stdin()
     177          finally:
     178              sys.stdin = save
     179          expected = ["def parrot(): pass\n\n",
     180                      "def cheese(): pass\n\n",
     181                      "<stdin>", False]
     182          self.assertEqual(results, expected)
     183  
     184      def check_file_refactoring(self, test_file, fixers=_2TO3_FIXERS,
     185                                 options=None, mock_log_debug=None,
     186                                 actually_write=True):
     187          test_file = self.init_test_file(test_file)
     188          old_contents = self.read_file(test_file)
     189          rt = self.rt(fixers=fixers, options=options)
     190          if mock_log_debug:
     191              rt.log_debug = mock_log_debug
     192  
     193          rt.refactor_file(test_file)
     194          self.assertEqual(old_contents, self.read_file(test_file))
     195  
     196          if not actually_write:
     197              return
     198          rt.refactor_file(test_file, True)
     199          new_contents = self.read_file(test_file)
     200          self.assertNotEqual(old_contents, new_contents)
     201          return new_contents
     202  
     203      def init_test_file(self, test_file):
     204          tmpdir = tempfile.mkdtemp(prefix="2to3-test_refactor")
     205          self.addCleanup(shutil.rmtree, tmpdir)
     206          shutil.copy(test_file, tmpdir)
     207          test_file = os.path.join(tmpdir, os.path.basename(test_file))
     208          os.chmod(test_file, 0o644)
     209          return test_file
     210  
     211      def read_file(self, test_file):
     212          with open(test_file, "rb") as fp:
     213              return fp.read()
     214  
     215      def refactor_file(self, test_file, fixers=_2TO3_FIXERS):
     216          test_file = self.init_test_file(test_file)
     217          old_contents = self.read_file(test_file)
     218          rt = self.rt(fixers=fixers)
     219          rt.refactor_file(test_file, True)
     220          new_contents = self.read_file(test_file)
     221          return old_contents, new_contents
     222  
     223      def test_refactor_file(self):
     224          test_file = os.path.join(FIXER_DIR, "parrot_example.py")
     225          self.check_file_refactoring(test_file, _DEFAULT_FIXERS)
     226  
     227      def test_refactor_file_write_unchanged_file(self):
     228          test_file = os.path.join(FIXER_DIR, "parrot_example.py")
     229          debug_messages = []
     230          def recording_log_debug(msg, *args):
     231              debug_messages.append(msg % args)
     232          self.check_file_refactoring(test_file, fixers=(),
     233                                      options={"write_unchanged_files": True},
     234                                      mock_log_debug=recording_log_debug,
     235                                      actually_write=False)
     236          # Testing that it logged this message when write=False was passed is
     237          # sufficient to see that it did not bail early after "No changes".
     238          message_regex = r"Not writing changes to .*%s" % \
     239                  re.escape(os.sep + os.path.basename(test_file))
     240          for message in debug_messages:
     241              if "Not writing changes" in message:
     242                  self.assertRegex(message, message_regex)
     243                  break
     244          else:
     245              self.fail("%r not matched in %r" % (message_regex, debug_messages))
     246  
     247      def test_refactor_dir(self):
     248          def check(structure, expected):
     249              def mock_refactor_file(self, f, *args):
     250                  got.append(f)
     251              save_func = refactor.RefactoringTool.refactor_file
     252              refactor.RefactoringTool.refactor_file = mock_refactor_file
     253              rt = self.rt()
     254              got = []
     255              dir = tempfile.mkdtemp(prefix="2to3-test_refactor")
     256              try:
     257                  os.mkdir(os.path.join(dir, "a_dir"))
     258                  for fn in structure:
     259                      open(os.path.join(dir, fn), "wb").close()
     260                  rt.refactor_dir(dir)
     261              finally:
     262                  refactor.RefactoringTool.refactor_file = save_func
     263                  shutil.rmtree(dir)
     264              self.assertEqual(got,
     265                               [os.path.join(dir, path) for path in expected])
     266          check([], [])
     267          tree = ["nothing",
     268                  "hi.py",
     269                  ".dumb",
     270                  ".after.py",
     271                  "notpy.npy",
     272                  "sappy"]
     273          expected = ["hi.py"]
     274          check(tree, expected)
     275          tree = ["hi.py",
     276                  os.path.join("a_dir", "stuff.py")]
     277          check(tree, tree)
     278  
     279      def test_file_encoding(self):
     280          fn = os.path.join(TEST_DATA_DIR, "different_encoding.py")
     281          self.check_file_refactoring(fn)
     282  
     283      def test_false_file_encoding(self):
     284          fn = os.path.join(TEST_DATA_DIR, "false_encoding.py")
     285          data = self.check_file_refactoring(fn)
     286  
     287      def test_bom(self):
     288          fn = os.path.join(TEST_DATA_DIR, "bom.py")
     289          data = self.check_file_refactoring(fn)
     290          self.assertTrue(data.startswith(codecs.BOM_UTF8))
     291  
     292      def test_crlf_newlines(self):
     293          old_sep = os.linesep
     294          os.linesep = "\r\n"
     295          try:
     296              fn = os.path.join(TEST_DATA_DIR, "crlf.py")
     297              fixes = refactor.get_fixers_from_package("lib2to3.fixes")
     298              self.check_file_refactoring(fn, fixes)
     299          finally:
     300              os.linesep = old_sep
     301  
     302      def test_crlf_unchanged(self):
     303          fn = os.path.join(TEST_DATA_DIR, "crlf.py")
     304          old, new = self.refactor_file(fn)
     305          self.assertIn(b"\r\n", old)
     306          self.assertIn(b"\r\n", new)
     307          self.assertNotIn(b"\r\r\n", new)
     308  
     309      def test_refactor_docstring(self):
     310          rt = self.rt()
     311  
     312          doc = """
     313  >>> example()
     314  42
     315  """
     316          out = rt.refactor_docstring(doc, "<test>")
     317          self.assertEqual(out, doc)
     318  
     319          doc = """
     320  >>> def parrot():
     321  ...      return 43
     322  """
     323          out = rt.refactor_docstring(doc, "<test>")
     324          self.assertNotEqual(out, doc)
     325  
     326      def test_explicit(self):
     327          from myfixes.fix_explicit import FixExplicit
     328  
     329          rt = self.rt(fixers=["myfixes.fix_explicit"])
     330          self.assertEqual(len(rt.post_order), 0)
     331  
     332          rt = self.rt(explicit=["myfixes.fix_explicit"])
     333          for fix in rt.post_order:
     334              if isinstance(fix, FixExplicit):
     335                  break
     336          else:
     337              self.fail("explicit fixer not loaded")