python (3.11.7)
1 """Utility functions, node construction macros, etc."""
2 # Author: Collin Winter
3
4 # Local imports
5 from .pgen2 import token
6 from .pytree import Leaf, Node
7 from .pygram import python_symbols as syms
8 from . import patcomp
9
10
11 ###########################################################
12 ### Common node-construction "macros"
13 ###########################################################
14
15 def KeywordArg(keyword, value):
16 return Node(syms.argument,
17 [keyword, Leaf(token.EQUAL, "="), value])
18
19 def LParen():
20 return Leaf(token.LPAR, "(")
21
22 def RParen():
23 return Leaf(token.RPAR, ")")
24
25 def Assign(target, source):
26 """Build an assignment statement"""
27 if not isinstance(target, list):
28 target = [target]
29 if not isinstance(source, list):
30 source.prefix = " "
31 source = [source]
32
33 return Node(syms.atom,
34 target + [Leaf(token.EQUAL, "=", prefix=" ")] + source)
35
36 def Name(name, prefix=None):
37 """Return a NAME leaf"""
38 return Leaf(token.NAME, name, prefix=prefix)
39
40 def Attr(obj, attr):
41 """A node tuple for obj.attr"""
42 return [obj, Node(syms.trailer, [Dot(), attr])]
43
44 def Comma():
45 """A comma leaf"""
46 return Leaf(token.COMMA, ",")
47
48 def Dot():
49 """A period (.) leaf"""
50 return Leaf(token.DOT, ".")
51
52 def ArgList(args, lparen=LParen(), rparen=RParen()):
53 """A parenthesised argument list, used by Call()"""
54 node = Node(syms.trailer, [lparen.clone(), rparen.clone()])
55 if args:
56 node.insert_child(1, Node(syms.arglist, args))
57 return node
58
59 def Call(func_name, args=None, prefix=None):
60 """A function call"""
61 node = Node(syms.power, [func_name, ArgList(args)])
62 if prefix is not None:
63 node.prefix = prefix
64 return node
65
66 def Newline():
67 """A newline literal"""
68 return Leaf(token.NEWLINE, "\n")
69
70 def BlankLine():
71 """A blank line"""
72 return Leaf(token.NEWLINE, "")
73
74 def Number(n, prefix=None):
75 return Leaf(token.NUMBER, n, prefix=prefix)
76
77 def Subscript(index_node):
78 """A numeric or string subscript"""
79 return Node(syms.trailer, [Leaf(token.LBRACE, "["),
80 index_node,
81 Leaf(token.RBRACE, "]")])
82
83 def String(string, prefix=None):
84 """A string leaf"""
85 return Leaf(token.STRING, string, prefix=prefix)
86
87 def ListComp(xp, fp, it, test=None):
88 """A list comprehension of the form [xp for fp in it if test].
89
90 If test is None, the "if test" part is omitted.
91 """
92 xp.prefix = ""
93 fp.prefix = " "
94 it.prefix = " "
95 for_leaf = Leaf(token.NAME, "for")
96 for_leaf.prefix = " "
97 in_leaf = Leaf(token.NAME, "in")
98 in_leaf.prefix = " "
99 inner_args = [for_leaf, fp, in_leaf, it]
100 if test:
101 test.prefix = " "
102 if_leaf = Leaf(token.NAME, "if")
103 if_leaf.prefix = " "
104 inner_args.append(Node(syms.comp_if, [if_leaf, test]))
105 inner = Node(syms.listmaker, [xp, Node(syms.comp_for, inner_args)])
106 return Node(syms.atom,
107 [Leaf(token.LBRACE, "["),
108 inner,
109 Leaf(token.RBRACE, "]")])
110
111 def FromImport(package_name, name_leafs):
112 """ Return an import statement in the form:
113 from package import name_leafs"""
114 # XXX: May not handle dotted imports properly (eg, package_name='foo.bar')
115 #assert package_name == '.' or '.' not in package_name, "FromImport has "\
116 # "not been tested with dotted package names -- use at your own "\
117 # "peril!"
118
119 for leaf in name_leafs:
120 # Pull the leaves out of their old tree
121 leaf.remove()
122
123 children = [Leaf(token.NAME, "from"),
124 Leaf(token.NAME, package_name, prefix=" "),
125 Leaf(token.NAME, "import", prefix=" "),
126 Node(syms.import_as_names, name_leafs)]
127 imp = Node(syms.import_from, children)
128 return imp
129
130 def ImportAndCall(node, results, names):
131 """Returns an import statement and calls a method
132 of the module:
133
134 import module
135 module.name()"""
136 obj = results["obj"].clone()
137 if obj.type == syms.arglist:
138 newarglist = obj.clone()
139 else:
140 newarglist = Node(syms.arglist, [obj.clone()])
141 after = results["after"]
142 if after:
143 after = [n.clone() for n in after]
144 new = Node(syms.power,
145 Attr(Name(names[0]), Name(names[1])) +
146 [Node(syms.trailer,
147 [results["lpar"].clone(),
148 newarglist,
149 results["rpar"].clone()])] + after)
150 new.prefix = node.prefix
151 return new
152
153
154 ###########################################################
155 ### Determine whether a node represents a given literal
156 ###########################################################
157
158 def is_tuple(node):
159 """Does the node represent a tuple literal?"""
160 if isinstance(node, Node) and node.children == [LParen(), RParen()]:
161 return True
162 return (isinstance(node, Node)
163 and len(node.children) == 3
164 and isinstance(node.children[0], Leaf)
165 and isinstance(node.children[1], Node)
166 and isinstance(node.children[2], Leaf)
167 and node.children[0].value == "("
168 and node.children[2].value == ")")
169
170 def is_list(node):
171 """Does the node represent a list literal?"""
172 return (isinstance(node, Node)
173 and len(node.children) > 1
174 and isinstance(node.children[0], Leaf)
175 and isinstance(node.children[-1], Leaf)
176 and node.children[0].value == "["
177 and node.children[-1].value == "]")
178
179
180 ###########################################################
181 ### Misc
182 ###########################################################
183
184 def parenthesize(node):
185 return Node(syms.atom, [LParen(), node, RParen()])
186
187
188 consuming_calls = {"sorted", "list", "set", "any", "all", "tuple", "sum",
189 "min", "max", "enumerate"}
190
191 def attr_chain(obj, attr):
192 """Follow an attribute chain.
193
194 If you have a chain of objects where a.foo -> b, b.foo-> c, etc,
195 use this to iterate over all objects in the chain. Iteration is
196 terminated by getattr(x, attr) is None.
197
198 Args:
199 obj: the starting object
200 attr: the name of the chaining attribute
201
202 Yields:
203 Each successive object in the chain.
204 """
205 next = getattr(obj, attr)
206 while next:
207 yield next
208 next = getattr(next, attr)
209
210 p0 = """for_stmt< 'for' any 'in' node=any ':' any* >
211 | comp_for< 'for' any 'in' node=any any* >
212 """
213 p1 = """
214 power<
215 ( 'iter' | 'list' | 'tuple' | 'sorted' | 'set' | 'sum' |
216 'any' | 'all' | 'enumerate' | (any* trailer< '.' 'join' >) )
217 trailer< '(' node=any ')' >
218 any*
219 >
220 """
221 p2 = """
222 power<
223 ( 'sorted' | 'enumerate' )
224 trailer< '(' arglist<node=any any*> ')' >
225 any*
226 >
227 """
228 pats_built = False
229 def in_special_context(node):
230 """ Returns true if node is in an environment where all that is required
231 of it is being iterable (ie, it doesn't matter if it returns a list
232 or an iterator).
233 See test_map_nochange in test_fixers.py for some examples and tests.
234 """
235 global p0, p1, p2, pats_built
236 if not pats_built:
237 p0 = patcomp.compile_pattern(p0)
238 p1 = patcomp.compile_pattern(p1)
239 p2 = patcomp.compile_pattern(p2)
240 pats_built = True
241 patterns = [p0, p1, p2]
242 for pattern, parent in zip(patterns, attr_chain(node, "parent")):
243 results = {}
244 if pattern.match(parent, results) and results["node"] is node:
245 return True
246 return False
247
248 def is_probably_builtin(node):
249 """
250 Check that something isn't an attribute or function name etc.
251 """
252 prev = node.prev_sibling
253 if prev is not None and prev.type == token.DOT:
254 # Attribute lookup.
255 return False
256 parent = node.parent
257 if parent.type in (syms.funcdef, syms.classdef):
258 return False
259 if parent.type == syms.expr_stmt and parent.children[0] is node:
260 # Assignment.
261 return False
262 if parent.type == syms.parameters or \
263 (parent.type == syms.typedargslist and (
264 (prev is not None and prev.type == token.COMMA) or
265 parent.children[0] is node
266 )):
267 # The name of an argument.
268 return False
269 return True
270
271 def find_indentation(node):
272 """Find the indentation of *node*."""
273 while node is not None:
274 if node.type == syms.suite and len(node.children) > 2:
275 indent = node.children[1]
276 if indent.type == token.INDENT:
277 return indent.value
278 node = node.parent
279 return ""
280
281 ###########################################################
282 ### The following functions are to find bindings in a suite
283 ###########################################################
284
285 def make_suite(node):
286 if node.type == syms.suite:
287 return node
288 node = node.clone()
289 parent, node.parent = node.parent, None
290 suite = Node(syms.suite, [node])
291 suite.parent = parent
292 return suite
293
294 def find_root(node):
295 """Find the top level namespace."""
296 # Scamper up to the top level namespace
297 while node.type != syms.file_input:
298 node = node.parent
299 if not node:
300 raise ValueError("root found before file_input node was found.")
301 return node
302
303 def does_tree_import(package, name, node):
304 """ Returns true if name is imported from package at the
305 top level of the tree which node belongs to.
306 To cover the case of an import like 'import foo', use
307 None for the package and 'foo' for the name. """
308 binding = find_binding(name, find_root(node), package)
309 return bool(binding)
310
311 def is_import(node):
312 """Returns true if the node is an import statement."""
313 return node.type in (syms.import_name, syms.import_from)
314
315 def touch_import(package, name, node):
316 """ Works like `does_tree_import` but adds an import statement
317 if it was not imported. """
318 def is_import_stmt(node):
319 return (node.type == syms.simple_stmt and node.children and
320 is_import(node.children[0]))
321
322 root = find_root(node)
323
324 if does_tree_import(package, name, root):
325 return
326
327 # figure out where to insert the new import. First try to find
328 # the first import and then skip to the last one.
329 insert_pos = offset = 0
330 for idx, node in enumerate(root.children):
331 if not is_import_stmt(node):
332 continue
333 for offset, node2 in enumerate(root.children[idx:]):
334 if not is_import_stmt(node2):
335 break
336 insert_pos = idx + offset
337 break
338
339 # if there are no imports where we can insert, find the docstring.
340 # if that also fails, we stick to the beginning of the file
341 if insert_pos == 0:
342 for idx, node in enumerate(root.children):
343 if (node.type == syms.simple_stmt and node.children and
344 node.children[0].type == token.STRING):
345 insert_pos = idx + 1
346 break
347
348 if package is None:
349 import_ = Node(syms.import_name, [
350 Leaf(token.NAME, "import"),
351 Leaf(token.NAME, name, prefix=" ")
352 ])
353 else:
354 import_ = FromImport(package, [Leaf(token.NAME, name, prefix=" ")])
355
356 children = [import_, Newline()]
357 root.insert_child(insert_pos, Node(syms.simple_stmt, children))
358
359
360 _def_syms = {syms.classdef, syms.funcdef}
361 def find_binding(name, node, package=None):
362 """ Returns the node which binds variable name, otherwise None.
363 If optional argument package is supplied, only imports will
364 be returned.
365 See test cases for examples."""
366 for child in node.children:
367 ret = None
368 if child.type == syms.for_stmt:
369 if _find(name, child.children[1]):
370 return child
371 n = find_binding(name, make_suite(child.children[-1]), package)
372 if n: ret = n
373 elif child.type in (syms.if_stmt, syms.while_stmt):
374 n = find_binding(name, make_suite(child.children[-1]), package)
375 if n: ret = n
376 elif child.type == syms.try_stmt:
377 n = find_binding(name, make_suite(child.children[2]), package)
378 if n:
379 ret = n
380 else:
381 for i, kid in enumerate(child.children[3:]):
382 if kid.type == token.COLON and kid.value == ":":
383 # i+3 is the colon, i+4 is the suite
384 n = find_binding(name, make_suite(child.children[i+4]), package)
385 if n: ret = n
386 elif child.type in _def_syms and child.children[1].value == name:
387 ret = child
388 elif _is_import_binding(child, name, package):
389 ret = child
390 elif child.type == syms.simple_stmt:
391 ret = find_binding(name, child, package)
392 elif child.type == syms.expr_stmt:
393 if _find(name, child.children[0]):
394 ret = child
395
396 if ret:
397 if not package:
398 return ret
399 if is_import(ret):
400 return ret
401 return None
402
403 _block_syms = {syms.funcdef, syms.classdef, syms.trailer}
404 def _find(name, node):
405 nodes = [node]
406 while nodes:
407 node = nodes.pop()
408 if node.type > 256 and node.type not in _block_syms:
409 nodes.extend(node.children)
410 elif node.type == token.NAME and node.value == name:
411 return node
412 return None
413
414 def _is_import_binding(node, name, package=None):
415 """ Will return node if node will import name, or node
416 will import * from package. None is returned otherwise.
417 See test cases for examples. """
418
419 if node.type == syms.import_name and not package:
420 imp = node.children[1]
421 if imp.type == syms.dotted_as_names:
422 for child in imp.children:
423 if child.type == syms.dotted_as_name:
424 if child.children[2].value == name:
425 return node
426 elif child.type == token.NAME and child.value == name:
427 return node
428 elif imp.type == syms.dotted_as_name:
429 last = imp.children[-1]
430 if last.type == token.NAME and last.value == name:
431 return node
432 elif imp.type == token.NAME and imp.value == name:
433 return node
434 elif node.type == syms.import_from:
435 # str(...) is used to make life easier here, because
436 # from a.b import parses to ['import', ['a', '.', 'b'], ...]
437 if package and str(node.children[1]).strip() != package:
438 return None
439 n = node.children[3]
440 if package and _find("as", n):
441 # See test_from_import_as for explanation
442 return None
443 elif n.type == syms.import_as_names and _find(name, n):
444 return node
445 elif n.type == syms.import_as_name:
446 child = n.children[2]
447 if child.type == token.NAME and child.value == name:
448 return node
449 elif n.type == token.NAME and n.value == name:
450 return node
451 elif package and n.type == token.STAR:
452 return node
453 return None