1 """
2 ast
3 ~~~
4
5 The `ast` module helps Python applications to process trees of the Python
6 abstract syntax grammar. The abstract syntax itself might change with
7 each Python release; this module helps to find out programmatically what
8 the current grammar looks like and allows modifications of it.
9
10 An abstract syntax tree can be generated by passing `ast.PyCF_ONLY_AST` as
11 a flag to the `compile()` builtin function or by using the `parse()`
12 function from this module. The result will be a tree of objects whose
13 classes all inherit from `ast.AST`.
14
15 A modified abstract syntax tree can be compiled into a Python code object
16 using the built-in `compile()` function.
17
18 Additionally various helper functions are provided that make working with
19 the trees simpler. The main intention of the helper functions and this
20 module in general is to provide an easy to use interface for libraries
21 that work tightly with the python syntax (template engines for example).
22
23
24 :copyright: Copyright 2008 by Armin Ronacher.
25 :license: Python License.
26 """
27 import sys
28 from _ast import *
29 from contextlib import contextmanager, nullcontext
30 from enum import IntEnum, auto, _simple_enum
31
32
33 def parse(source, filename='<unknown>', mode='exec', *,
34 type_comments=False, feature_version=None):
35 """
36 Parse the source into an AST node.
37 Equivalent to compile(source, filename, mode, PyCF_ONLY_AST).
38 Pass type_comments=True to get back type comments where the syntax allows.
39 """
40 flags = PyCF_ONLY_AST
41 if type_comments:
42 flags |= PyCF_TYPE_COMMENTS
43 if isinstance(feature_version, tuple):
44 major, minor = feature_version # Should be a 2-tuple.
45 assert major == 3
46 feature_version = minor
47 elif feature_version is None:
48 feature_version = -1
49 # Else it should be an int giving the minor version for 3.x.
50 return compile(source, filename, mode, flags,
51 _feature_version=feature_version)
52
53
54 def literal_eval(node_or_string):
55 """
56 Evaluate an expression node or a string containing only a Python
57 expression. The string or node provided may only consist of the following
58 Python literal structures: strings, bytes, numbers, tuples, lists, dicts,
59 sets, booleans, and None.
60
61 Caution: A complex expression can overflow the C stack and cause a crash.
62 """
63 if isinstance(node_or_string, str):
64 node_or_string = parse(node_or_string.lstrip(" \t"), mode='eval')
65 if isinstance(node_or_string, Expression):
66 node_or_string = node_or_string.body
67 def _raise_malformed_node(node):
68 msg = "malformed node or string"
69 if lno := getattr(node, 'lineno', None):
70 msg += f' on line {lno}'
71 raise ValueError(msg + f': {node!r}')
72 def _convert_num(node):
73 if not isinstance(node, Constant) or type(node.value) not in (int, float, complex):
74 _raise_malformed_node(node)
75 return node.value
76 def _convert_signed_num(node):
77 if isinstance(node, UnaryOp) and isinstance(node.op, (UAdd, USub)):
78 operand = _convert_num(node.operand)
79 if isinstance(node.op, UAdd):
80 return + operand
81 else:
82 return - operand
83 return _convert_num(node)
84 def _convert(node):
85 if isinstance(node, Constant):
86 return node.value
87 elif isinstance(node, Tuple):
88 return tuple(map(_convert, node.elts))
89 elif isinstance(node, List):
90 return list(map(_convert, node.elts))
91 elif isinstance(node, Set):
92 return set(map(_convert, node.elts))
93 elif (isinstance(node, Call) and isinstance(node.func, Name) and
94 node.func.id == 'set' and node.args == node.keywords == []):
95 return set()
96 elif isinstance(node, Dict):
97 if len(node.keys) != len(node.values):
98 _raise_malformed_node(node)
99 return dict(zip(map(_convert, node.keys),
100 map(_convert, node.values)))
101 elif isinstance(node, BinOp) and isinstance(node.op, (Add, Sub)):
102 left = _convert_signed_num(node.left)
103 right = _convert_num(node.right)
104 if isinstance(left, (int, float)) and isinstance(right, complex):
105 if isinstance(node.op, Add):
106 return left + right
107 else:
108 return left - right
109 return _convert_signed_num(node)
110 return _convert(node_or_string)
111
112
113 def dump(node, annotate_fields=True, include_attributes=False, *, indent=None):
114 """
115 Return a formatted dump of the tree in node. This is mainly useful for
116 debugging purposes. If annotate_fields is true (by default),
117 the returned string will show the names and the values for fields.
118 If annotate_fields is false, the result string will be more compact by
119 omitting unambiguous field names. Attributes such as line
120 numbers and column offsets are not dumped by default. If this is wanted,
121 include_attributes can be set to true. If indent is a non-negative
122 integer or string, then the tree will be pretty-printed with that indent
123 level. None (the default) selects the single line representation.
124 """
125 def _format(node, level=0):
126 if indent is not None:
127 level += 1
128 prefix = '\n' + indent * level
129 sep = ',\n' + indent * level
130 else:
131 prefix = ''
132 sep = ', '
133 if isinstance(node, AST):
134 cls = type(node)
135 args = []
136 allsimple = True
137 keywords = annotate_fields
138 for name in node._fields:
139 try:
140 value = getattr(node, name)
141 except AttributeError:
142 keywords = True
143 continue
144 if value is None and getattr(cls, name, ...) is None:
145 keywords = True
146 continue
147 value, simple = _format(value, level)
148 allsimple = allsimple and simple
149 if keywords:
150 args.append('%s=%s' % (name, value))
151 else:
152 args.append(value)
153 if include_attributes and node._attributes:
154 for name in node._attributes:
155 try:
156 value = getattr(node, name)
157 except AttributeError:
158 continue
159 if value is None and getattr(cls, name, ...) is None:
160 continue
161 value, simple = _format(value, level)
162 allsimple = allsimple and simple
163 args.append('%s=%s' % (name, value))
164 if allsimple and len(args) <= 3:
165 return '%s(%s)' % (node.__class__.__name__, ', '.join(args)), not args
166 return '%s(%s%s)' % (node.__class__.__name__, prefix, sep.join(args)), False
167 elif isinstance(node, list):
168 if not node:
169 return '[]', True
170 return '[%s%s]' % (prefix, sep.join(_format(x, level)[0] for x in node)), False
171 return repr(node), True
172
173 if not isinstance(node, AST):
174 raise TypeError('expected AST, got %r' % node.__class__.__name__)
175 if indent is not None and not isinstance(indent, str):
176 indent = ' ' * indent
177 return _format(node)[0]
178
179
180 def copy_location(new_node, old_node):
181 """
182 Copy source location (`lineno`, `col_offset`, `end_lineno`, and `end_col_offset`
183 attributes) from *old_node* to *new_node* if possible, and return *new_node*.
184 """
185 for attr in 'lineno', 'col_offset', 'end_lineno', 'end_col_offset':
186 if attr in old_node._attributes and attr in new_node._attributes:
187 value = getattr(old_node, attr, None)
188 # end_lineno and end_col_offset are optional attributes, and they
189 # should be copied whether the value is None or not.
190 if value is not None or (
191 hasattr(old_node, attr) and attr.startswith("end_")
192 ):
193 setattr(new_node, attr, value)
194 return new_node
195
196
197 def fix_missing_locations(node):
198 """
199 When you compile a node tree with compile(), the compiler expects lineno and
200 col_offset attributes for every node that supports them. This is rather
201 tedious to fill in for generated nodes, so this helper adds these attributes
202 recursively where not already set, by setting them to the values of the
203 parent node. It works recursively starting at *node*.
204 """
205 def _fix(node, lineno, col_offset, end_lineno, end_col_offset):
206 if 'lineno' in node._attributes:
207 if not hasattr(node, 'lineno'):
208 node.lineno = lineno
209 else:
210 lineno = node.lineno
211 if 'end_lineno' in node._attributes:
212 if getattr(node, 'end_lineno', None) is None:
213 node.end_lineno = end_lineno
214 else:
215 end_lineno = node.end_lineno
216 if 'col_offset' in node._attributes:
217 if not hasattr(node, 'col_offset'):
218 node.col_offset = col_offset
219 else:
220 col_offset = node.col_offset
221 if 'end_col_offset' in node._attributes:
222 if getattr(node, 'end_col_offset', None) is None:
223 node.end_col_offset = end_col_offset
224 else:
225 end_col_offset = node.end_col_offset
226 for child in iter_child_nodes(node):
227 _fix(child, lineno, col_offset, end_lineno, end_col_offset)
228 _fix(node, 1, 0, 1, 0)
229 return node
230
231
232 def increment_lineno(node, n=1):
233 """
234 Increment the line number and end line number of each node in the tree
235 starting at *node* by *n*. This is useful to "move code" to a different
236 location in a file.
237 """
238 for child in walk(node):
239 # TypeIgnore is a special case where lineno is not an attribute
240 # but rather a field of the node itself.
241 if isinstance(child, TypeIgnore):
242 child.lineno = getattr(child, 'lineno', 0) + n
243 continue
244
245 if 'lineno' in child._attributes:
246 child.lineno = getattr(child, 'lineno', 0) + n
247 if (
248 "end_lineno" in child._attributes
249 and (end_lineno := getattr(child, "end_lineno", 0)) is not None
250 ):
251 child.end_lineno = end_lineno + n
252 return node
253
254
255 def iter_fields(node):
256 """
257 Yield a tuple of ``(fieldname, value)`` for each field in ``node._fields``
258 that is present on *node*.
259 """
260 for field in node._fields:
261 try:
262 yield field, getattr(node, field)
263 except AttributeError:
264 pass
265
266
267 def iter_child_nodes(node):
268 """
269 Yield all direct child nodes of *node*, that is, all fields that are nodes
270 and all items of fields that are lists of nodes.
271 """
272 for name, field in iter_fields(node):
273 if isinstance(field, AST):
274 yield field
275 elif isinstance(field, list):
276 for item in field:
277 if isinstance(item, AST):
278 yield item
279
280
281 def get_docstring(node, clean=True):
282 """
283 Return the docstring for the given node or None if no docstring can
284 be found. If the node provided does not have docstrings a TypeError
285 will be raised.
286
287 If *clean* is `True`, all tabs are expanded to spaces and any whitespace
288 that can be uniformly removed from the second line onwards is removed.
289 """
290 if not isinstance(node, (AsyncFunctionDef, FunctionDef, ClassDef, Module)):
291 raise TypeError("%r can't have docstrings" % node.__class__.__name__)
292 if not(node.body and isinstance(node.body[0], Expr)):
293 return None
294 node = node.body[0].value
295 if isinstance(node, Str):
296 text = node.s
297 elif isinstance(node, Constant) and isinstance(node.value, str):
298 text = node.value
299 else:
300 return None
301 if clean:
302 import inspect
303 text = inspect.cleandoc(text)
304 return text
305
306
307 def _splitlines_no_ff(source):
308 """Split a string into lines ignoring form feed and other chars.
309
310 This mimics how the Python parser splits source code.
311 """
312 idx = 0
313 lines = []
314 next_line = ''
315 while idx < len(source):
316 c = source[idx]
317 next_line += c
318 idx += 1
319 # Keep \r\n together
320 if c == '\r' and idx < len(source) and source[idx] == '\n':
321 next_line += '\n'
322 idx += 1
323 if c in '\r\n':
324 lines.append(next_line)
325 next_line = ''
326
327 if next_line:
328 lines.append(next_line)
329 return lines
330
331
332 def _pad_whitespace(source):
333 r"""Replace all chars except '\f\t' in a line with spaces."""
334 result = ''
335 for c in source:
336 if c in '\f\t':
337 result += c
338 else:
339 result += ' '
340 return result
341
342
343 def get_source_segment(source, node, *, padded=False):
344 """Get source code segment of the *source* that generated *node*.
345
346 If some location information (`lineno`, `end_lineno`, `col_offset`,
347 or `end_col_offset`) is missing, return None.
348
349 If *padded* is `True`, the first line of a multi-line statement will
350 be padded with spaces to match its original position.
351 """
352 try:
353 if node.end_lineno is None or node.end_col_offset is None:
354 return None
355 lineno = node.lineno - 1
356 end_lineno = node.end_lineno - 1
357 col_offset = node.col_offset
358 end_col_offset = node.end_col_offset
359 except AttributeError:
360 return None
361
362 lines = _splitlines_no_ff(source)
363 if end_lineno == lineno:
364 return lines[lineno].encode()[col_offset:end_col_offset].decode()
365
366 if padded:
367 padding = _pad_whitespace(lines[lineno].encode()[:col_offset].decode())
368 else:
369 padding = ''
370
371 first = padding + lines[lineno].encode()[col_offset:].decode()
372 last = lines[end_lineno].encode()[:end_col_offset].decode()
373 lines = lines[lineno+1:end_lineno]
374
375 lines.insert(0, first)
376 lines.append(last)
377 return ''.join(lines)
378
379
380 def walk(node):
381 """
382 Recursively yield all descendant nodes in the tree starting at *node*
383 (including *node* itself), in no specified order. This is useful if you
384 only want to modify nodes in place and don't care about the context.
385 """
386 from collections import deque
387 todo = deque([node])
388 while todo:
389 node = todo.popleft()
390 todo.extend(iter_child_nodes(node))
391 yield node
392
393
394 class ESC[4;38;5;81mNodeVisitor(ESC[4;38;5;149mobject):
395 """
396 A node visitor base class that walks the abstract syntax tree and calls a
397 visitor function for every node found. This function may return a value
398 which is forwarded by the `visit` method.
399
400 This class is meant to be subclassed, with the subclass adding visitor
401 methods.
402
403 Per default the visitor functions for the nodes are ``'visit_'`` +
404 class name of the node. So a `TryFinally` node visit function would
405 be `visit_TryFinally`. This behavior can be changed by overriding
406 the `visit` method. If no visitor function exists for a node
407 (return value `None`) the `generic_visit` visitor is used instead.
408
409 Don't use the `NodeVisitor` if you want to apply changes to nodes during
410 traversing. For this a special visitor exists (`NodeTransformer`) that
411 allows modifications.
412 """
413
414 def visit(self, node):
415 """Visit a node."""
416 method = 'visit_' + node.__class__.__name__
417 visitor = getattr(self, method, self.generic_visit)
418 return visitor(node)
419
420 def generic_visit(self, node):
421 """Called if no explicit visitor function exists for a node."""
422 for field, value in iter_fields(node):
423 if isinstance(value, list):
424 for item in value:
425 if isinstance(item, AST):
426 self.visit(item)
427 elif isinstance(value, AST):
428 self.visit(value)
429
430 def visit_Constant(self, node):
431 value = node.value
432 type_name = _const_node_type_names.get(type(value))
433 if type_name is None:
434 for cls, name in _const_node_type_names.items():
435 if isinstance(value, cls):
436 type_name = name
437 break
438 if type_name is not None:
439 method = 'visit_' + type_name
440 try:
441 visitor = getattr(self, method)
442 except AttributeError:
443 pass
444 else:
445 import warnings
446 warnings.warn(f"{method} is deprecated; add visit_Constant",
447 DeprecationWarning, 2)
448 return visitor(node)
449 return self.generic_visit(node)
450
451
452 class ESC[4;38;5;81mNodeTransformer(ESC[4;38;5;149mNodeVisitor):
453 """
454 A :class:`NodeVisitor` subclass that walks the abstract syntax tree and
455 allows modification of nodes.
456
457 The `NodeTransformer` will walk the AST and use the return value of the
458 visitor methods to replace or remove the old node. If the return value of
459 the visitor method is ``None``, the node will be removed from its location,
460 otherwise it is replaced with the return value. The return value may be the
461 original node in which case no replacement takes place.
462
463 Here is an example transformer that rewrites all occurrences of name lookups
464 (``foo``) to ``data['foo']``::
465
466 class RewriteName(NodeTransformer):
467
468 def visit_Name(self, node):
469 return Subscript(
470 value=Name(id='data', ctx=Load()),
471 slice=Constant(value=node.id),
472 ctx=node.ctx
473 )
474
475 Keep in mind that if the node you're operating on has child nodes you must
476 either transform the child nodes yourself or call the :meth:`generic_visit`
477 method for the node first.
478
479 For nodes that were part of a collection of statements (that applies to all
480 statement nodes), the visitor may also return a list of nodes rather than
481 just a single node.
482
483 Usually you use the transformer like this::
484
485 node = YourTransformer().visit(node)
486 """
487
488 def generic_visit(self, node):
489 for field, old_value in iter_fields(node):
490 if isinstance(old_value, list):
491 new_values = []
492 for value in old_value:
493 if isinstance(value, AST):
494 value = self.visit(value)
495 if value is None:
496 continue
497 elif not isinstance(value, AST):
498 new_values.extend(value)
499 continue
500 new_values.append(value)
501 old_value[:] = new_values
502 elif isinstance(old_value, AST):
503 new_node = self.visit(old_value)
504 if new_node is None:
505 delattr(node, field)
506 else:
507 setattr(node, field, new_node)
508 return node
509
510
511 # If the ast module is loaded more than once, only add deprecated methods once
512 if not hasattr(Constant, 'n'):
513 # The following code is for backward compatibility.
514 # It will be removed in future.
515
516 def _getter(self):
517 """Deprecated. Use value instead."""
518 return self.value
519
520 def _setter(self, value):
521 self.value = value
522
523 Constant.n = property(_getter, _setter)
524 Constant.s = property(_getter, _setter)
525
526 class ESC[4;38;5;81m_ABC(ESC[4;38;5;149mtype):
527
528 def __init__(cls, *args):
529 cls.__doc__ = """Deprecated AST node class. Use ast.Constant instead"""
530
531 def __instancecheck__(cls, inst):
532 if not isinstance(inst, Constant):
533 return False
534 if cls in _const_types:
535 try:
536 value = inst.value
537 except AttributeError:
538 return False
539 else:
540 return (
541 isinstance(value, _const_types[cls]) and
542 not isinstance(value, _const_types_not.get(cls, ()))
543 )
544 return type.__instancecheck__(cls, inst)
545
546 def _new(cls, *args, **kwargs):
547 for key in kwargs:
548 if key not in cls._fields:
549 # arbitrary keyword arguments are accepted
550 continue
551 pos = cls._fields.index(key)
552 if pos < len(args):
553 raise TypeError(f"{cls.__name__} got multiple values for argument {key!r}")
554 if cls in _const_types:
555 return Constant(*args, **kwargs)
556 return Constant.__new__(cls, *args, **kwargs)
557
558 class ESC[4;38;5;81mNum(ESC[4;38;5;149mConstant, metaclass=ESC[4;38;5;149m_ABC):
559 _fields = ('n',)
560 __new__ = _new
561
562 class ESC[4;38;5;81mStr(ESC[4;38;5;149mConstant, metaclass=ESC[4;38;5;149m_ABC):
563 _fields = ('s',)
564 __new__ = _new
565
566 class ESC[4;38;5;81mBytes(ESC[4;38;5;149mConstant, metaclass=ESC[4;38;5;149m_ABC):
567 _fields = ('s',)
568 __new__ = _new
569
570 class ESC[4;38;5;81mNameConstant(ESC[4;38;5;149mConstant, metaclass=ESC[4;38;5;149m_ABC):
571 __new__ = _new
572
573 class ESC[4;38;5;81mEllipsis(ESC[4;38;5;149mConstant, metaclass=ESC[4;38;5;149m_ABC):
574 _fields = ()
575
576 def __new__(cls, *args, **kwargs):
577 if cls is Ellipsis:
578 return Constant(..., *args, **kwargs)
579 return Constant.__new__(cls, *args, **kwargs)
580
581 _const_types = {
582 Num: (int, float, complex),
583 Str: (str,),
584 Bytes: (bytes,),
585 NameConstant: (type(None), bool),
586 Ellipsis: (type(...),),
587 }
588 _const_types_not = {
589 Num: (bool,),
590 }
591
592 _const_node_type_names = {
593 bool: 'NameConstant', # should be before int
594 type(None): 'NameConstant',
595 int: 'Num',
596 float: 'Num',
597 complex: 'Num',
598 str: 'Str',
599 bytes: 'Bytes',
600 type(...): 'Ellipsis',
601 }
602
603 class ESC[4;38;5;81mslice(ESC[4;38;5;149mAST):
604 """Deprecated AST node class."""
605
606 class ESC[4;38;5;81mIndex(ESC[4;38;5;149mslice):
607 """Deprecated AST node class. Use the index value directly instead."""
608 def __new__(cls, value, **kwargs):
609 return value
610
611 class ESC[4;38;5;81mExtSlice(ESC[4;38;5;149mslice):
612 """Deprecated AST node class. Use ast.Tuple instead."""
613 def __new__(cls, dims=(), **kwargs):
614 return Tuple(list(dims), Load(), **kwargs)
615
616 # If the ast module is loaded more than once, only add deprecated methods once
617 if not hasattr(Tuple, 'dims'):
618 # The following code is for backward compatibility.
619 # It will be removed in future.
620
621 def _dims_getter(self):
622 """Deprecated. Use elts instead."""
623 return self.elts
624
625 def _dims_setter(self, value):
626 self.elts = value
627
628 Tuple.dims = property(_dims_getter, _dims_setter)
629
630 class ESC[4;38;5;81mSuite(ESC[4;38;5;149mmod):
631 """Deprecated AST node class. Unused in Python 3."""
632
633 class ESC[4;38;5;81mAugLoad(ESC[4;38;5;149mexpr_context):
634 """Deprecated AST node class. Unused in Python 3."""
635
636 class ESC[4;38;5;81mAugStore(ESC[4;38;5;149mexpr_context):
637 """Deprecated AST node class. Unused in Python 3."""
638
639 class ESC[4;38;5;81mParam(ESC[4;38;5;149mexpr_context):
640 """Deprecated AST node class. Unused in Python 3."""
641
642
643 # Large float and imaginary literals get turned into infinities in the AST.
644 # We unparse those infinities to INFSTR.
645 _INFSTR = "1e" + repr(sys.float_info.max_10_exp + 1)
646
647 @_simple_enum(IntEnum)
648 class ESC[4;38;5;81m_Precedence:
649 """Precedence table that originated from python grammar."""
650
651 NAMED_EXPR = auto() # <target> := <expr1>
652 TUPLE = auto() # <expr1>, <expr2>
653 YIELD = auto() # 'yield', 'yield from'
654 TEST = auto() # 'if'-'else', 'lambda'
655 OR = auto() # 'or'
656 AND = auto() # 'and'
657 NOT = auto() # 'not'
658 CMP = auto() # '<', '>', '==', '>=', '<=', '!=',
659 # 'in', 'not in', 'is', 'is not'
660 EXPR = auto()
661 BOR = EXPR # '|'
662 BXOR = auto() # '^'
663 BAND = auto() # '&'
664 SHIFT = auto() # '<<', '>>'
665 ARITH = auto() # '+', '-'
666 TERM = auto() # '*', '@', '/', '%', '//'
667 FACTOR = auto() # unary '+', '-', '~'
668 POWER = auto() # '**'
669 AWAIT = auto() # 'await'
670 ATOM = auto()
671
672 def next(self):
673 try:
674 return self.__class__(self + 1)
675 except ValueError:
676 return self
677
678
679 _SINGLE_QUOTES = ("'", '"')
680 _MULTI_QUOTES = ('"""', "'''")
681 _ALL_QUOTES = (*_SINGLE_QUOTES, *_MULTI_QUOTES)
682
683 class ESC[4;38;5;81m_Unparser(ESC[4;38;5;149mNodeVisitor):
684 """Methods in this class recursively traverse an AST and
685 output source code for the abstract syntax; original formatting
686 is disregarded."""
687
688 def __init__(self, *, _avoid_backslashes=False):
689 self._source = []
690 self._precedences = {}
691 self._type_ignores = {}
692 self._indent = 0
693 self._avoid_backslashes = _avoid_backslashes
694 self._in_try_star = False
695
696 def interleave(self, inter, f, seq):
697 """Call f on each item in seq, calling inter() in between."""
698 seq = iter(seq)
699 try:
700 f(next(seq))
701 except StopIteration:
702 pass
703 else:
704 for x in seq:
705 inter()
706 f(x)
707
708 def items_view(self, traverser, items):
709 """Traverse and separate the given *items* with a comma and append it to
710 the buffer. If *items* is a single item sequence, a trailing comma
711 will be added."""
712 if len(items) == 1:
713 traverser(items[0])
714 self.write(",")
715 else:
716 self.interleave(lambda: self.write(", "), traverser, items)
717
718 def maybe_newline(self):
719 """Adds a newline if it isn't the start of generated source"""
720 if self._source:
721 self.write("\n")
722
723 def fill(self, text=""):
724 """Indent a piece of text and append it, according to the current
725 indentation level"""
726 self.maybe_newline()
727 self.write(" " * self._indent + text)
728
729 def write(self, *text):
730 """Add new source parts"""
731 self._source.extend(text)
732
733 @contextmanager
734 def buffered(self, buffer = None):
735 if buffer is None:
736 buffer = []
737
738 original_source = self._source
739 self._source = buffer
740 yield buffer
741 self._source = original_source
742
743 @contextmanager
744 def block(self, *, extra = None):
745 """A context manager for preparing the source for blocks. It adds
746 the character':', increases the indentation on enter and decreases
747 the indentation on exit. If *extra* is given, it will be directly
748 appended after the colon character.
749 """
750 self.write(":")
751 if extra:
752 self.write(extra)
753 self._indent += 1
754 yield
755 self._indent -= 1
756
757 @contextmanager
758 def delimit(self, start, end):
759 """A context manager for preparing the source for expressions. It adds
760 *start* to the buffer and enters, after exit it adds *end*."""
761
762 self.write(start)
763 yield
764 self.write(end)
765
766 def delimit_if(self, start, end, condition):
767 if condition:
768 return self.delimit(start, end)
769 else:
770 return nullcontext()
771
772 def require_parens(self, precedence, node):
773 """Shortcut to adding precedence related parens"""
774 return self.delimit_if("(", ")", self.get_precedence(node) > precedence)
775
776 def get_precedence(self, node):
777 return self._precedences.get(node, _Precedence.TEST)
778
779 def set_precedence(self, precedence, *nodes):
780 for node in nodes:
781 self._precedences[node] = precedence
782
783 def get_raw_docstring(self, node):
784 """If a docstring node is found in the body of the *node* parameter,
785 return that docstring node, None otherwise.
786
787 Logic mirrored from ``_PyAST_GetDocString``."""
788 if not isinstance(
789 node, (AsyncFunctionDef, FunctionDef, ClassDef, Module)
790 ) or len(node.body) < 1:
791 return None
792 node = node.body[0]
793 if not isinstance(node, Expr):
794 return None
795 node = node.value
796 if isinstance(node, Constant) and isinstance(node.value, str):
797 return node
798
799 def get_type_comment(self, node):
800 comment = self._type_ignores.get(node.lineno) or node.type_comment
801 if comment is not None:
802 return f" # type: {comment}"
803
804 def traverse(self, node):
805 if isinstance(node, list):
806 for item in node:
807 self.traverse(item)
808 else:
809 super().visit(node)
810
811 # Note: as visit() resets the output text, do NOT rely on
812 # NodeVisitor.generic_visit to handle any nodes (as it calls back in to
813 # the subclass visit() method, which resets self._source to an empty list)
814 def visit(self, node):
815 """Outputs a source code string that, if converted back to an ast
816 (using ast.parse) will generate an AST equivalent to *node*"""
817 self._source = []
818 self.traverse(node)
819 return "".join(self._source)
820
821 def _write_docstring_and_traverse_body(self, node):
822 if (docstring := self.get_raw_docstring(node)):
823 self._write_docstring(docstring)
824 self.traverse(node.body[1:])
825 else:
826 self.traverse(node.body)
827
828 def visit_Module(self, node):
829 self._type_ignores = {
830 ignore.lineno: f"ignore{ignore.tag}"
831 for ignore in node.type_ignores
832 }
833 self._write_docstring_and_traverse_body(node)
834 self._type_ignores.clear()
835
836 def visit_FunctionType(self, node):
837 with self.delimit("(", ")"):
838 self.interleave(
839 lambda: self.write(", "), self.traverse, node.argtypes
840 )
841
842 self.write(" -> ")
843 self.traverse(node.returns)
844
845 def visit_Expr(self, node):
846 self.fill()
847 self.set_precedence(_Precedence.YIELD, node.value)
848 self.traverse(node.value)
849
850 def visit_NamedExpr(self, node):
851 with self.require_parens(_Precedence.NAMED_EXPR, node):
852 self.set_precedence(_Precedence.ATOM, node.target, node.value)
853 self.traverse(node.target)
854 self.write(" := ")
855 self.traverse(node.value)
856
857 def visit_Import(self, node):
858 self.fill("import ")
859 self.interleave(lambda: self.write(", "), self.traverse, node.names)
860
861 def visit_ImportFrom(self, node):
862 self.fill("from ")
863 self.write("." * (node.level or 0))
864 if node.module:
865 self.write(node.module)
866 self.write(" import ")
867 self.interleave(lambda: self.write(", "), self.traverse, node.names)
868
869 def visit_Assign(self, node):
870 self.fill()
871 for target in node.targets:
872 self.set_precedence(_Precedence.TUPLE, target)
873 self.traverse(target)
874 self.write(" = ")
875 self.traverse(node.value)
876 if type_comment := self.get_type_comment(node):
877 self.write(type_comment)
878
879 def visit_AugAssign(self, node):
880 self.fill()
881 self.traverse(node.target)
882 self.write(" " + self.binop[node.op.__class__.__name__] + "= ")
883 self.traverse(node.value)
884
885 def visit_AnnAssign(self, node):
886 self.fill()
887 with self.delimit_if("(", ")", not node.simple and isinstance(node.target, Name)):
888 self.traverse(node.target)
889 self.write(": ")
890 self.traverse(node.annotation)
891 if node.value:
892 self.write(" = ")
893 self.traverse(node.value)
894
895 def visit_Return(self, node):
896 self.fill("return")
897 if node.value:
898 self.write(" ")
899 self.traverse(node.value)
900
901 def visit_Pass(self, node):
902 self.fill("pass")
903
904 def visit_Break(self, node):
905 self.fill("break")
906
907 def visit_Continue(self, node):
908 self.fill("continue")
909
910 def visit_Delete(self, node):
911 self.fill("del ")
912 self.interleave(lambda: self.write(", "), self.traverse, node.targets)
913
914 def visit_Assert(self, node):
915 self.fill("assert ")
916 self.traverse(node.test)
917 if node.msg:
918 self.write(", ")
919 self.traverse(node.msg)
920
921 def visit_Global(self, node):
922 self.fill("global ")
923 self.interleave(lambda: self.write(", "), self.write, node.names)
924
925 def visit_Nonlocal(self, node):
926 self.fill("nonlocal ")
927 self.interleave(lambda: self.write(", "), self.write, node.names)
928
929 def visit_Await(self, node):
930 with self.require_parens(_Precedence.AWAIT, node):
931 self.write("await")
932 if node.value:
933 self.write(" ")
934 self.set_precedence(_Precedence.ATOM, node.value)
935 self.traverse(node.value)
936
937 def visit_Yield(self, node):
938 with self.require_parens(_Precedence.YIELD, node):
939 self.write("yield")
940 if node.value:
941 self.write(" ")
942 self.set_precedence(_Precedence.ATOM, node.value)
943 self.traverse(node.value)
944
945 def visit_YieldFrom(self, node):
946 with self.require_parens(_Precedence.YIELD, node):
947 self.write("yield from ")
948 if not node.value:
949 raise ValueError("Node can't be used without a value attribute.")
950 self.set_precedence(_Precedence.ATOM, node.value)
951 self.traverse(node.value)
952
953 def visit_Raise(self, node):
954 self.fill("raise")
955 if not node.exc:
956 if node.cause:
957 raise ValueError(f"Node can't use cause without an exception.")
958 return
959 self.write(" ")
960 self.traverse(node.exc)
961 if node.cause:
962 self.write(" from ")
963 self.traverse(node.cause)
964
965 def do_visit_try(self, node):
966 self.fill("try")
967 with self.block():
968 self.traverse(node.body)
969 for ex in node.handlers:
970 self.traverse(ex)
971 if node.orelse:
972 self.fill("else")
973 with self.block():
974 self.traverse(node.orelse)
975 if node.finalbody:
976 self.fill("finally")
977 with self.block():
978 self.traverse(node.finalbody)
979
980 def visit_Try(self, node):
981 prev_in_try_star = self._in_try_star
982 try:
983 self._in_try_star = False
984 self.do_visit_try(node)
985 finally:
986 self._in_try_star = prev_in_try_star
987
988 def visit_TryStar(self, node):
989 prev_in_try_star = self._in_try_star
990 try:
991 self._in_try_star = True
992 self.do_visit_try(node)
993 finally:
994 self._in_try_star = prev_in_try_star
995
996 def visit_ExceptHandler(self, node):
997 self.fill("except*" if self._in_try_star else "except")
998 if node.type:
999 self.write(" ")
1000 self.traverse(node.type)
1001 if node.name:
1002 self.write(" as ")
1003 self.write(node.name)
1004 with self.block():
1005 self.traverse(node.body)
1006
1007 def visit_ClassDef(self, node):
1008 self.maybe_newline()
1009 for deco in node.decorator_list:
1010 self.fill("@")
1011 self.traverse(deco)
1012 self.fill("class " + node.name)
1013 with self.delimit_if("(", ")", condition = node.bases or node.keywords):
1014 comma = False
1015 for e in node.bases:
1016 if comma:
1017 self.write(", ")
1018 else:
1019 comma = True
1020 self.traverse(e)
1021 for e in node.keywords:
1022 if comma:
1023 self.write(", ")
1024 else:
1025 comma = True
1026 self.traverse(e)
1027
1028 with self.block():
1029 self._write_docstring_and_traverse_body(node)
1030
1031 def visit_FunctionDef(self, node):
1032 self._function_helper(node, "def")
1033
1034 def visit_AsyncFunctionDef(self, node):
1035 self._function_helper(node, "async def")
1036
1037 def _function_helper(self, node, fill_suffix):
1038 self.maybe_newline()
1039 for deco in node.decorator_list:
1040 self.fill("@")
1041 self.traverse(deco)
1042 def_str = fill_suffix + " " + node.name
1043 self.fill(def_str)
1044 with self.delimit("(", ")"):
1045 self.traverse(node.args)
1046 if node.returns:
1047 self.write(" -> ")
1048 self.traverse(node.returns)
1049 with self.block(extra=self.get_type_comment(node)):
1050 self._write_docstring_and_traverse_body(node)
1051
1052 def visit_For(self, node):
1053 self._for_helper("for ", node)
1054
1055 def visit_AsyncFor(self, node):
1056 self._for_helper("async for ", node)
1057
1058 def _for_helper(self, fill, node):
1059 self.fill(fill)
1060 self.set_precedence(_Precedence.TUPLE, node.target)
1061 self.traverse(node.target)
1062 self.write(" in ")
1063 self.traverse(node.iter)
1064 with self.block(extra=self.get_type_comment(node)):
1065 self.traverse(node.body)
1066 if node.orelse:
1067 self.fill("else")
1068 with self.block():
1069 self.traverse(node.orelse)
1070
1071 def visit_If(self, node):
1072 self.fill("if ")
1073 self.traverse(node.test)
1074 with self.block():
1075 self.traverse(node.body)
1076 # collapse nested ifs into equivalent elifs.
1077 while node.orelse and len(node.orelse) == 1 and isinstance(node.orelse[0], If):
1078 node = node.orelse[0]
1079 self.fill("elif ")
1080 self.traverse(node.test)
1081 with self.block():
1082 self.traverse(node.body)
1083 # final else
1084 if node.orelse:
1085 self.fill("else")
1086 with self.block():
1087 self.traverse(node.orelse)
1088
1089 def visit_While(self, node):
1090 self.fill("while ")
1091 self.traverse(node.test)
1092 with self.block():
1093 self.traverse(node.body)
1094 if node.orelse:
1095 self.fill("else")
1096 with self.block():
1097 self.traverse(node.orelse)
1098
1099 def visit_With(self, node):
1100 self.fill("with ")
1101 self.interleave(lambda: self.write(", "), self.traverse, node.items)
1102 with self.block(extra=self.get_type_comment(node)):
1103 self.traverse(node.body)
1104
1105 def visit_AsyncWith(self, node):
1106 self.fill("async with ")
1107 self.interleave(lambda: self.write(", "), self.traverse, node.items)
1108 with self.block(extra=self.get_type_comment(node)):
1109 self.traverse(node.body)
1110
1111 def _str_literal_helper(
1112 self, string, *, quote_types=_ALL_QUOTES, escape_special_whitespace=False
1113 ):
1114 """Helper for writing string literals, minimizing escapes.
1115 Returns the tuple (string literal to write, possible quote types).
1116 """
1117 def escape_char(c):
1118 # \n and \t are non-printable, but we only escape them if
1119 # escape_special_whitespace is True
1120 if not escape_special_whitespace and c in "\n\t":
1121 return c
1122 # Always escape backslashes and other non-printable characters
1123 if c == "\\" or not c.isprintable():
1124 return c.encode("unicode_escape").decode("ascii")
1125 return c
1126
1127 escaped_string = "".join(map(escape_char, string))
1128 possible_quotes = quote_types
1129 if "\n" in escaped_string:
1130 possible_quotes = [q for q in possible_quotes if q in _MULTI_QUOTES]
1131 possible_quotes = [q for q in possible_quotes if q not in escaped_string]
1132 if not possible_quotes:
1133 # If there aren't any possible_quotes, fallback to using repr
1134 # on the original string. Try to use a quote from quote_types,
1135 # e.g., so that we use triple quotes for docstrings.
1136 string = repr(string)
1137 quote = next((q for q in quote_types if string[0] in q), string[0])
1138 return string[1:-1], [quote]
1139 if escaped_string:
1140 # Sort so that we prefer '''"''' over """\""""
1141 possible_quotes.sort(key=lambda q: q[0] == escaped_string[-1])
1142 # If we're using triple quotes and we'd need to escape a final
1143 # quote, escape it
1144 if possible_quotes[0][0] == escaped_string[-1]:
1145 assert len(possible_quotes[0]) == 3
1146 escaped_string = escaped_string[:-1] + "\\" + escaped_string[-1]
1147 return escaped_string, possible_quotes
1148
1149 def _write_str_avoiding_backslashes(self, string, *, quote_types=_ALL_QUOTES):
1150 """Write string literal value with a best effort attempt to avoid backslashes."""
1151 string, quote_types = self._str_literal_helper(string, quote_types=quote_types)
1152 quote_type = quote_types[0]
1153 self.write(f"{quote_type}{string}{quote_type}")
1154
1155 def visit_JoinedStr(self, node):
1156 self.write("f")
1157 if self._avoid_backslashes:
1158 with self.buffered() as buffer:
1159 self._write_fstring_inner(node)
1160 return self._write_str_avoiding_backslashes("".join(buffer))
1161
1162 # If we don't need to avoid backslashes globally (i.e., we only need
1163 # to avoid them inside FormattedValues), it's cosmetically preferred
1164 # to use escaped whitespace. That is, it's preferred to use backslashes
1165 # for cases like: f"{x}\n". To accomplish this, we keep track of what
1166 # in our buffer corresponds to FormattedValues and what corresponds to
1167 # Constant parts of the f-string, and allow escapes accordingly.
1168 fstring_parts = []
1169 for value in node.values:
1170 with self.buffered() as buffer:
1171 self._write_fstring_inner(value)
1172 fstring_parts.append(
1173 ("".join(buffer), isinstance(value, Constant))
1174 )
1175
1176 new_fstring_parts = []
1177 quote_types = list(_ALL_QUOTES)
1178 fallback_to_repr = False
1179 for value, is_constant in fstring_parts:
1180 value, new_quote_types = self._str_literal_helper(
1181 value,
1182 quote_types=quote_types,
1183 escape_special_whitespace=is_constant,
1184 )
1185 new_fstring_parts.append(value)
1186 if set(new_quote_types).isdisjoint(quote_types):
1187 fallback_to_repr = True
1188 break
1189 quote_types = new_quote_types
1190
1191 if fallback_to_repr:
1192 # If we weren't able to find a quote type that works for all parts
1193 # of the JoinedStr, fallback to using repr and triple single quotes.
1194 quote_types = ["'''"]
1195 new_fstring_parts.clear()
1196 for value, is_constant in fstring_parts:
1197 value = repr('"' + value) # force repr to use single quotes
1198 expected_prefix = "'\""
1199 assert value.startswith(expected_prefix), repr(value)
1200 new_fstring_parts.append(value[len(expected_prefix):-1])
1201
1202 value = "".join(new_fstring_parts)
1203 quote_type = quote_types[0]
1204 self.write(f"{quote_type}{value}{quote_type}")
1205
1206 def _write_fstring_inner(self, node):
1207 if isinstance(node, JoinedStr):
1208 # for both the f-string itself, and format_spec
1209 for value in node.values:
1210 self._write_fstring_inner(value)
1211 elif isinstance(node, Constant) and isinstance(node.value, str):
1212 value = node.value.replace("{", "{{").replace("}", "}}")
1213 self.write(value)
1214 elif isinstance(node, FormattedValue):
1215 self.visit_FormattedValue(node)
1216 else:
1217 raise ValueError(f"Unexpected node inside JoinedStr, {node!r}")
1218
1219 def visit_FormattedValue(self, node):
1220 def unparse_inner(inner):
1221 unparser = type(self)(_avoid_backslashes=True)
1222 unparser.set_precedence(_Precedence.TEST.next(), inner)
1223 return unparser.visit(inner)
1224
1225 with self.delimit("{", "}"):
1226 expr = unparse_inner(node.value)
1227 if "\\" in expr:
1228 raise ValueError(
1229 "Unable to avoid backslash in f-string expression part"
1230 )
1231 if expr.startswith("{"):
1232 # Separate pair of opening brackets as "{ {"
1233 self.write(" ")
1234 self.write(expr)
1235 if node.conversion != -1:
1236 self.write(f"!{chr(node.conversion)}")
1237 if node.format_spec:
1238 self.write(":")
1239 self._write_fstring_inner(node.format_spec)
1240
1241 def visit_Name(self, node):
1242 self.write(node.id)
1243
1244 def _write_docstring(self, node):
1245 self.fill()
1246 if node.kind == "u":
1247 self.write("u")
1248 self._write_str_avoiding_backslashes(node.value, quote_types=_MULTI_QUOTES)
1249
1250 def _write_constant(self, value):
1251 if isinstance(value, (float, complex)):
1252 # Substitute overflowing decimal literal for AST infinities,
1253 # and inf - inf for NaNs.
1254 self.write(
1255 repr(value)
1256 .replace("inf", _INFSTR)
1257 .replace("nan", f"({_INFSTR}-{_INFSTR})")
1258 )
1259 elif self._avoid_backslashes and isinstance(value, str):
1260 self._write_str_avoiding_backslashes(value)
1261 else:
1262 self.write(repr(value))
1263
1264 def visit_Constant(self, node):
1265 value = node.value
1266 if isinstance(value, tuple):
1267 with self.delimit("(", ")"):
1268 self.items_view(self._write_constant, value)
1269 elif value is ...:
1270 self.write("...")
1271 else:
1272 if node.kind == "u":
1273 self.write("u")
1274 self._write_constant(node.value)
1275
1276 def visit_List(self, node):
1277 with self.delimit("[", "]"):
1278 self.interleave(lambda: self.write(", "), self.traverse, node.elts)
1279
1280 def visit_ListComp(self, node):
1281 with self.delimit("[", "]"):
1282 self.traverse(node.elt)
1283 for gen in node.generators:
1284 self.traverse(gen)
1285
1286 def visit_GeneratorExp(self, node):
1287 with self.delimit("(", ")"):
1288 self.traverse(node.elt)
1289 for gen in node.generators:
1290 self.traverse(gen)
1291
1292 def visit_SetComp(self, node):
1293 with self.delimit("{", "}"):
1294 self.traverse(node.elt)
1295 for gen in node.generators:
1296 self.traverse(gen)
1297
1298 def visit_DictComp(self, node):
1299 with self.delimit("{", "}"):
1300 self.traverse(node.key)
1301 self.write(": ")
1302 self.traverse(node.value)
1303 for gen in node.generators:
1304 self.traverse(gen)
1305
1306 def visit_comprehension(self, node):
1307 if node.is_async:
1308 self.write(" async for ")
1309 else:
1310 self.write(" for ")
1311 self.set_precedence(_Precedence.TUPLE, node.target)
1312 self.traverse(node.target)
1313 self.write(" in ")
1314 self.set_precedence(_Precedence.TEST.next(), node.iter, *node.ifs)
1315 self.traverse(node.iter)
1316 for if_clause in node.ifs:
1317 self.write(" if ")
1318 self.traverse(if_clause)
1319
1320 def visit_IfExp(self, node):
1321 with self.require_parens(_Precedence.TEST, node):
1322 self.set_precedence(_Precedence.TEST.next(), node.body, node.test)
1323 self.traverse(node.body)
1324 self.write(" if ")
1325 self.traverse(node.test)
1326 self.write(" else ")
1327 self.set_precedence(_Precedence.TEST, node.orelse)
1328 self.traverse(node.orelse)
1329
1330 def visit_Set(self, node):
1331 if node.elts:
1332 with self.delimit("{", "}"):
1333 self.interleave(lambda: self.write(", "), self.traverse, node.elts)
1334 else:
1335 # `{}` would be interpreted as a dictionary literal, and
1336 # `set` might be shadowed. Thus:
1337 self.write('{*()}')
1338
1339 def visit_Dict(self, node):
1340 def write_key_value_pair(k, v):
1341 self.traverse(k)
1342 self.write(": ")
1343 self.traverse(v)
1344
1345 def write_item(item):
1346 k, v = item
1347 if k is None:
1348 # for dictionary unpacking operator in dicts {**{'y': 2}}
1349 # see PEP 448 for details
1350 self.write("**")
1351 self.set_precedence(_Precedence.EXPR, v)
1352 self.traverse(v)
1353 else:
1354 write_key_value_pair(k, v)
1355
1356 with self.delimit("{", "}"):
1357 self.interleave(
1358 lambda: self.write(", "), write_item, zip(node.keys, node.values)
1359 )
1360
1361 def visit_Tuple(self, node):
1362 with self.delimit_if(
1363 "(",
1364 ")",
1365 len(node.elts) == 0 or self.get_precedence(node) > _Precedence.TUPLE
1366 ):
1367 self.items_view(self.traverse, node.elts)
1368
1369 unop = {"Invert": "~", "Not": "not", "UAdd": "+", "USub": "-"}
1370 unop_precedence = {
1371 "not": _Precedence.NOT,
1372 "~": _Precedence.FACTOR,
1373 "+": _Precedence.FACTOR,
1374 "-": _Precedence.FACTOR,
1375 }
1376
1377 def visit_UnaryOp(self, node):
1378 operator = self.unop[node.op.__class__.__name__]
1379 operator_precedence = self.unop_precedence[operator]
1380 with self.require_parens(operator_precedence, node):
1381 self.write(operator)
1382 # factor prefixes (+, -, ~) shouldn't be separated
1383 # from the value they belong, (e.g: +1 instead of + 1)
1384 if operator_precedence is not _Precedence.FACTOR:
1385 self.write(" ")
1386 self.set_precedence(operator_precedence, node.operand)
1387 self.traverse(node.operand)
1388
1389 binop = {
1390 "Add": "+",
1391 "Sub": "-",
1392 "Mult": "*",
1393 "MatMult": "@",
1394 "Div": "/",
1395 "Mod": "%",
1396 "LShift": "<<",
1397 "RShift": ">>",
1398 "BitOr": "|",
1399 "BitXor": "^",
1400 "BitAnd": "&",
1401 "FloorDiv": "//",
1402 "Pow": "**",
1403 }
1404
1405 binop_precedence = {
1406 "+": _Precedence.ARITH,
1407 "-": _Precedence.ARITH,
1408 "*": _Precedence.TERM,
1409 "@": _Precedence.TERM,
1410 "/": _Precedence.TERM,
1411 "%": _Precedence.TERM,
1412 "<<": _Precedence.SHIFT,
1413 ">>": _Precedence.SHIFT,
1414 "|": _Precedence.BOR,
1415 "^": _Precedence.BXOR,
1416 "&": _Precedence.BAND,
1417 "//": _Precedence.TERM,
1418 "**": _Precedence.POWER,
1419 }
1420
1421 binop_rassoc = frozenset(("**",))
1422 def visit_BinOp(self, node):
1423 operator = self.binop[node.op.__class__.__name__]
1424 operator_precedence = self.binop_precedence[operator]
1425 with self.require_parens(operator_precedence, node):
1426 if operator in self.binop_rassoc:
1427 left_precedence = operator_precedence.next()
1428 right_precedence = operator_precedence
1429 else:
1430 left_precedence = operator_precedence
1431 right_precedence = operator_precedence.next()
1432
1433 self.set_precedence(left_precedence, node.left)
1434 self.traverse(node.left)
1435 self.write(f" {operator} ")
1436 self.set_precedence(right_precedence, node.right)
1437 self.traverse(node.right)
1438
1439 cmpops = {
1440 "Eq": "==",
1441 "NotEq": "!=",
1442 "Lt": "<",
1443 "LtE": "<=",
1444 "Gt": ">",
1445 "GtE": ">=",
1446 "Is": "is",
1447 "IsNot": "is not",
1448 "In": "in",
1449 "NotIn": "not in",
1450 }
1451
1452 def visit_Compare(self, node):
1453 with self.require_parens(_Precedence.CMP, node):
1454 self.set_precedence(_Precedence.CMP.next(), node.left, *node.comparators)
1455 self.traverse(node.left)
1456 for o, e in zip(node.ops, node.comparators):
1457 self.write(" " + self.cmpops[o.__class__.__name__] + " ")
1458 self.traverse(e)
1459
1460 boolops = {"And": "and", "Or": "or"}
1461 boolop_precedence = {"and": _Precedence.AND, "or": _Precedence.OR}
1462
1463 def visit_BoolOp(self, node):
1464 operator = self.boolops[node.op.__class__.__name__]
1465 operator_precedence = self.boolop_precedence[operator]
1466
1467 def increasing_level_traverse(node):
1468 nonlocal operator_precedence
1469 operator_precedence = operator_precedence.next()
1470 self.set_precedence(operator_precedence, node)
1471 self.traverse(node)
1472
1473 with self.require_parens(operator_precedence, node):
1474 s = f" {operator} "
1475 self.interleave(lambda: self.write(s), increasing_level_traverse, node.values)
1476
1477 def visit_Attribute(self, node):
1478 self.set_precedence(_Precedence.ATOM, node.value)
1479 self.traverse(node.value)
1480 # Special case: 3.__abs__() is a syntax error, so if node.value
1481 # is an integer literal then we need to either parenthesize
1482 # it or add an extra space to get 3 .__abs__().
1483 if isinstance(node.value, Constant) and isinstance(node.value.value, int):
1484 self.write(" ")
1485 self.write(".")
1486 self.write(node.attr)
1487
1488 def visit_Call(self, node):
1489 self.set_precedence(_Precedence.ATOM, node.func)
1490 self.traverse(node.func)
1491 with self.delimit("(", ")"):
1492 comma = False
1493 for e in node.args:
1494 if comma:
1495 self.write(", ")
1496 else:
1497 comma = True
1498 self.traverse(e)
1499 for e in node.keywords:
1500 if comma:
1501 self.write(", ")
1502 else:
1503 comma = True
1504 self.traverse(e)
1505
1506 def visit_Subscript(self, node):
1507 def is_non_empty_tuple(slice_value):
1508 return (
1509 isinstance(slice_value, Tuple)
1510 and slice_value.elts
1511 )
1512
1513 self.set_precedence(_Precedence.ATOM, node.value)
1514 self.traverse(node.value)
1515 with self.delimit("[", "]"):
1516 if is_non_empty_tuple(node.slice):
1517 # parentheses can be omitted if the tuple isn't empty
1518 self.items_view(self.traverse, node.slice.elts)
1519 else:
1520 self.traverse(node.slice)
1521
1522 def visit_Starred(self, node):
1523 self.write("*")
1524 self.set_precedence(_Precedence.EXPR, node.value)
1525 self.traverse(node.value)
1526
1527 def visit_Ellipsis(self, node):
1528 self.write("...")
1529
1530 def visit_Slice(self, node):
1531 if node.lower:
1532 self.traverse(node.lower)
1533 self.write(":")
1534 if node.upper:
1535 self.traverse(node.upper)
1536 if node.step:
1537 self.write(":")
1538 self.traverse(node.step)
1539
1540 def visit_Match(self, node):
1541 self.fill("match ")
1542 self.traverse(node.subject)
1543 with self.block():
1544 for case in node.cases:
1545 self.traverse(case)
1546
1547 def visit_arg(self, node):
1548 self.write(node.arg)
1549 if node.annotation:
1550 self.write(": ")
1551 self.traverse(node.annotation)
1552
1553 def visit_arguments(self, node):
1554 first = True
1555 # normal arguments
1556 all_args = node.posonlyargs + node.args
1557 defaults = [None] * (len(all_args) - len(node.defaults)) + node.defaults
1558 for index, elements in enumerate(zip(all_args, defaults), 1):
1559 a, d = elements
1560 if first:
1561 first = False
1562 else:
1563 self.write(", ")
1564 self.traverse(a)
1565 if d:
1566 self.write("=")
1567 self.traverse(d)
1568 if index == len(node.posonlyargs):
1569 self.write(", /")
1570
1571 # varargs, or bare '*' if no varargs but keyword-only arguments present
1572 if node.vararg or node.kwonlyargs:
1573 if first:
1574 first = False
1575 else:
1576 self.write(", ")
1577 self.write("*")
1578 if node.vararg:
1579 self.write(node.vararg.arg)
1580 if node.vararg.annotation:
1581 self.write(": ")
1582 self.traverse(node.vararg.annotation)
1583
1584 # keyword-only arguments
1585 if node.kwonlyargs:
1586 for a, d in zip(node.kwonlyargs, node.kw_defaults):
1587 self.write(", ")
1588 self.traverse(a)
1589 if d:
1590 self.write("=")
1591 self.traverse(d)
1592
1593 # kwargs
1594 if node.kwarg:
1595 if first:
1596 first = False
1597 else:
1598 self.write(", ")
1599 self.write("**" + node.kwarg.arg)
1600 if node.kwarg.annotation:
1601 self.write(": ")
1602 self.traverse(node.kwarg.annotation)
1603
1604 def visit_keyword(self, node):
1605 if node.arg is None:
1606 self.write("**")
1607 else:
1608 self.write(node.arg)
1609 self.write("=")
1610 self.traverse(node.value)
1611
1612 def visit_Lambda(self, node):
1613 with self.require_parens(_Precedence.TEST, node):
1614 self.write("lambda")
1615 with self.buffered() as buffer:
1616 self.traverse(node.args)
1617 if buffer:
1618 self.write(" ", *buffer)
1619 self.write(": ")
1620 self.set_precedence(_Precedence.TEST, node.body)
1621 self.traverse(node.body)
1622
1623 def visit_alias(self, node):
1624 self.write(node.name)
1625 if node.asname:
1626 self.write(" as " + node.asname)
1627
1628 def visit_withitem(self, node):
1629 self.traverse(node.context_expr)
1630 if node.optional_vars:
1631 self.write(" as ")
1632 self.traverse(node.optional_vars)
1633
1634 def visit_match_case(self, node):
1635 self.fill("case ")
1636 self.traverse(node.pattern)
1637 if node.guard:
1638 self.write(" if ")
1639 self.traverse(node.guard)
1640 with self.block():
1641 self.traverse(node.body)
1642
1643 def visit_MatchValue(self, node):
1644 self.traverse(node.value)
1645
1646 def visit_MatchSingleton(self, node):
1647 self._write_constant(node.value)
1648
1649 def visit_MatchSequence(self, node):
1650 with self.delimit("[", "]"):
1651 self.interleave(
1652 lambda: self.write(", "), self.traverse, node.patterns
1653 )
1654
1655 def visit_MatchStar(self, node):
1656 name = node.name
1657 if name is None:
1658 name = "_"
1659 self.write(f"*{name}")
1660
1661 def visit_MatchMapping(self, node):
1662 def write_key_pattern_pair(pair):
1663 k, p = pair
1664 self.traverse(k)
1665 self.write(": ")
1666 self.traverse(p)
1667
1668 with self.delimit("{", "}"):
1669 keys = node.keys
1670 self.interleave(
1671 lambda: self.write(", "),
1672 write_key_pattern_pair,
1673 zip(keys, node.patterns, strict=True),
1674 )
1675 rest = node.rest
1676 if rest is not None:
1677 if keys:
1678 self.write(", ")
1679 self.write(f"**{rest}")
1680
1681 def visit_MatchClass(self, node):
1682 self.set_precedence(_Precedence.ATOM, node.cls)
1683 self.traverse(node.cls)
1684 with self.delimit("(", ")"):
1685 patterns = node.patterns
1686 self.interleave(
1687 lambda: self.write(", "), self.traverse, patterns
1688 )
1689 attrs = node.kwd_attrs
1690 if attrs:
1691 def write_attr_pattern(pair):
1692 attr, pattern = pair
1693 self.write(f"{attr}=")
1694 self.traverse(pattern)
1695
1696 if patterns:
1697 self.write(", ")
1698 self.interleave(
1699 lambda: self.write(", "),
1700 write_attr_pattern,
1701 zip(attrs, node.kwd_patterns, strict=True),
1702 )
1703
1704 def visit_MatchAs(self, node):
1705 name = node.name
1706 pattern = node.pattern
1707 if name is None:
1708 self.write("_")
1709 elif pattern is None:
1710 self.write(node.name)
1711 else:
1712 with self.require_parens(_Precedence.TEST, node):
1713 self.set_precedence(_Precedence.BOR, node.pattern)
1714 self.traverse(node.pattern)
1715 self.write(f" as {node.name}")
1716
1717 def visit_MatchOr(self, node):
1718 with self.require_parens(_Precedence.BOR, node):
1719 self.set_precedence(_Precedence.BOR.next(), *node.patterns)
1720 self.interleave(lambda: self.write(" | "), self.traverse, node.patterns)
1721
1722 def unparse(ast_obj):
1723 unparser = _Unparser()
1724 return unparser.visit(ast_obj)
1725
1726
1727 def main():
1728 import argparse
1729
1730 parser = argparse.ArgumentParser(prog='python -m ast')
1731 parser.add_argument('infile', type=argparse.FileType(mode='rb'), nargs='?',
1732 default='-',
1733 help='the file to parse; defaults to stdin')
1734 parser.add_argument('-m', '--mode', default='exec',
1735 choices=('exec', 'single', 'eval', 'func_type'),
1736 help='specify what kind of code must be parsed')
1737 parser.add_argument('--no-type-comments', default=True, action='store_false',
1738 help="don't add information about type comments")
1739 parser.add_argument('-a', '--include-attributes', action='store_true',
1740 help='include attributes such as line numbers and '
1741 'column offsets')
1742 parser.add_argument('-i', '--indent', type=int, default=3,
1743 help='indentation of nodes (number of spaces)')
1744 args = parser.parse_args()
1745
1746 with args.infile as infile:
1747 source = infile.read()
1748 tree = parse(source, args.infile.name, args.mode, type_comments=args.no_type_comments)
1749 print(dump(tree, include_attributes=args.include_attributes, indent=args.indent))
1750
1751 if __name__ == '__main__':
1752 main()