1 """
2 ast
3 ~~~
4
5 The `ast` module helps Python applications to process trees of the Python
6 abstract syntax grammar. The abstract syntax itself might change with
7 each Python release; this module helps to find out programmatically what
8 the current grammar looks like and allows modifications of it.
9
10 An abstract syntax tree can be generated by passing `ast.PyCF_ONLY_AST` as
11 a flag to the `compile()` builtin function or by using the `parse()`
12 function from this module. The result will be a tree of objects whose
13 classes all inherit from `ast.AST`.
14
15 A modified abstract syntax tree can be compiled into a Python code object
16 using the built-in `compile()` function.
17
18 Additionally various helper functions are provided that make working with
19 the trees simpler. The main intention of the helper functions and this
20 module in general is to provide an easy to use interface for libraries
21 that work tightly with the python syntax (template engines for example).
22
23
24 :copyright: Copyright 2008 by Armin Ronacher.
25 :license: Python License.
26 """
27 import sys
28 import re
29 from _ast import *
30 from contextlib import contextmanager, nullcontext
31 from enum import IntEnum, auto, _simple_enum
32
33
34 def parse(source, filename='<unknown>', mode='exec', *,
35 type_comments=False, feature_version=None):
36 """
37 Parse the source into an AST node.
38 Equivalent to compile(source, filename, mode, PyCF_ONLY_AST).
39 Pass type_comments=True to get back type comments where the syntax allows.
40 """
41 flags = PyCF_ONLY_AST
42 if type_comments:
43 flags |= PyCF_TYPE_COMMENTS
44 if feature_version is None:
45 feature_version = -1
46 elif isinstance(feature_version, tuple):
47 major, minor = feature_version # Should be a 2-tuple.
48 if major != 3:
49 raise ValueError(f"Unsupported major version: {major}")
50 feature_version = minor
51 # Else it should be an int giving the minor version for 3.x.
52 return compile(source, filename, mode, flags,
53 _feature_version=feature_version)
54
55
56 def literal_eval(node_or_string):
57 """
58 Evaluate an expression node or a string containing only a Python
59 expression. The string or node provided may only consist of the following
60 Python literal structures: strings, bytes, numbers, tuples, lists, dicts,
61 sets, booleans, and None.
62
63 Caution: A complex expression can overflow the C stack and cause a crash.
64 """
65 if isinstance(node_or_string, str):
66 node_or_string = parse(node_or_string.lstrip(" \t"), mode='eval')
67 if isinstance(node_or_string, Expression):
68 node_or_string = node_or_string.body
69 def _raise_malformed_node(node):
70 msg = "malformed node or string"
71 if lno := getattr(node, 'lineno', None):
72 msg += f' on line {lno}'
73 raise ValueError(msg + f': {node!r}')
74 def _convert_num(node):
75 if not isinstance(node, Constant) or type(node.value) not in (int, float, complex):
76 _raise_malformed_node(node)
77 return node.value
78 def _convert_signed_num(node):
79 if isinstance(node, UnaryOp) and isinstance(node.op, (UAdd, USub)):
80 operand = _convert_num(node.operand)
81 if isinstance(node.op, UAdd):
82 return + operand
83 else:
84 return - operand
85 return _convert_num(node)
86 def _convert(node):
87 if isinstance(node, Constant):
88 return node.value
89 elif isinstance(node, Tuple):
90 return tuple(map(_convert, node.elts))
91 elif isinstance(node, List):
92 return list(map(_convert, node.elts))
93 elif isinstance(node, Set):
94 return set(map(_convert, node.elts))
95 elif (isinstance(node, Call) and isinstance(node.func, Name) and
96 node.func.id == 'set' and node.args == node.keywords == []):
97 return set()
98 elif isinstance(node, Dict):
99 if len(node.keys) != len(node.values):
100 _raise_malformed_node(node)
101 return dict(zip(map(_convert, node.keys),
102 map(_convert, node.values)))
103 elif isinstance(node, BinOp) and isinstance(node.op, (Add, Sub)):
104 left = _convert_signed_num(node.left)
105 right = _convert_num(node.right)
106 if isinstance(left, (int, float)) and isinstance(right, complex):
107 if isinstance(node.op, Add):
108 return left + right
109 else:
110 return left - right
111 return _convert_signed_num(node)
112 return _convert(node_or_string)
113
114
115 def dump(node, annotate_fields=True, include_attributes=False, *, indent=None):
116 """
117 Return a formatted dump of the tree in node. This is mainly useful for
118 debugging purposes. If annotate_fields is true (by default),
119 the returned string will show the names and the values for fields.
120 If annotate_fields is false, the result string will be more compact by
121 omitting unambiguous field names. Attributes such as line
122 numbers and column offsets are not dumped by default. If this is wanted,
123 include_attributes can be set to true. If indent is a non-negative
124 integer or string, then the tree will be pretty-printed with that indent
125 level. None (the default) selects the single line representation.
126 """
127 def _format(node, level=0):
128 if indent is not None:
129 level += 1
130 prefix = '\n' + indent * level
131 sep = ',\n' + indent * level
132 else:
133 prefix = ''
134 sep = ', '
135 if isinstance(node, AST):
136 cls = type(node)
137 args = []
138 allsimple = True
139 keywords = annotate_fields
140 for name in node._fields:
141 try:
142 value = getattr(node, name)
143 except AttributeError:
144 keywords = True
145 continue
146 if value is None and getattr(cls, name, ...) is None:
147 keywords = True
148 continue
149 value, simple = _format(value, level)
150 allsimple = allsimple and simple
151 if keywords:
152 args.append('%s=%s' % (name, value))
153 else:
154 args.append(value)
155 if include_attributes and node._attributes:
156 for name in node._attributes:
157 try:
158 value = getattr(node, name)
159 except AttributeError:
160 continue
161 if value is None and getattr(cls, name, ...) is None:
162 continue
163 value, simple = _format(value, level)
164 allsimple = allsimple and simple
165 args.append('%s=%s' % (name, value))
166 if allsimple and len(args) <= 3:
167 return '%s(%s)' % (node.__class__.__name__, ', '.join(args)), not args
168 return '%s(%s%s)' % (node.__class__.__name__, prefix, sep.join(args)), False
169 elif isinstance(node, list):
170 if not node:
171 return '[]', True
172 return '[%s%s]' % (prefix, sep.join(_format(x, level)[0] for x in node)), False
173 return repr(node), True
174
175 if not isinstance(node, AST):
176 raise TypeError('expected AST, got %r' % node.__class__.__name__)
177 if indent is not None and not isinstance(indent, str):
178 indent = ' ' * indent
179 return _format(node)[0]
180
181
182 def copy_location(new_node, old_node):
183 """
184 Copy source location (`lineno`, `col_offset`, `end_lineno`, and `end_col_offset`
185 attributes) from *old_node* to *new_node* if possible, and return *new_node*.
186 """
187 for attr in 'lineno', 'col_offset', 'end_lineno', 'end_col_offset':
188 if attr in old_node._attributes and attr in new_node._attributes:
189 value = getattr(old_node, attr, None)
190 # end_lineno and end_col_offset are optional attributes, and they
191 # should be copied whether the value is None or not.
192 if value is not None or (
193 hasattr(old_node, attr) and attr.startswith("end_")
194 ):
195 setattr(new_node, attr, value)
196 return new_node
197
198
199 def fix_missing_locations(node):
200 """
201 When you compile a node tree with compile(), the compiler expects lineno and
202 col_offset attributes for every node that supports them. This is rather
203 tedious to fill in for generated nodes, so this helper adds these attributes
204 recursively where not already set, by setting them to the values of the
205 parent node. It works recursively starting at *node*.
206 """
207 def _fix(node, lineno, col_offset, end_lineno, end_col_offset):
208 if 'lineno' in node._attributes:
209 if not hasattr(node, 'lineno'):
210 node.lineno = lineno
211 else:
212 lineno = node.lineno
213 if 'end_lineno' in node._attributes:
214 if getattr(node, 'end_lineno', None) is None:
215 node.end_lineno = end_lineno
216 else:
217 end_lineno = node.end_lineno
218 if 'col_offset' in node._attributes:
219 if not hasattr(node, 'col_offset'):
220 node.col_offset = col_offset
221 else:
222 col_offset = node.col_offset
223 if 'end_col_offset' in node._attributes:
224 if getattr(node, 'end_col_offset', None) is None:
225 node.end_col_offset = end_col_offset
226 else:
227 end_col_offset = node.end_col_offset
228 for child in iter_child_nodes(node):
229 _fix(child, lineno, col_offset, end_lineno, end_col_offset)
230 _fix(node, 1, 0, 1, 0)
231 return node
232
233
234 def increment_lineno(node, n=1):
235 """
236 Increment the line number and end line number of each node in the tree
237 starting at *node* by *n*. This is useful to "move code" to a different
238 location in a file.
239 """
240 for child in walk(node):
241 # TypeIgnore is a special case where lineno is not an attribute
242 # but rather a field of the node itself.
243 if isinstance(child, TypeIgnore):
244 child.lineno = getattr(child, 'lineno', 0) + n
245 continue
246
247 if 'lineno' in child._attributes:
248 child.lineno = getattr(child, 'lineno', 0) + n
249 if (
250 "end_lineno" in child._attributes
251 and (end_lineno := getattr(child, "end_lineno", 0)) is not None
252 ):
253 child.end_lineno = end_lineno + n
254 return node
255
256
257 def iter_fields(node):
258 """
259 Yield a tuple of ``(fieldname, value)`` for each field in ``node._fields``
260 that is present on *node*.
261 """
262 for field in node._fields:
263 try:
264 yield field, getattr(node, field)
265 except AttributeError:
266 pass
267
268
269 def iter_child_nodes(node):
270 """
271 Yield all direct child nodes of *node*, that is, all fields that are nodes
272 and all items of fields that are lists of nodes.
273 """
274 for name, field in iter_fields(node):
275 if isinstance(field, AST):
276 yield field
277 elif isinstance(field, list):
278 for item in field:
279 if isinstance(item, AST):
280 yield item
281
282
283 def get_docstring(node, clean=True):
284 """
285 Return the docstring for the given node or None if no docstring can
286 be found. If the node provided does not have docstrings a TypeError
287 will be raised.
288
289 If *clean* is `True`, all tabs are expanded to spaces and any whitespace
290 that can be uniformly removed from the second line onwards is removed.
291 """
292 if not isinstance(node, (AsyncFunctionDef, FunctionDef, ClassDef, Module)):
293 raise TypeError("%r can't have docstrings" % node.__class__.__name__)
294 if not(node.body and isinstance(node.body[0], Expr)):
295 return None
296 node = node.body[0].value
297 if isinstance(node, Constant) and isinstance(node.value, str):
298 text = node.value
299 else:
300 return None
301 if clean:
302 import inspect
303 text = inspect.cleandoc(text)
304 return text
305
306
307 _line_pattern = re.compile(r"(.*?(?:\r\n|\n|\r|$))")
308 def _splitlines_no_ff(source, maxlines=None):
309 """Split a string into lines ignoring form feed and other chars.
310
311 This mimics how the Python parser splits source code.
312 """
313 lines = []
314 for lineno, match in enumerate(_line_pattern.finditer(source), 1):
315 if maxlines is not None and lineno > maxlines:
316 break
317 lines.append(match[0])
318 return lines
319
320
321 def _pad_whitespace(source):
322 r"""Replace all chars except '\f\t' in a line with spaces."""
323 result = ''
324 for c in source:
325 if c in '\f\t':
326 result += c
327 else:
328 result += ' '
329 return result
330
331
332 def get_source_segment(source, node, *, padded=False):
333 """Get source code segment of the *source* that generated *node*.
334
335 If some location information (`lineno`, `end_lineno`, `col_offset`,
336 or `end_col_offset`) is missing, return None.
337
338 If *padded* is `True`, the first line of a multi-line statement will
339 be padded with spaces to match its original position.
340 """
341 try:
342 if node.end_lineno is None or node.end_col_offset is None:
343 return None
344 lineno = node.lineno - 1
345 end_lineno = node.end_lineno - 1
346 col_offset = node.col_offset
347 end_col_offset = node.end_col_offset
348 except AttributeError:
349 return None
350
351 lines = _splitlines_no_ff(source, maxlines=end_lineno+1)
352 if end_lineno == lineno:
353 return lines[lineno].encode()[col_offset:end_col_offset].decode()
354
355 if padded:
356 padding = _pad_whitespace(lines[lineno].encode()[:col_offset].decode())
357 else:
358 padding = ''
359
360 first = padding + lines[lineno].encode()[col_offset:].decode()
361 last = lines[end_lineno].encode()[:end_col_offset].decode()
362 lines = lines[lineno+1:end_lineno]
363
364 lines.insert(0, first)
365 lines.append(last)
366 return ''.join(lines)
367
368
369 def walk(node):
370 """
371 Recursively yield all descendant nodes in the tree starting at *node*
372 (including *node* itself), in no specified order. This is useful if you
373 only want to modify nodes in place and don't care about the context.
374 """
375 from collections import deque
376 todo = deque([node])
377 while todo:
378 node = todo.popleft()
379 todo.extend(iter_child_nodes(node))
380 yield node
381
382
383 class ESC[4;38;5;81mNodeVisitor(ESC[4;38;5;149mobject):
384 """
385 A node visitor base class that walks the abstract syntax tree and calls a
386 visitor function for every node found. This function may return a value
387 which is forwarded by the `visit` method.
388
389 This class is meant to be subclassed, with the subclass adding visitor
390 methods.
391
392 Per default the visitor functions for the nodes are ``'visit_'`` +
393 class name of the node. So a `TryFinally` node visit function would
394 be `visit_TryFinally`. This behavior can be changed by overriding
395 the `visit` method. If no visitor function exists for a node
396 (return value `None`) the `generic_visit` visitor is used instead.
397
398 Don't use the `NodeVisitor` if you want to apply changes to nodes during
399 traversing. For this a special visitor exists (`NodeTransformer`) that
400 allows modifications.
401 """
402
403 def visit(self, node):
404 """Visit a node."""
405 method = 'visit_' + node.__class__.__name__
406 visitor = getattr(self, method, self.generic_visit)
407 return visitor(node)
408
409 def generic_visit(self, node):
410 """Called if no explicit visitor function exists for a node."""
411 for field, value in iter_fields(node):
412 if isinstance(value, list):
413 for item in value:
414 if isinstance(item, AST):
415 self.visit(item)
416 elif isinstance(value, AST):
417 self.visit(value)
418
419 def visit_Constant(self, node):
420 value = node.value
421 type_name = _const_node_type_names.get(type(value))
422 if type_name is None:
423 for cls, name in _const_node_type_names.items():
424 if isinstance(value, cls):
425 type_name = name
426 break
427 if type_name is not None:
428 method = 'visit_' + type_name
429 try:
430 visitor = getattr(self, method)
431 except AttributeError:
432 pass
433 else:
434 import warnings
435 warnings.warn(f"{method} is deprecated; add visit_Constant",
436 DeprecationWarning, 2)
437 return visitor(node)
438 return self.generic_visit(node)
439
440
441 class ESC[4;38;5;81mNodeTransformer(ESC[4;38;5;149mNodeVisitor):
442 """
443 A :class:`NodeVisitor` subclass that walks the abstract syntax tree and
444 allows modification of nodes.
445
446 The `NodeTransformer` will walk the AST and use the return value of the
447 visitor methods to replace or remove the old node. If the return value of
448 the visitor method is ``None``, the node will be removed from its location,
449 otherwise it is replaced with the return value. The return value may be the
450 original node in which case no replacement takes place.
451
452 Here is an example transformer that rewrites all occurrences of name lookups
453 (``foo``) to ``data['foo']``::
454
455 class RewriteName(NodeTransformer):
456
457 def visit_Name(self, node):
458 return Subscript(
459 value=Name(id='data', ctx=Load()),
460 slice=Constant(value=node.id),
461 ctx=node.ctx
462 )
463
464 Keep in mind that if the node you're operating on has child nodes you must
465 either transform the child nodes yourself or call the :meth:`generic_visit`
466 method for the node first.
467
468 For nodes that were part of a collection of statements (that applies to all
469 statement nodes), the visitor may also return a list of nodes rather than
470 just a single node.
471
472 Usually you use the transformer like this::
473
474 node = YourTransformer().visit(node)
475 """
476
477 def generic_visit(self, node):
478 for field, old_value in iter_fields(node):
479 if isinstance(old_value, list):
480 new_values = []
481 for value in old_value:
482 if isinstance(value, AST):
483 value = self.visit(value)
484 if value is None:
485 continue
486 elif not isinstance(value, AST):
487 new_values.extend(value)
488 continue
489 new_values.append(value)
490 old_value[:] = new_values
491 elif isinstance(old_value, AST):
492 new_node = self.visit(old_value)
493 if new_node is None:
494 delattr(node, field)
495 else:
496 setattr(node, field, new_node)
497 return node
498
499
500 _DEPRECATED_VALUE_ALIAS_MESSAGE = (
501 "{name} is deprecated and will be removed in Python {remove}; use value instead"
502 )
503 _DEPRECATED_CLASS_MESSAGE = (
504 "{name} is deprecated and will be removed in Python {remove}; "
505 "use ast.Constant instead"
506 )
507
508
509 # If the ast module is loaded more than once, only add deprecated methods once
510 if not hasattr(Constant, 'n'):
511 # The following code is for backward compatibility.
512 # It will be removed in future.
513
514 def _n_getter(self):
515 """Deprecated. Use value instead."""
516 import warnings
517 warnings._deprecated(
518 "Attribute n", message=_DEPRECATED_VALUE_ALIAS_MESSAGE, remove=(3, 14)
519 )
520 return self.value
521
522 def _n_setter(self, value):
523 import warnings
524 warnings._deprecated(
525 "Attribute n", message=_DEPRECATED_VALUE_ALIAS_MESSAGE, remove=(3, 14)
526 )
527 self.value = value
528
529 def _s_getter(self):
530 """Deprecated. Use value instead."""
531 import warnings
532 warnings._deprecated(
533 "Attribute s", message=_DEPRECATED_VALUE_ALIAS_MESSAGE, remove=(3, 14)
534 )
535 return self.value
536
537 def _s_setter(self, value):
538 import warnings
539 warnings._deprecated(
540 "Attribute s", message=_DEPRECATED_VALUE_ALIAS_MESSAGE, remove=(3, 14)
541 )
542 self.value = value
543
544 Constant.n = property(_n_getter, _n_setter)
545 Constant.s = property(_s_getter, _s_setter)
546
547 class ESC[4;38;5;81m_ABC(ESC[4;38;5;149mtype):
548
549 def __init__(cls, *args):
550 cls.__doc__ = """Deprecated AST node class. Use ast.Constant instead"""
551
552 def __instancecheck__(cls, inst):
553 if cls in _const_types:
554 import warnings
555 warnings._deprecated(
556 f"ast.{cls.__qualname__}",
557 message=_DEPRECATED_CLASS_MESSAGE,
558 remove=(3, 14)
559 )
560 if not isinstance(inst, Constant):
561 return False
562 if cls in _const_types:
563 try:
564 value = inst.value
565 except AttributeError:
566 return False
567 else:
568 return (
569 isinstance(value, _const_types[cls]) and
570 not isinstance(value, _const_types_not.get(cls, ()))
571 )
572 return type.__instancecheck__(cls, inst)
573
574 def _new(cls, *args, **kwargs):
575 for key in kwargs:
576 if key not in cls._fields:
577 # arbitrary keyword arguments are accepted
578 continue
579 pos = cls._fields.index(key)
580 if pos < len(args):
581 raise TypeError(f"{cls.__name__} got multiple values for argument {key!r}")
582 if cls in _const_types:
583 import warnings
584 warnings._deprecated(
585 f"ast.{cls.__qualname__}", message=_DEPRECATED_CLASS_MESSAGE, remove=(3, 14)
586 )
587 return Constant(*args, **kwargs)
588 return Constant.__new__(cls, *args, **kwargs)
589
590 class ESC[4;38;5;81mNum(ESC[4;38;5;149mConstant, metaclass=ESC[4;38;5;149m_ABC):
591 _fields = ('n',)
592 __new__ = _new
593
594 class ESC[4;38;5;81mStr(ESC[4;38;5;149mConstant, metaclass=ESC[4;38;5;149m_ABC):
595 _fields = ('s',)
596 __new__ = _new
597
598 class ESC[4;38;5;81mBytes(ESC[4;38;5;149mConstant, metaclass=ESC[4;38;5;149m_ABC):
599 _fields = ('s',)
600 __new__ = _new
601
602 class ESC[4;38;5;81mNameConstant(ESC[4;38;5;149mConstant, metaclass=ESC[4;38;5;149m_ABC):
603 __new__ = _new
604
605 class ESC[4;38;5;81mEllipsis(ESC[4;38;5;149mConstant, metaclass=ESC[4;38;5;149m_ABC):
606 _fields = ()
607
608 def __new__(cls, *args, **kwargs):
609 if cls is _ast_Ellipsis:
610 import warnings
611 warnings._deprecated(
612 "ast.Ellipsis", message=_DEPRECATED_CLASS_MESSAGE, remove=(3, 14)
613 )
614 return Constant(..., *args, **kwargs)
615 return Constant.__new__(cls, *args, **kwargs)
616
617 # Keep another reference to Ellipsis in the global namespace
618 # so it can be referenced in Ellipsis.__new__
619 # (The original "Ellipsis" name is removed from the global namespace later on)
620 _ast_Ellipsis = Ellipsis
621
622 _const_types = {
623 Num: (int, float, complex),
624 Str: (str,),
625 Bytes: (bytes,),
626 NameConstant: (type(None), bool),
627 Ellipsis: (type(...),),
628 }
629 _const_types_not = {
630 Num: (bool,),
631 }
632
633 _const_node_type_names = {
634 bool: 'NameConstant', # should be before int
635 type(None): 'NameConstant',
636 int: 'Num',
637 float: 'Num',
638 complex: 'Num',
639 str: 'Str',
640 bytes: 'Bytes',
641 type(...): 'Ellipsis',
642 }
643
644 class ESC[4;38;5;81mslice(ESC[4;38;5;149mAST):
645 """Deprecated AST node class."""
646
647 class ESC[4;38;5;81mIndex(ESC[4;38;5;149mslice):
648 """Deprecated AST node class. Use the index value directly instead."""
649 def __new__(cls, value, **kwargs):
650 return value
651
652 class ESC[4;38;5;81mExtSlice(ESC[4;38;5;149mslice):
653 """Deprecated AST node class. Use ast.Tuple instead."""
654 def __new__(cls, dims=(), **kwargs):
655 return Tuple(list(dims), Load(), **kwargs)
656
657 # If the ast module is loaded more than once, only add deprecated methods once
658 if not hasattr(Tuple, 'dims'):
659 # The following code is for backward compatibility.
660 # It will be removed in future.
661
662 def _dims_getter(self):
663 """Deprecated. Use elts instead."""
664 return self.elts
665
666 def _dims_setter(self, value):
667 self.elts = value
668
669 Tuple.dims = property(_dims_getter, _dims_setter)
670
671 class ESC[4;38;5;81mSuite(ESC[4;38;5;149mmod):
672 """Deprecated AST node class. Unused in Python 3."""
673
674 class ESC[4;38;5;81mAugLoad(ESC[4;38;5;149mexpr_context):
675 """Deprecated AST node class. Unused in Python 3."""
676
677 class ESC[4;38;5;81mAugStore(ESC[4;38;5;149mexpr_context):
678 """Deprecated AST node class. Unused in Python 3."""
679
680 class ESC[4;38;5;81mParam(ESC[4;38;5;149mexpr_context):
681 """Deprecated AST node class. Unused in Python 3."""
682
683
684 # Large float and imaginary literals get turned into infinities in the AST.
685 # We unparse those infinities to INFSTR.
686 _INFSTR = "1e" + repr(sys.float_info.max_10_exp + 1)
687
688 @_simple_enum(IntEnum)
689 class ESC[4;38;5;81m_Precedence:
690 """Precedence table that originated from python grammar."""
691
692 NAMED_EXPR = auto() # <target> := <expr1>
693 TUPLE = auto() # <expr1>, <expr2>
694 YIELD = auto() # 'yield', 'yield from'
695 TEST = auto() # 'if'-'else', 'lambda'
696 OR = auto() # 'or'
697 AND = auto() # 'and'
698 NOT = auto() # 'not'
699 CMP = auto() # '<', '>', '==', '>=', '<=', '!=',
700 # 'in', 'not in', 'is', 'is not'
701 EXPR = auto()
702 BOR = EXPR # '|'
703 BXOR = auto() # '^'
704 BAND = auto() # '&'
705 SHIFT = auto() # '<<', '>>'
706 ARITH = auto() # '+', '-'
707 TERM = auto() # '*', '@', '/', '%', '//'
708 FACTOR = auto() # unary '+', '-', '~'
709 POWER = auto() # '**'
710 AWAIT = auto() # 'await'
711 ATOM = auto()
712
713 def next(self):
714 try:
715 return self.__class__(self + 1)
716 except ValueError:
717 return self
718
719
720 _SINGLE_QUOTES = ("'", '"')
721 _MULTI_QUOTES = ('"""', "'''")
722 _ALL_QUOTES = (*_SINGLE_QUOTES, *_MULTI_QUOTES)
723
724 class ESC[4;38;5;81m_Unparser(ESC[4;38;5;149mNodeVisitor):
725 """Methods in this class recursively traverse an AST and
726 output source code for the abstract syntax; original formatting
727 is disregarded."""
728
729 def __init__(self, *, _avoid_backslashes=False):
730 self._source = []
731 self._precedences = {}
732 self._type_ignores = {}
733 self._indent = 0
734 self._avoid_backslashes = _avoid_backslashes
735 self._in_try_star = False
736
737 def interleave(self, inter, f, seq):
738 """Call f on each item in seq, calling inter() in between."""
739 seq = iter(seq)
740 try:
741 f(next(seq))
742 except StopIteration:
743 pass
744 else:
745 for x in seq:
746 inter()
747 f(x)
748
749 def items_view(self, traverser, items):
750 """Traverse and separate the given *items* with a comma and append it to
751 the buffer. If *items* is a single item sequence, a trailing comma
752 will be added."""
753 if len(items) == 1:
754 traverser(items[0])
755 self.write(",")
756 else:
757 self.interleave(lambda: self.write(", "), traverser, items)
758
759 def maybe_newline(self):
760 """Adds a newline if it isn't the start of generated source"""
761 if self._source:
762 self.write("\n")
763
764 def fill(self, text=""):
765 """Indent a piece of text and append it, according to the current
766 indentation level"""
767 self.maybe_newline()
768 self.write(" " * self._indent + text)
769
770 def write(self, *text):
771 """Add new source parts"""
772 self._source.extend(text)
773
774 @contextmanager
775 def buffered(self, buffer = None):
776 if buffer is None:
777 buffer = []
778
779 original_source = self._source
780 self._source = buffer
781 yield buffer
782 self._source = original_source
783
784 @contextmanager
785 def block(self, *, extra = None):
786 """A context manager for preparing the source for blocks. It adds
787 the character':', increases the indentation on enter and decreases
788 the indentation on exit. If *extra* is given, it will be directly
789 appended after the colon character.
790 """
791 self.write(":")
792 if extra:
793 self.write(extra)
794 self._indent += 1
795 yield
796 self._indent -= 1
797
798 @contextmanager
799 def delimit(self, start, end):
800 """A context manager for preparing the source for expressions. It adds
801 *start* to the buffer and enters, after exit it adds *end*."""
802
803 self.write(start)
804 yield
805 self.write(end)
806
807 def delimit_if(self, start, end, condition):
808 if condition:
809 return self.delimit(start, end)
810 else:
811 return nullcontext()
812
813 def require_parens(self, precedence, node):
814 """Shortcut to adding precedence related parens"""
815 return self.delimit_if("(", ")", self.get_precedence(node) > precedence)
816
817 def get_precedence(self, node):
818 return self._precedences.get(node, _Precedence.TEST)
819
820 def set_precedence(self, precedence, *nodes):
821 for node in nodes:
822 self._precedences[node] = precedence
823
824 def get_raw_docstring(self, node):
825 """If a docstring node is found in the body of the *node* parameter,
826 return that docstring node, None otherwise.
827
828 Logic mirrored from ``_PyAST_GetDocString``."""
829 if not isinstance(
830 node, (AsyncFunctionDef, FunctionDef, ClassDef, Module)
831 ) or len(node.body) < 1:
832 return None
833 node = node.body[0]
834 if not isinstance(node, Expr):
835 return None
836 node = node.value
837 if isinstance(node, Constant) and isinstance(node.value, str):
838 return node
839
840 def get_type_comment(self, node):
841 comment = self._type_ignores.get(node.lineno) or node.type_comment
842 if comment is not None:
843 return f" # type: {comment}"
844
845 def traverse(self, node):
846 if isinstance(node, list):
847 for item in node:
848 self.traverse(item)
849 else:
850 super().visit(node)
851
852 # Note: as visit() resets the output text, do NOT rely on
853 # NodeVisitor.generic_visit to handle any nodes (as it calls back in to
854 # the subclass visit() method, which resets self._source to an empty list)
855 def visit(self, node):
856 """Outputs a source code string that, if converted back to an ast
857 (using ast.parse) will generate an AST equivalent to *node*"""
858 self._source = []
859 self.traverse(node)
860 return "".join(self._source)
861
862 def _write_docstring_and_traverse_body(self, node):
863 if (docstring := self.get_raw_docstring(node)):
864 self._write_docstring(docstring)
865 self.traverse(node.body[1:])
866 else:
867 self.traverse(node.body)
868
869 def visit_Module(self, node):
870 self._type_ignores = {
871 ignore.lineno: f"ignore{ignore.tag}"
872 for ignore in node.type_ignores
873 }
874 self._write_docstring_and_traverse_body(node)
875 self._type_ignores.clear()
876
877 def visit_FunctionType(self, node):
878 with self.delimit("(", ")"):
879 self.interleave(
880 lambda: self.write(", "), self.traverse, node.argtypes
881 )
882
883 self.write(" -> ")
884 self.traverse(node.returns)
885
886 def visit_Expr(self, node):
887 self.fill()
888 self.set_precedence(_Precedence.YIELD, node.value)
889 self.traverse(node.value)
890
891 def visit_NamedExpr(self, node):
892 with self.require_parens(_Precedence.NAMED_EXPR, node):
893 self.set_precedence(_Precedence.ATOM, node.target, node.value)
894 self.traverse(node.target)
895 self.write(" := ")
896 self.traverse(node.value)
897
898 def visit_Import(self, node):
899 self.fill("import ")
900 self.interleave(lambda: self.write(", "), self.traverse, node.names)
901
902 def visit_ImportFrom(self, node):
903 self.fill("from ")
904 self.write("." * (node.level or 0))
905 if node.module:
906 self.write(node.module)
907 self.write(" import ")
908 self.interleave(lambda: self.write(", "), self.traverse, node.names)
909
910 def visit_Assign(self, node):
911 self.fill()
912 for target in node.targets:
913 self.set_precedence(_Precedence.TUPLE, target)
914 self.traverse(target)
915 self.write(" = ")
916 self.traverse(node.value)
917 if type_comment := self.get_type_comment(node):
918 self.write(type_comment)
919
920 def visit_AugAssign(self, node):
921 self.fill()
922 self.traverse(node.target)
923 self.write(" " + self.binop[node.op.__class__.__name__] + "= ")
924 self.traverse(node.value)
925
926 def visit_AnnAssign(self, node):
927 self.fill()
928 with self.delimit_if("(", ")", not node.simple and isinstance(node.target, Name)):
929 self.traverse(node.target)
930 self.write(": ")
931 self.traverse(node.annotation)
932 if node.value:
933 self.write(" = ")
934 self.traverse(node.value)
935
936 def visit_Return(self, node):
937 self.fill("return")
938 if node.value:
939 self.write(" ")
940 self.traverse(node.value)
941
942 def visit_Pass(self, node):
943 self.fill("pass")
944
945 def visit_Break(self, node):
946 self.fill("break")
947
948 def visit_Continue(self, node):
949 self.fill("continue")
950
951 def visit_Delete(self, node):
952 self.fill("del ")
953 self.interleave(lambda: self.write(", "), self.traverse, node.targets)
954
955 def visit_Assert(self, node):
956 self.fill("assert ")
957 self.traverse(node.test)
958 if node.msg:
959 self.write(", ")
960 self.traverse(node.msg)
961
962 def visit_Global(self, node):
963 self.fill("global ")
964 self.interleave(lambda: self.write(", "), self.write, node.names)
965
966 def visit_Nonlocal(self, node):
967 self.fill("nonlocal ")
968 self.interleave(lambda: self.write(", "), self.write, node.names)
969
970 def visit_Await(self, node):
971 with self.require_parens(_Precedence.AWAIT, node):
972 self.write("await")
973 if node.value:
974 self.write(" ")
975 self.set_precedence(_Precedence.ATOM, node.value)
976 self.traverse(node.value)
977
978 def visit_Yield(self, node):
979 with self.require_parens(_Precedence.YIELD, node):
980 self.write("yield")
981 if node.value:
982 self.write(" ")
983 self.set_precedence(_Precedence.ATOM, node.value)
984 self.traverse(node.value)
985
986 def visit_YieldFrom(self, node):
987 with self.require_parens(_Precedence.YIELD, node):
988 self.write("yield from ")
989 if not node.value:
990 raise ValueError("Node can't be used without a value attribute.")
991 self.set_precedence(_Precedence.ATOM, node.value)
992 self.traverse(node.value)
993
994 def visit_Raise(self, node):
995 self.fill("raise")
996 if not node.exc:
997 if node.cause:
998 raise ValueError(f"Node can't use cause without an exception.")
999 return
1000 self.write(" ")
1001 self.traverse(node.exc)
1002 if node.cause:
1003 self.write(" from ")
1004 self.traverse(node.cause)
1005
1006 def do_visit_try(self, node):
1007 self.fill("try")
1008 with self.block():
1009 self.traverse(node.body)
1010 for ex in node.handlers:
1011 self.traverse(ex)
1012 if node.orelse:
1013 self.fill("else")
1014 with self.block():
1015 self.traverse(node.orelse)
1016 if node.finalbody:
1017 self.fill("finally")
1018 with self.block():
1019 self.traverse(node.finalbody)
1020
1021 def visit_Try(self, node):
1022 prev_in_try_star = self._in_try_star
1023 try:
1024 self._in_try_star = False
1025 self.do_visit_try(node)
1026 finally:
1027 self._in_try_star = prev_in_try_star
1028
1029 def visit_TryStar(self, node):
1030 prev_in_try_star = self._in_try_star
1031 try:
1032 self._in_try_star = True
1033 self.do_visit_try(node)
1034 finally:
1035 self._in_try_star = prev_in_try_star
1036
1037 def visit_ExceptHandler(self, node):
1038 self.fill("except*" if self._in_try_star else "except")
1039 if node.type:
1040 self.write(" ")
1041 self.traverse(node.type)
1042 if node.name:
1043 self.write(" as ")
1044 self.write(node.name)
1045 with self.block():
1046 self.traverse(node.body)
1047
1048 def visit_ClassDef(self, node):
1049 self.maybe_newline()
1050 for deco in node.decorator_list:
1051 self.fill("@")
1052 self.traverse(deco)
1053 self.fill("class " + node.name)
1054 if hasattr(node, "type_params"):
1055 self._type_params_helper(node.type_params)
1056 with self.delimit_if("(", ")", condition = node.bases or node.keywords):
1057 comma = False
1058 for e in node.bases:
1059 if comma:
1060 self.write(", ")
1061 else:
1062 comma = True
1063 self.traverse(e)
1064 for e in node.keywords:
1065 if comma:
1066 self.write(", ")
1067 else:
1068 comma = True
1069 self.traverse(e)
1070
1071 with self.block():
1072 self._write_docstring_and_traverse_body(node)
1073
1074 def visit_FunctionDef(self, node):
1075 self._function_helper(node, "def")
1076
1077 def visit_AsyncFunctionDef(self, node):
1078 self._function_helper(node, "async def")
1079
1080 def _function_helper(self, node, fill_suffix):
1081 self.maybe_newline()
1082 for deco in node.decorator_list:
1083 self.fill("@")
1084 self.traverse(deco)
1085 def_str = fill_suffix + " " + node.name
1086 self.fill(def_str)
1087 if hasattr(node, "type_params"):
1088 self._type_params_helper(node.type_params)
1089 with self.delimit("(", ")"):
1090 self.traverse(node.args)
1091 if node.returns:
1092 self.write(" -> ")
1093 self.traverse(node.returns)
1094 with self.block(extra=self.get_type_comment(node)):
1095 self._write_docstring_and_traverse_body(node)
1096
1097 def _type_params_helper(self, type_params):
1098 if type_params is not None and len(type_params) > 0:
1099 with self.delimit("[", "]"):
1100 self.interleave(lambda: self.write(", "), self.traverse, type_params)
1101
1102 def visit_TypeVar(self, node):
1103 self.write(node.name)
1104 if node.bound:
1105 self.write(": ")
1106 self.traverse(node.bound)
1107
1108 def visit_TypeVarTuple(self, node):
1109 self.write("*" + node.name)
1110
1111 def visit_ParamSpec(self, node):
1112 self.write("**" + node.name)
1113
1114 def visit_TypeAlias(self, node):
1115 self.fill("type ")
1116 self.traverse(node.name)
1117 self._type_params_helper(node.type_params)
1118 self.write(" = ")
1119 self.traverse(node.value)
1120
1121 def visit_For(self, node):
1122 self._for_helper("for ", node)
1123
1124 def visit_AsyncFor(self, node):
1125 self._for_helper("async for ", node)
1126
1127 def _for_helper(self, fill, node):
1128 self.fill(fill)
1129 self.set_precedence(_Precedence.TUPLE, node.target)
1130 self.traverse(node.target)
1131 self.write(" in ")
1132 self.traverse(node.iter)
1133 with self.block(extra=self.get_type_comment(node)):
1134 self.traverse(node.body)
1135 if node.orelse:
1136 self.fill("else")
1137 with self.block():
1138 self.traverse(node.orelse)
1139
1140 def visit_If(self, node):
1141 self.fill("if ")
1142 self.traverse(node.test)
1143 with self.block():
1144 self.traverse(node.body)
1145 # collapse nested ifs into equivalent elifs.
1146 while node.orelse and len(node.orelse) == 1 and isinstance(node.orelse[0], If):
1147 node = node.orelse[0]
1148 self.fill("elif ")
1149 self.traverse(node.test)
1150 with self.block():
1151 self.traverse(node.body)
1152 # final else
1153 if node.orelse:
1154 self.fill("else")
1155 with self.block():
1156 self.traverse(node.orelse)
1157
1158 def visit_While(self, node):
1159 self.fill("while ")
1160 self.traverse(node.test)
1161 with self.block():
1162 self.traverse(node.body)
1163 if node.orelse:
1164 self.fill("else")
1165 with self.block():
1166 self.traverse(node.orelse)
1167
1168 def visit_With(self, node):
1169 self.fill("with ")
1170 self.interleave(lambda: self.write(", "), self.traverse, node.items)
1171 with self.block(extra=self.get_type_comment(node)):
1172 self.traverse(node.body)
1173
1174 def visit_AsyncWith(self, node):
1175 self.fill("async with ")
1176 self.interleave(lambda: self.write(", "), self.traverse, node.items)
1177 with self.block(extra=self.get_type_comment(node)):
1178 self.traverse(node.body)
1179
1180 def _str_literal_helper(
1181 self, string, *, quote_types=_ALL_QUOTES, escape_special_whitespace=False
1182 ):
1183 """Helper for writing string literals, minimizing escapes.
1184 Returns the tuple (string literal to write, possible quote types).
1185 """
1186 def escape_char(c):
1187 # \n and \t are non-printable, but we only escape them if
1188 # escape_special_whitespace is True
1189 if not escape_special_whitespace and c in "\n\t":
1190 return c
1191 # Always escape backslashes and other non-printable characters
1192 if c == "\\" or not c.isprintable():
1193 return c.encode("unicode_escape").decode("ascii")
1194 return c
1195
1196 escaped_string = "".join(map(escape_char, string))
1197 possible_quotes = quote_types
1198 if "\n" in escaped_string:
1199 possible_quotes = [q for q in possible_quotes if q in _MULTI_QUOTES]
1200 possible_quotes = [q for q in possible_quotes if q not in escaped_string]
1201 if not possible_quotes:
1202 # If there aren't any possible_quotes, fallback to using repr
1203 # on the original string. Try to use a quote from quote_types,
1204 # e.g., so that we use triple quotes for docstrings.
1205 string = repr(string)
1206 quote = next((q for q in quote_types if string[0] in q), string[0])
1207 return string[1:-1], [quote]
1208 if escaped_string:
1209 # Sort so that we prefer '''"''' over """\""""
1210 possible_quotes.sort(key=lambda q: q[0] == escaped_string[-1])
1211 # If we're using triple quotes and we'd need to escape a final
1212 # quote, escape it
1213 if possible_quotes[0][0] == escaped_string[-1]:
1214 assert len(possible_quotes[0]) == 3
1215 escaped_string = escaped_string[:-1] + "\\" + escaped_string[-1]
1216 return escaped_string, possible_quotes
1217
1218 def _write_str_avoiding_backslashes(self, string, *, quote_types=_ALL_QUOTES):
1219 """Write string literal value with a best effort attempt to avoid backslashes."""
1220 string, quote_types = self._str_literal_helper(string, quote_types=quote_types)
1221 quote_type = quote_types[0]
1222 self.write(f"{quote_type}{string}{quote_type}")
1223
1224 def visit_JoinedStr(self, node):
1225 self.write("f")
1226
1227 fstring_parts = []
1228 for value in node.values:
1229 with self.buffered() as buffer:
1230 self._write_fstring_inner(value)
1231 fstring_parts.append(
1232 ("".join(buffer), isinstance(value, Constant))
1233 )
1234
1235 new_fstring_parts = []
1236 quote_types = list(_ALL_QUOTES)
1237 fallback_to_repr = False
1238 for value, is_constant in fstring_parts:
1239 if is_constant:
1240 value, new_quote_types = self._str_literal_helper(
1241 value,
1242 quote_types=quote_types,
1243 escape_special_whitespace=True,
1244 )
1245 if set(new_quote_types).isdisjoint(quote_types):
1246 fallback_to_repr = True
1247 break
1248 quote_types = new_quote_types
1249 elif "\n" in value:
1250 quote_types = [q for q in quote_types if q in _MULTI_QUOTES]
1251 assert quote_types
1252 new_fstring_parts.append(value)
1253
1254 if fallback_to_repr:
1255 # If we weren't able to find a quote type that works for all parts
1256 # of the JoinedStr, fallback to using repr and triple single quotes.
1257 quote_types = ["'''"]
1258 new_fstring_parts.clear()
1259 for value, is_constant in fstring_parts:
1260 if is_constant:
1261 value = repr('"' + value) # force repr to use single quotes
1262 expected_prefix = "'\""
1263 assert value.startswith(expected_prefix), repr(value)
1264 value = value[len(expected_prefix):-1]
1265 new_fstring_parts.append(value)
1266
1267 value = "".join(new_fstring_parts)
1268 quote_type = quote_types[0]
1269 self.write(f"{quote_type}{value}{quote_type}")
1270
1271 def _write_fstring_inner(self, node):
1272 if isinstance(node, JoinedStr):
1273 # for both the f-string itself, and format_spec
1274 for value in node.values:
1275 self._write_fstring_inner(value)
1276 elif isinstance(node, Constant) and isinstance(node.value, str):
1277 value = node.value.replace("{", "{{").replace("}", "}}")
1278 self.write(value)
1279 elif isinstance(node, FormattedValue):
1280 self.visit_FormattedValue(node)
1281 else:
1282 raise ValueError(f"Unexpected node inside JoinedStr, {node!r}")
1283
1284 def visit_FormattedValue(self, node):
1285 def unparse_inner(inner):
1286 unparser = type(self)()
1287 unparser.set_precedence(_Precedence.TEST.next(), inner)
1288 return unparser.visit(inner)
1289
1290 with self.delimit("{", "}"):
1291 expr = unparse_inner(node.value)
1292 if expr.startswith("{"):
1293 # Separate pair of opening brackets as "{ {"
1294 self.write(" ")
1295 self.write(expr)
1296 if node.conversion != -1:
1297 self.write(f"!{chr(node.conversion)}")
1298 if node.format_spec:
1299 self.write(":")
1300 self._write_fstring_inner(node.format_spec)
1301
1302 def visit_Name(self, node):
1303 self.write(node.id)
1304
1305 def _write_docstring(self, node):
1306 self.fill()
1307 if node.kind == "u":
1308 self.write("u")
1309 self._write_str_avoiding_backslashes(node.value, quote_types=_MULTI_QUOTES)
1310
1311 def _write_constant(self, value):
1312 if isinstance(value, (float, complex)):
1313 # Substitute overflowing decimal literal for AST infinities,
1314 # and inf - inf for NaNs.
1315 self.write(
1316 repr(value)
1317 .replace("inf", _INFSTR)
1318 .replace("nan", f"({_INFSTR}-{_INFSTR})")
1319 )
1320 elif self._avoid_backslashes and isinstance(value, str):
1321 self._write_str_avoiding_backslashes(value)
1322 else:
1323 self.write(repr(value))
1324
1325 def visit_Constant(self, node):
1326 value = node.value
1327 if isinstance(value, tuple):
1328 with self.delimit("(", ")"):
1329 self.items_view(self._write_constant, value)
1330 elif value is ...:
1331 self.write("...")
1332 else:
1333 if node.kind == "u":
1334 self.write("u")
1335 self._write_constant(node.value)
1336
1337 def visit_List(self, node):
1338 with self.delimit("[", "]"):
1339 self.interleave(lambda: self.write(", "), self.traverse, node.elts)
1340
1341 def visit_ListComp(self, node):
1342 with self.delimit("[", "]"):
1343 self.traverse(node.elt)
1344 for gen in node.generators:
1345 self.traverse(gen)
1346
1347 def visit_GeneratorExp(self, node):
1348 with self.delimit("(", ")"):
1349 self.traverse(node.elt)
1350 for gen in node.generators:
1351 self.traverse(gen)
1352
1353 def visit_SetComp(self, node):
1354 with self.delimit("{", "}"):
1355 self.traverse(node.elt)
1356 for gen in node.generators:
1357 self.traverse(gen)
1358
1359 def visit_DictComp(self, node):
1360 with self.delimit("{", "}"):
1361 self.traverse(node.key)
1362 self.write(": ")
1363 self.traverse(node.value)
1364 for gen in node.generators:
1365 self.traverse(gen)
1366
1367 def visit_comprehension(self, node):
1368 if node.is_async:
1369 self.write(" async for ")
1370 else:
1371 self.write(" for ")
1372 self.set_precedence(_Precedence.TUPLE, node.target)
1373 self.traverse(node.target)
1374 self.write(" in ")
1375 self.set_precedence(_Precedence.TEST.next(), node.iter, *node.ifs)
1376 self.traverse(node.iter)
1377 for if_clause in node.ifs:
1378 self.write(" if ")
1379 self.traverse(if_clause)
1380
1381 def visit_IfExp(self, node):
1382 with self.require_parens(_Precedence.TEST, node):
1383 self.set_precedence(_Precedence.TEST.next(), node.body, node.test)
1384 self.traverse(node.body)
1385 self.write(" if ")
1386 self.traverse(node.test)
1387 self.write(" else ")
1388 self.set_precedence(_Precedence.TEST, node.orelse)
1389 self.traverse(node.orelse)
1390
1391 def visit_Set(self, node):
1392 if node.elts:
1393 with self.delimit("{", "}"):
1394 self.interleave(lambda: self.write(", "), self.traverse, node.elts)
1395 else:
1396 # `{}` would be interpreted as a dictionary literal, and
1397 # `set` might be shadowed. Thus:
1398 self.write('{*()}')
1399
1400 def visit_Dict(self, node):
1401 def write_key_value_pair(k, v):
1402 self.traverse(k)
1403 self.write(": ")
1404 self.traverse(v)
1405
1406 def write_item(item):
1407 k, v = item
1408 if k is None:
1409 # for dictionary unpacking operator in dicts {**{'y': 2}}
1410 # see PEP 448 for details
1411 self.write("**")
1412 self.set_precedence(_Precedence.EXPR, v)
1413 self.traverse(v)
1414 else:
1415 write_key_value_pair(k, v)
1416
1417 with self.delimit("{", "}"):
1418 self.interleave(
1419 lambda: self.write(", "), write_item, zip(node.keys, node.values)
1420 )
1421
1422 def visit_Tuple(self, node):
1423 with self.delimit_if(
1424 "(",
1425 ")",
1426 len(node.elts) == 0 or self.get_precedence(node) > _Precedence.TUPLE
1427 ):
1428 self.items_view(self.traverse, node.elts)
1429
1430 unop = {"Invert": "~", "Not": "not", "UAdd": "+", "USub": "-"}
1431 unop_precedence = {
1432 "not": _Precedence.NOT,
1433 "~": _Precedence.FACTOR,
1434 "+": _Precedence.FACTOR,
1435 "-": _Precedence.FACTOR,
1436 }
1437
1438 def visit_UnaryOp(self, node):
1439 operator = self.unop[node.op.__class__.__name__]
1440 operator_precedence = self.unop_precedence[operator]
1441 with self.require_parens(operator_precedence, node):
1442 self.write(operator)
1443 # factor prefixes (+, -, ~) shouldn't be separated
1444 # from the value they belong, (e.g: +1 instead of + 1)
1445 if operator_precedence is not _Precedence.FACTOR:
1446 self.write(" ")
1447 self.set_precedence(operator_precedence, node.operand)
1448 self.traverse(node.operand)
1449
1450 binop = {
1451 "Add": "+",
1452 "Sub": "-",
1453 "Mult": "*",
1454 "MatMult": "@",
1455 "Div": "/",
1456 "Mod": "%",
1457 "LShift": "<<",
1458 "RShift": ">>",
1459 "BitOr": "|",
1460 "BitXor": "^",
1461 "BitAnd": "&",
1462 "FloorDiv": "//",
1463 "Pow": "**",
1464 }
1465
1466 binop_precedence = {
1467 "+": _Precedence.ARITH,
1468 "-": _Precedence.ARITH,
1469 "*": _Precedence.TERM,
1470 "@": _Precedence.TERM,
1471 "/": _Precedence.TERM,
1472 "%": _Precedence.TERM,
1473 "<<": _Precedence.SHIFT,
1474 ">>": _Precedence.SHIFT,
1475 "|": _Precedence.BOR,
1476 "^": _Precedence.BXOR,
1477 "&": _Precedence.BAND,
1478 "//": _Precedence.TERM,
1479 "**": _Precedence.POWER,
1480 }
1481
1482 binop_rassoc = frozenset(("**",))
1483 def visit_BinOp(self, node):
1484 operator = self.binop[node.op.__class__.__name__]
1485 operator_precedence = self.binop_precedence[operator]
1486 with self.require_parens(operator_precedence, node):
1487 if operator in self.binop_rassoc:
1488 left_precedence = operator_precedence.next()
1489 right_precedence = operator_precedence
1490 else:
1491 left_precedence = operator_precedence
1492 right_precedence = operator_precedence.next()
1493
1494 self.set_precedence(left_precedence, node.left)
1495 self.traverse(node.left)
1496 self.write(f" {operator} ")
1497 self.set_precedence(right_precedence, node.right)
1498 self.traverse(node.right)
1499
1500 cmpops = {
1501 "Eq": "==",
1502 "NotEq": "!=",
1503 "Lt": "<",
1504 "LtE": "<=",
1505 "Gt": ">",
1506 "GtE": ">=",
1507 "Is": "is",
1508 "IsNot": "is not",
1509 "In": "in",
1510 "NotIn": "not in",
1511 }
1512
1513 def visit_Compare(self, node):
1514 with self.require_parens(_Precedence.CMP, node):
1515 self.set_precedence(_Precedence.CMP.next(), node.left, *node.comparators)
1516 self.traverse(node.left)
1517 for o, e in zip(node.ops, node.comparators):
1518 self.write(" " + self.cmpops[o.__class__.__name__] + " ")
1519 self.traverse(e)
1520
1521 boolops = {"And": "and", "Or": "or"}
1522 boolop_precedence = {"and": _Precedence.AND, "or": _Precedence.OR}
1523
1524 def visit_BoolOp(self, node):
1525 operator = self.boolops[node.op.__class__.__name__]
1526 operator_precedence = self.boolop_precedence[operator]
1527
1528 def increasing_level_traverse(node):
1529 nonlocal operator_precedence
1530 operator_precedence = operator_precedence.next()
1531 self.set_precedence(operator_precedence, node)
1532 self.traverse(node)
1533
1534 with self.require_parens(operator_precedence, node):
1535 s = f" {operator} "
1536 self.interleave(lambda: self.write(s), increasing_level_traverse, node.values)
1537
1538 def visit_Attribute(self, node):
1539 self.set_precedence(_Precedence.ATOM, node.value)
1540 self.traverse(node.value)
1541 # Special case: 3.__abs__() is a syntax error, so if node.value
1542 # is an integer literal then we need to either parenthesize
1543 # it or add an extra space to get 3 .__abs__().
1544 if isinstance(node.value, Constant) and isinstance(node.value.value, int):
1545 self.write(" ")
1546 self.write(".")
1547 self.write(node.attr)
1548
1549 def visit_Call(self, node):
1550 self.set_precedence(_Precedence.ATOM, node.func)
1551 self.traverse(node.func)
1552 with self.delimit("(", ")"):
1553 comma = False
1554 for e in node.args:
1555 if comma:
1556 self.write(", ")
1557 else:
1558 comma = True
1559 self.traverse(e)
1560 for e in node.keywords:
1561 if comma:
1562 self.write(", ")
1563 else:
1564 comma = True
1565 self.traverse(e)
1566
1567 def visit_Subscript(self, node):
1568 def is_non_empty_tuple(slice_value):
1569 return (
1570 isinstance(slice_value, Tuple)
1571 and slice_value.elts
1572 )
1573
1574 self.set_precedence(_Precedence.ATOM, node.value)
1575 self.traverse(node.value)
1576 with self.delimit("[", "]"):
1577 if is_non_empty_tuple(node.slice):
1578 # parentheses can be omitted if the tuple isn't empty
1579 self.items_view(self.traverse, node.slice.elts)
1580 else:
1581 self.traverse(node.slice)
1582
1583 def visit_Starred(self, node):
1584 self.write("*")
1585 self.set_precedence(_Precedence.EXPR, node.value)
1586 self.traverse(node.value)
1587
1588 def visit_Ellipsis(self, node):
1589 self.write("...")
1590
1591 def visit_Slice(self, node):
1592 if node.lower:
1593 self.traverse(node.lower)
1594 self.write(":")
1595 if node.upper:
1596 self.traverse(node.upper)
1597 if node.step:
1598 self.write(":")
1599 self.traverse(node.step)
1600
1601 def visit_Match(self, node):
1602 self.fill("match ")
1603 self.traverse(node.subject)
1604 with self.block():
1605 for case in node.cases:
1606 self.traverse(case)
1607
1608 def visit_arg(self, node):
1609 self.write(node.arg)
1610 if node.annotation:
1611 self.write(": ")
1612 self.traverse(node.annotation)
1613
1614 def visit_arguments(self, node):
1615 first = True
1616 # normal arguments
1617 all_args = node.posonlyargs + node.args
1618 defaults = [None] * (len(all_args) - len(node.defaults)) + node.defaults
1619 for index, elements in enumerate(zip(all_args, defaults), 1):
1620 a, d = elements
1621 if first:
1622 first = False
1623 else:
1624 self.write(", ")
1625 self.traverse(a)
1626 if d:
1627 self.write("=")
1628 self.traverse(d)
1629 if index == len(node.posonlyargs):
1630 self.write(", /")
1631
1632 # varargs, or bare '*' if no varargs but keyword-only arguments present
1633 if node.vararg or node.kwonlyargs:
1634 if first:
1635 first = False
1636 else:
1637 self.write(", ")
1638 self.write("*")
1639 if node.vararg:
1640 self.write(node.vararg.arg)
1641 if node.vararg.annotation:
1642 self.write(": ")
1643 self.traverse(node.vararg.annotation)
1644
1645 # keyword-only arguments
1646 if node.kwonlyargs:
1647 for a, d in zip(node.kwonlyargs, node.kw_defaults):
1648 self.write(", ")
1649 self.traverse(a)
1650 if d:
1651 self.write("=")
1652 self.traverse(d)
1653
1654 # kwargs
1655 if node.kwarg:
1656 if first:
1657 first = False
1658 else:
1659 self.write(", ")
1660 self.write("**" + node.kwarg.arg)
1661 if node.kwarg.annotation:
1662 self.write(": ")
1663 self.traverse(node.kwarg.annotation)
1664
1665 def visit_keyword(self, node):
1666 if node.arg is None:
1667 self.write("**")
1668 else:
1669 self.write(node.arg)
1670 self.write("=")
1671 self.traverse(node.value)
1672
1673 def visit_Lambda(self, node):
1674 with self.require_parens(_Precedence.TEST, node):
1675 self.write("lambda")
1676 with self.buffered() as buffer:
1677 self.traverse(node.args)
1678 if buffer:
1679 self.write(" ", *buffer)
1680 self.write(": ")
1681 self.set_precedence(_Precedence.TEST, node.body)
1682 self.traverse(node.body)
1683
1684 def visit_alias(self, node):
1685 self.write(node.name)
1686 if node.asname:
1687 self.write(" as " + node.asname)
1688
1689 def visit_withitem(self, node):
1690 self.traverse(node.context_expr)
1691 if node.optional_vars:
1692 self.write(" as ")
1693 self.traverse(node.optional_vars)
1694
1695 def visit_match_case(self, node):
1696 self.fill("case ")
1697 self.traverse(node.pattern)
1698 if node.guard:
1699 self.write(" if ")
1700 self.traverse(node.guard)
1701 with self.block():
1702 self.traverse(node.body)
1703
1704 def visit_MatchValue(self, node):
1705 self.traverse(node.value)
1706
1707 def visit_MatchSingleton(self, node):
1708 self._write_constant(node.value)
1709
1710 def visit_MatchSequence(self, node):
1711 with self.delimit("[", "]"):
1712 self.interleave(
1713 lambda: self.write(", "), self.traverse, node.patterns
1714 )
1715
1716 def visit_MatchStar(self, node):
1717 name = node.name
1718 if name is None:
1719 name = "_"
1720 self.write(f"*{name}")
1721
1722 def visit_MatchMapping(self, node):
1723 def write_key_pattern_pair(pair):
1724 k, p = pair
1725 self.traverse(k)
1726 self.write(": ")
1727 self.traverse(p)
1728
1729 with self.delimit("{", "}"):
1730 keys = node.keys
1731 self.interleave(
1732 lambda: self.write(", "),
1733 write_key_pattern_pair,
1734 zip(keys, node.patterns, strict=True),
1735 )
1736 rest = node.rest
1737 if rest is not None:
1738 if keys:
1739 self.write(", ")
1740 self.write(f"**{rest}")
1741
1742 def visit_MatchClass(self, node):
1743 self.set_precedence(_Precedence.ATOM, node.cls)
1744 self.traverse(node.cls)
1745 with self.delimit("(", ")"):
1746 patterns = node.patterns
1747 self.interleave(
1748 lambda: self.write(", "), self.traverse, patterns
1749 )
1750 attrs = node.kwd_attrs
1751 if attrs:
1752 def write_attr_pattern(pair):
1753 attr, pattern = pair
1754 self.write(f"{attr}=")
1755 self.traverse(pattern)
1756
1757 if patterns:
1758 self.write(", ")
1759 self.interleave(
1760 lambda: self.write(", "),
1761 write_attr_pattern,
1762 zip(attrs, node.kwd_patterns, strict=True),
1763 )
1764
1765 def visit_MatchAs(self, node):
1766 name = node.name
1767 pattern = node.pattern
1768 if name is None:
1769 self.write("_")
1770 elif pattern is None:
1771 self.write(node.name)
1772 else:
1773 with self.require_parens(_Precedence.TEST, node):
1774 self.set_precedence(_Precedence.BOR, node.pattern)
1775 self.traverse(node.pattern)
1776 self.write(f" as {node.name}")
1777
1778 def visit_MatchOr(self, node):
1779 with self.require_parens(_Precedence.BOR, node):
1780 self.set_precedence(_Precedence.BOR.next(), *node.patterns)
1781 self.interleave(lambda: self.write(" | "), self.traverse, node.patterns)
1782
1783 def unparse(ast_obj):
1784 unparser = _Unparser()
1785 return unparser.visit(ast_obj)
1786
1787
1788 _deprecated_globals = {
1789 name: globals().pop(name)
1790 for name in ('Num', 'Str', 'Bytes', 'NameConstant', 'Ellipsis')
1791 }
1792
1793 def __getattr__(name):
1794 if name in _deprecated_globals:
1795 globals()[name] = value = _deprecated_globals[name]
1796 import warnings
1797 warnings._deprecated(
1798 f"ast.{name}", message=_DEPRECATED_CLASS_MESSAGE, remove=(3, 14)
1799 )
1800 return value
1801 raise AttributeError(f"module 'ast' has no attribute '{name}'")
1802
1803
1804 def main():
1805 import argparse
1806
1807 parser = argparse.ArgumentParser(prog='python -m ast')
1808 parser.add_argument('infile', type=argparse.FileType(mode='rb'), nargs='?',
1809 default='-',
1810 help='the file to parse; defaults to stdin')
1811 parser.add_argument('-m', '--mode', default='exec',
1812 choices=('exec', 'single', 'eval', 'func_type'),
1813 help='specify what kind of code must be parsed')
1814 parser.add_argument('--no-type-comments', default=True, action='store_false',
1815 help="don't add information about type comments")
1816 parser.add_argument('-a', '--include-attributes', action='store_true',
1817 help='include attributes such as line numbers and '
1818 'column offsets')
1819 parser.add_argument('-i', '--indent', type=int, default=3,
1820 help='indentation of nodes (number of spaces)')
1821 args = parser.parse_args()
1822
1823 with args.infile as infile:
1824 source = infile.read()
1825 tree = parse(source, args.infile.name, args.mode, type_comments=args.no_type_comments)
1826 print(dump(tree, include_attributes=args.include_attributes, indent=args.indent))
1827
1828 if __name__ == '__main__':
1829 main()