(root)/
Python-3.12.0/
Lib/
lib2to3/
fixer_util.py
       1  """Utility functions, node construction macros, etc."""
       2  # Author: Collin Winter
       3  
       4  # Local imports
       5  from .pgen2 import token
       6  from .pytree import Leaf, Node
       7  from .pygram import python_symbols as syms
       8  from . import patcomp
       9  
      10  
      11  ###########################################################
      12  ### Common node-construction "macros"
      13  ###########################################################
      14  
      15  def KeywordArg(keyword, value):
      16      return Node(syms.argument,
      17                  [keyword, Leaf(token.EQUAL, "="), value])
      18  
      19  def LParen():
      20      return Leaf(token.LPAR, "(")
      21  
      22  def RParen():
      23      return Leaf(token.RPAR, ")")
      24  
      25  def Assign(target, source):
      26      """Build an assignment statement"""
      27      if not isinstance(target, list):
      28          target = [target]
      29      if not isinstance(source, list):
      30          source.prefix = " "
      31          source = [source]
      32  
      33      return Node(syms.atom,
      34                  target + [Leaf(token.EQUAL, "=", prefix=" ")] + source)
      35  
      36  def Name(name, prefix=None):
      37      """Return a NAME leaf"""
      38      return Leaf(token.NAME, name, prefix=prefix)
      39  
      40  def Attr(obj, attr):
      41      """A node tuple for obj.attr"""
      42      return [obj, Node(syms.trailer, [Dot(), attr])]
      43  
      44  def Comma():
      45      """A comma leaf"""
      46      return Leaf(token.COMMA, ",")
      47  
      48  def Dot():
      49      """A period (.) leaf"""
      50      return Leaf(token.DOT, ".")
      51  
      52  def ArgList(args, lparen=LParen(), rparen=RParen()):
      53      """A parenthesised argument list, used by Call()"""
      54      node = Node(syms.trailer, [lparen.clone(), rparen.clone()])
      55      if args:
      56          node.insert_child(1, Node(syms.arglist, args))
      57      return node
      58  
      59  def Call(func_name, args=None, prefix=None):
      60      """A function call"""
      61      node = Node(syms.power, [func_name, ArgList(args)])
      62      if prefix is not None:
      63          node.prefix = prefix
      64      return node
      65  
      66  def Newline():
      67      """A newline literal"""
      68      return Leaf(token.NEWLINE, "\n")
      69  
      70  def BlankLine():
      71      """A blank line"""
      72      return Leaf(token.NEWLINE, "")
      73  
      74  def Number(n, prefix=None):
      75      return Leaf(token.NUMBER, n, prefix=prefix)
      76  
      77  def Subscript(index_node):
      78      """A numeric or string subscript"""
      79      return Node(syms.trailer, [Leaf(token.LBRACE, "["),
      80                                 index_node,
      81                                 Leaf(token.RBRACE, "]")])
      82  
      83  def String(string, prefix=None):
      84      """A string leaf"""
      85      return Leaf(token.STRING, string, prefix=prefix)
      86  
      87  def ListComp(xp, fp, it, test=None):
      88      """A list comprehension of the form [xp for fp in it if test].
      89  
      90      If test is None, the "if test" part is omitted.
      91      """
      92      xp.prefix = ""
      93      fp.prefix = " "
      94      it.prefix = " "
      95      for_leaf = Leaf(token.NAME, "for")
      96      for_leaf.prefix = " "
      97      in_leaf = Leaf(token.NAME, "in")
      98      in_leaf.prefix = " "
      99      inner_args = [for_leaf, fp, in_leaf, it]
     100      if test:
     101          test.prefix = " "
     102          if_leaf = Leaf(token.NAME, "if")
     103          if_leaf.prefix = " "
     104          inner_args.append(Node(syms.comp_if, [if_leaf, test]))
     105      inner = Node(syms.listmaker, [xp, Node(syms.comp_for, inner_args)])
     106      return Node(syms.atom,
     107                         [Leaf(token.LBRACE, "["),
     108                          inner,
     109                          Leaf(token.RBRACE, "]")])
     110  
     111  def FromImport(package_name, name_leafs):
     112      """ Return an import statement in the form:
     113          from package import name_leafs"""
     114      # XXX: May not handle dotted imports properly (eg, package_name='foo.bar')
     115      #assert package_name == '.' or '.' not in package_name, "FromImport has "\
     116      #       "not been tested with dotted package names -- use at your own "\
     117      #       "peril!"
     118  
     119      for leaf in name_leafs:
     120          # Pull the leaves out of their old tree
     121          leaf.remove()
     122  
     123      children = [Leaf(token.NAME, "from"),
     124                  Leaf(token.NAME, package_name, prefix=" "),
     125                  Leaf(token.NAME, "import", prefix=" "),
     126                  Node(syms.import_as_names, name_leafs)]
     127      imp = Node(syms.import_from, children)
     128      return imp
     129  
     130  def ImportAndCall(node, results, names):
     131      """Returns an import statement and calls a method
     132      of the module:
     133  
     134      import module
     135      module.name()"""
     136      obj = results["obj"].clone()
     137      if obj.type == syms.arglist:
     138          newarglist = obj.clone()
     139      else:
     140          newarglist = Node(syms.arglist, [obj.clone()])
     141      after = results["after"]
     142      if after:
     143          after = [n.clone() for n in after]
     144      new = Node(syms.power,
     145                 Attr(Name(names[0]), Name(names[1])) +
     146                 [Node(syms.trailer,
     147                       [results["lpar"].clone(),
     148                        newarglist,
     149                        results["rpar"].clone()])] + after)
     150      new.prefix = node.prefix
     151      return new
     152  
     153  
     154  ###########################################################
     155  ### Determine whether a node represents a given literal
     156  ###########################################################
     157  
     158  def is_tuple(node):
     159      """Does the node represent a tuple literal?"""
     160      if isinstance(node, Node) and node.children == [LParen(), RParen()]:
     161          return True
     162      return (isinstance(node, Node)
     163              and len(node.children) == 3
     164              and isinstance(node.children[0], Leaf)
     165              and isinstance(node.children[1], Node)
     166              and isinstance(node.children[2], Leaf)
     167              and node.children[0].value == "("
     168              and node.children[2].value == ")")
     169  
     170  def is_list(node):
     171      """Does the node represent a list literal?"""
     172      return (isinstance(node, Node)
     173              and len(node.children) > 1
     174              and isinstance(node.children[0], Leaf)
     175              and isinstance(node.children[-1], Leaf)
     176              and node.children[0].value == "["
     177              and node.children[-1].value == "]")
     178  
     179  
     180  ###########################################################
     181  ### Misc
     182  ###########################################################
     183  
     184  def parenthesize(node):
     185      return Node(syms.atom, [LParen(), node, RParen()])
     186  
     187  
     188  consuming_calls = {"sorted", "list", "set", "any", "all", "tuple", "sum",
     189                     "min", "max", "enumerate"}
     190  
     191  def attr_chain(obj, attr):
     192      """Follow an attribute chain.
     193  
     194      If you have a chain of objects where a.foo -> b, b.foo-> c, etc,
     195      use this to iterate over all objects in the chain. Iteration is
     196      terminated by getattr(x, attr) is None.
     197  
     198      Args:
     199          obj: the starting object
     200          attr: the name of the chaining attribute
     201  
     202      Yields:
     203          Each successive object in the chain.
     204      """
     205      next = getattr(obj, attr)
     206      while next:
     207          yield next
     208          next = getattr(next, attr)
     209  
     210  p0 = """for_stmt< 'for' any 'in' node=any ':' any* >
     211          | comp_for< 'for' any 'in' node=any any* >
     212       """
     213  p1 = """
     214  power<
     215      ( 'iter' | 'list' | 'tuple' | 'sorted' | 'set' | 'sum' |
     216        'any' | 'all' | 'enumerate' | (any* trailer< '.' 'join' >) )
     217      trailer< '(' node=any ')' >
     218      any*
     219  >
     220  """
     221  p2 = """
     222  power<
     223      ( 'sorted' | 'enumerate' )
     224      trailer< '(' arglist<node=any any*> ')' >
     225      any*
     226  >
     227  """
     228  pats_built = False
     229  def in_special_context(node):
     230      """ Returns true if node is in an environment where all that is required
     231          of it is being iterable (ie, it doesn't matter if it returns a list
     232          or an iterator).
     233          See test_map_nochange in test_fixers.py for some examples and tests.
     234          """
     235      global p0, p1, p2, pats_built
     236      if not pats_built:
     237          p0 = patcomp.compile_pattern(p0)
     238          p1 = patcomp.compile_pattern(p1)
     239          p2 = patcomp.compile_pattern(p2)
     240          pats_built = True
     241      patterns = [p0, p1, p2]
     242      for pattern, parent in zip(patterns, attr_chain(node, "parent")):
     243          results = {}
     244          if pattern.match(parent, results) and results["node"] is node:
     245              return True
     246      return False
     247  
     248  def is_probably_builtin(node):
     249      """
     250      Check that something isn't an attribute or function name etc.
     251      """
     252      prev = node.prev_sibling
     253      if prev is not None and prev.type == token.DOT:
     254          # Attribute lookup.
     255          return False
     256      parent = node.parent
     257      if parent.type in (syms.funcdef, syms.classdef):
     258          return False
     259      if parent.type == syms.expr_stmt and parent.children[0] is node:
     260          # Assignment.
     261          return False
     262      if parent.type == syms.parameters or \
     263              (parent.type == syms.typedargslist and (
     264              (prev is not None and prev.type == token.COMMA) or
     265              parent.children[0] is node
     266              )):
     267          # The name of an argument.
     268          return False
     269      return True
     270  
     271  def find_indentation(node):
     272      """Find the indentation of *node*."""
     273      while node is not None:
     274          if node.type == syms.suite and len(node.children) > 2:
     275              indent = node.children[1]
     276              if indent.type == token.INDENT:
     277                  return indent.value
     278          node = node.parent
     279      return ""
     280  
     281  ###########################################################
     282  ### The following functions are to find bindings in a suite
     283  ###########################################################
     284  
     285  def make_suite(node):
     286      if node.type == syms.suite:
     287          return node
     288      node = node.clone()
     289      parent, node.parent = node.parent, None
     290      suite = Node(syms.suite, [node])
     291      suite.parent = parent
     292      return suite
     293  
     294  def find_root(node):
     295      """Find the top level namespace."""
     296      # Scamper up to the top level namespace
     297      while node.type != syms.file_input:
     298          node = node.parent
     299          if not node:
     300              raise ValueError("root found before file_input node was found.")
     301      return node
     302  
     303  def does_tree_import(package, name, node):
     304      """ Returns true if name is imported from package at the
     305          top level of the tree which node belongs to.
     306          To cover the case of an import like 'import foo', use
     307          None for the package and 'foo' for the name. """
     308      binding = find_binding(name, find_root(node), package)
     309      return bool(binding)
     310  
     311  def is_import(node):
     312      """Returns true if the node is an import statement."""
     313      return node.type in (syms.import_name, syms.import_from)
     314  
     315  def touch_import(package, name, node):
     316      """ Works like `does_tree_import` but adds an import statement
     317          if it was not imported. """
     318      def is_import_stmt(node):
     319          return (node.type == syms.simple_stmt and node.children and
     320                  is_import(node.children[0]))
     321  
     322      root = find_root(node)
     323  
     324      if does_tree_import(package, name, root):
     325          return
     326  
     327      # figure out where to insert the new import.  First try to find
     328      # the first import and then skip to the last one.
     329      insert_pos = offset = 0
     330      for idx, node in enumerate(root.children):
     331          if not is_import_stmt(node):
     332              continue
     333          for offset, node2 in enumerate(root.children[idx:]):
     334              if not is_import_stmt(node2):
     335                  break
     336          insert_pos = idx + offset
     337          break
     338  
     339      # if there are no imports where we can insert, find the docstring.
     340      # if that also fails, we stick to the beginning of the file
     341      if insert_pos == 0:
     342          for idx, node in enumerate(root.children):
     343              if (node.type == syms.simple_stmt and node.children and
     344                 node.children[0].type == token.STRING):
     345                  insert_pos = idx + 1
     346                  break
     347  
     348      if package is None:
     349          import_ = Node(syms.import_name, [
     350              Leaf(token.NAME, "import"),
     351              Leaf(token.NAME, name, prefix=" ")
     352          ])
     353      else:
     354          import_ = FromImport(package, [Leaf(token.NAME, name, prefix=" ")])
     355  
     356      children = [import_, Newline()]
     357      root.insert_child(insert_pos, Node(syms.simple_stmt, children))
     358  
     359  
     360  _def_syms = {syms.classdef, syms.funcdef}
     361  def find_binding(name, node, package=None):
     362      """ Returns the node which binds variable name, otherwise None.
     363          If optional argument package is supplied, only imports will
     364          be returned.
     365          See test cases for examples."""
     366      for child in node.children:
     367          ret = None
     368          if child.type == syms.for_stmt:
     369              if _find(name, child.children[1]):
     370                  return child
     371              n = find_binding(name, make_suite(child.children[-1]), package)
     372              if n: ret = n
     373          elif child.type in (syms.if_stmt, syms.while_stmt):
     374              n = find_binding(name, make_suite(child.children[-1]), package)
     375              if n: ret = n
     376          elif child.type == syms.try_stmt:
     377              n = find_binding(name, make_suite(child.children[2]), package)
     378              if n:
     379                  ret = n
     380              else:
     381                  for i, kid in enumerate(child.children[3:]):
     382                      if kid.type == token.COLON and kid.value == ":":
     383                          # i+3 is the colon, i+4 is the suite
     384                          n = find_binding(name, make_suite(child.children[i+4]), package)
     385                          if n: ret = n
     386          elif child.type in _def_syms and child.children[1].value == name:
     387              ret = child
     388          elif _is_import_binding(child, name, package):
     389              ret = child
     390          elif child.type == syms.simple_stmt:
     391              ret = find_binding(name, child, package)
     392          elif child.type == syms.expr_stmt:
     393              if _find(name, child.children[0]):
     394                  ret = child
     395  
     396          if ret:
     397              if not package:
     398                  return ret
     399              if is_import(ret):
     400                  return ret
     401      return None
     402  
     403  _block_syms = {syms.funcdef, syms.classdef, syms.trailer}
     404  def _find(name, node):
     405      nodes = [node]
     406      while nodes:
     407          node = nodes.pop()
     408          if node.type > 256 and node.type not in _block_syms:
     409              nodes.extend(node.children)
     410          elif node.type == token.NAME and node.value == name:
     411              return node
     412      return None
     413  
     414  def _is_import_binding(node, name, package=None):
     415      """ Will return node if node will import name, or node
     416          will import * from package.  None is returned otherwise.
     417          See test cases for examples. """
     418  
     419      if node.type == syms.import_name and not package:
     420          imp = node.children[1]
     421          if imp.type == syms.dotted_as_names:
     422              for child in imp.children:
     423                  if child.type == syms.dotted_as_name:
     424                      if child.children[2].value == name:
     425                          return node
     426                  elif child.type == token.NAME and child.value == name:
     427                      return node
     428          elif imp.type == syms.dotted_as_name:
     429              last = imp.children[-1]
     430              if last.type == token.NAME and last.value == name:
     431                  return node
     432          elif imp.type == token.NAME and imp.value == name:
     433              return node
     434      elif node.type == syms.import_from:
     435          # str(...) is used to make life easier here, because
     436          # from a.b import parses to ['import', ['a', '.', 'b'], ...]
     437          if package and str(node.children[1]).strip() != package:
     438              return None
     439          n = node.children[3]
     440          if package and _find("as", n):
     441              # See test_from_import_as for explanation
     442              return None
     443          elif n.type == syms.import_as_names and _find(name, n):
     444              return node
     445          elif n.type == syms.import_as_name:
     446              child = n.children[2]
     447              if child.type == token.NAME and child.value == name:
     448                  return node
     449          elif n.type == token.NAME and n.value == name:
     450              return node
     451          elif package and n.type == token.STAR:
     452              return node
     453      return None