1 /* AST Optimizer */
2 #include "Python.h"
3 #include "pycore_ast.h" // _PyAST_GetDocString()
4 #include "pycore_compile.h" // _PyASTOptimizeState
5 #include "pycore_long.h" // _PyLong
6 #include "pycore_pystate.h" // _PyThreadState_GET()
7 #include "pycore_format.h" // F_LJUST
8
9
10 static int
11 make_const(expr_ty node, PyObject *val, PyArena *arena)
12 {
13 // Even if no new value was calculated, make_const may still
14 // need to clear an error (e.g. for division by zero)
15 if (val == NULL) {
16 if (PyErr_ExceptionMatches(PyExc_KeyboardInterrupt)) {
17 return 0;
18 }
19 PyErr_Clear();
20 return 1;
21 }
22 if (_PyArena_AddPyObject(arena, val) < 0) {
23 Py_DECREF(val);
24 return 0;
25 }
26 node->kind = Constant_kind;
27 node->v.Constant.kind = NULL;
28 node->v.Constant.value = val;
29 return 1;
30 }
31
32 #define COPY_NODE(TO, FROM) (memcpy((TO), (FROM), sizeof(struct _expr)))
33
34 static int
35 has_starred(asdl_expr_seq *elts)
36 {
37 Py_ssize_t n = asdl_seq_LEN(elts);
38 for (Py_ssize_t i = 0; i < n; i++) {
39 expr_ty e = (expr_ty)asdl_seq_GET(elts, i);
40 if (e->kind == Starred_kind) {
41 return 1;
42 }
43 }
44 return 0;
45 }
46
47
48 static PyObject*
49 unary_not(PyObject *v)
50 {
51 int r = PyObject_IsTrue(v);
52 if (r < 0)
53 return NULL;
54 return PyBool_FromLong(!r);
55 }
56
57 static int
58 fold_unaryop(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
59 {
60 expr_ty arg = node->v.UnaryOp.operand;
61
62 if (arg->kind != Constant_kind) {
63 /* Fold not into comparison */
64 if (node->v.UnaryOp.op == Not && arg->kind == Compare_kind &&
65 asdl_seq_LEN(arg->v.Compare.ops) == 1) {
66 /* Eq and NotEq are often implemented in terms of one another, so
67 folding not (self == other) into self != other breaks implementation
68 of !=. Detecting such cases doesn't seem worthwhile.
69 Python uses </> for 'is subset'/'is superset' operations on sets.
70 They don't satisfy not folding laws. */
71 cmpop_ty op = asdl_seq_GET(arg->v.Compare.ops, 0);
72 switch (op) {
73 case Is:
74 op = IsNot;
75 break;
76 case IsNot:
77 op = Is;
78 break;
79 case In:
80 op = NotIn;
81 break;
82 case NotIn:
83 op = In;
84 break;
85 // The remaining comparison operators can't be safely inverted
86 case Eq:
87 case NotEq:
88 case Lt:
89 case LtE:
90 case Gt:
91 case GtE:
92 op = 0; // The AST enums leave "0" free as an "unused" marker
93 break;
94 // No default case, so the compiler will emit a warning if new
95 // comparison operators are added without being handled here
96 }
97 if (op) {
98 asdl_seq_SET(arg->v.Compare.ops, 0, op);
99 COPY_NODE(node, arg);
100 return 1;
101 }
102 }
103 return 1;
104 }
105
106 typedef PyObject *(*unary_op)(PyObject*);
107 static const unary_op ops[] = {
108 [Invert] = PyNumber_Invert,
109 [Not] = unary_not,
110 [UAdd] = PyNumber_Positive,
111 [USub] = PyNumber_Negative,
112 };
113 PyObject *newval = ops[node->v.UnaryOp.op](arg->v.Constant.value);
114 return make_const(node, newval, arena);
115 }
116
117 /* Check whether a collection doesn't containing too much items (including
118 subcollections). This protects from creating a constant that needs
119 too much time for calculating a hash.
120 "limit" is the maximal number of items.
121 Returns the negative number if the total number of items exceeds the
122 limit. Otherwise returns the limit minus the total number of items.
123 */
124
125 static Py_ssize_t
126 check_complexity(PyObject *obj, Py_ssize_t limit)
127 {
128 if (PyTuple_Check(obj)) {
129 Py_ssize_t i;
130 limit -= PyTuple_GET_SIZE(obj);
131 for (i = 0; limit >= 0 && i < PyTuple_GET_SIZE(obj); i++) {
132 limit = check_complexity(PyTuple_GET_ITEM(obj, i), limit);
133 }
134 return limit;
135 }
136 else if (PyFrozenSet_Check(obj)) {
137 Py_ssize_t i = 0;
138 PyObject *item;
139 Py_hash_t hash;
140 limit -= PySet_GET_SIZE(obj);
141 while (limit >= 0 && _PySet_NextEntry(obj, &i, &item, &hash)) {
142 limit = check_complexity(item, limit);
143 }
144 }
145 return limit;
146 }
147
148 #define MAX_INT_SIZE 128 /* bits */
149 #define MAX_COLLECTION_SIZE 256 /* items */
150 #define MAX_STR_SIZE 4096 /* characters */
151 #define MAX_TOTAL_ITEMS 1024 /* including nested collections */
152
153 static PyObject *
154 safe_multiply(PyObject *v, PyObject *w)
155 {
156 if (PyLong_Check(v) && PyLong_Check(w) &&
157 !_PyLong_IsZero((PyLongObject *)v) && !_PyLong_IsZero((PyLongObject *)w)
158 ) {
159 size_t vbits = _PyLong_NumBits(v);
160 size_t wbits = _PyLong_NumBits(w);
161 if (vbits == (size_t)-1 || wbits == (size_t)-1) {
162 return NULL;
163 }
164 if (vbits + wbits > MAX_INT_SIZE) {
165 return NULL;
166 }
167 }
168 else if (PyLong_Check(v) && (PyTuple_Check(w) || PyFrozenSet_Check(w))) {
169 Py_ssize_t size = PyTuple_Check(w) ? PyTuple_GET_SIZE(w) :
170 PySet_GET_SIZE(w);
171 if (size) {
172 long n = PyLong_AsLong(v);
173 if (n < 0 || n > MAX_COLLECTION_SIZE / size) {
174 return NULL;
175 }
176 if (n && check_complexity(w, MAX_TOTAL_ITEMS / n) < 0) {
177 return NULL;
178 }
179 }
180 }
181 else if (PyLong_Check(v) && (PyUnicode_Check(w) || PyBytes_Check(w))) {
182 Py_ssize_t size = PyUnicode_Check(w) ? PyUnicode_GET_LENGTH(w) :
183 PyBytes_GET_SIZE(w);
184 if (size) {
185 long n = PyLong_AsLong(v);
186 if (n < 0 || n > MAX_STR_SIZE / size) {
187 return NULL;
188 }
189 }
190 }
191 else if (PyLong_Check(w) &&
192 (PyTuple_Check(v) || PyFrozenSet_Check(v) ||
193 PyUnicode_Check(v) || PyBytes_Check(v)))
194 {
195 return safe_multiply(w, v);
196 }
197
198 return PyNumber_Multiply(v, w);
199 }
200
201 static PyObject *
202 safe_power(PyObject *v, PyObject *w)
203 {
204 if (PyLong_Check(v) && PyLong_Check(w) &&
205 !_PyLong_IsZero((PyLongObject *)v) && _PyLong_IsPositive((PyLongObject *)w)
206 ) {
207 size_t vbits = _PyLong_NumBits(v);
208 size_t wbits = PyLong_AsSize_t(w);
209 if (vbits == (size_t)-1 || wbits == (size_t)-1) {
210 return NULL;
211 }
212 if (vbits > MAX_INT_SIZE / wbits) {
213 return NULL;
214 }
215 }
216
217 return PyNumber_Power(v, w, Py_None);
218 }
219
220 static PyObject *
221 safe_lshift(PyObject *v, PyObject *w)
222 {
223 if (PyLong_Check(v) && PyLong_Check(w) &&
224 !_PyLong_IsZero((PyLongObject *)v) && !_PyLong_IsZero((PyLongObject *)w)
225 ) {
226 size_t vbits = _PyLong_NumBits(v);
227 size_t wbits = PyLong_AsSize_t(w);
228 if (vbits == (size_t)-1 || wbits == (size_t)-1) {
229 return NULL;
230 }
231 if (wbits > MAX_INT_SIZE || vbits > MAX_INT_SIZE - wbits) {
232 return NULL;
233 }
234 }
235
236 return PyNumber_Lshift(v, w);
237 }
238
239 static PyObject *
240 safe_mod(PyObject *v, PyObject *w)
241 {
242 if (PyUnicode_Check(v) || PyBytes_Check(v)) {
243 return NULL;
244 }
245
246 return PyNumber_Remainder(v, w);
247 }
248
249
250 static expr_ty
251 parse_literal(PyObject *fmt, Py_ssize_t *ppos, PyArena *arena)
252 {
253 const void *data = PyUnicode_DATA(fmt);
254 int kind = PyUnicode_KIND(fmt);
255 Py_ssize_t size = PyUnicode_GET_LENGTH(fmt);
256 Py_ssize_t start, pos;
257 int has_percents = 0;
258 start = pos = *ppos;
259 while (pos < size) {
260 if (PyUnicode_READ(kind, data, pos) != '%') {
261 pos++;
262 }
263 else if (pos+1 < size && PyUnicode_READ(kind, data, pos+1) == '%') {
264 has_percents = 1;
265 pos += 2;
266 }
267 else {
268 break;
269 }
270 }
271 *ppos = pos;
272 if (pos == start) {
273 return NULL;
274 }
275 PyObject *str = PyUnicode_Substring(fmt, start, pos);
276 /* str = str.replace('%%', '%') */
277 if (str && has_percents) {
278 _Py_DECLARE_STR(percent, "%");
279 _Py_DECLARE_STR(dbl_percent, "%%");
280 Py_SETREF(str, PyUnicode_Replace(str, &_Py_STR(dbl_percent),
281 &_Py_STR(percent), -1));
282 }
283 if (!str) {
284 return NULL;
285 }
286
287 if (_PyArena_AddPyObject(arena, str) < 0) {
288 Py_DECREF(str);
289 return NULL;
290 }
291 return _PyAST_Constant(str, NULL, -1, -1, -1, -1, arena);
292 }
293
294 #define MAXDIGITS 3
295
296 static int
297 simple_format_arg_parse(PyObject *fmt, Py_ssize_t *ppos,
298 int *spec, int *flags, int *width, int *prec)
299 {
300 Py_ssize_t pos = *ppos, len = PyUnicode_GET_LENGTH(fmt);
301 Py_UCS4 ch;
302
303 #define NEXTC do { \
304 if (pos >= len) { \
305 return 0; \
306 } \
307 ch = PyUnicode_READ_CHAR(fmt, pos); \
308 pos++; \
309 } while (0)
310
311 *flags = 0;
312 while (1) {
313 NEXTC;
314 switch (ch) {
315 case '-': *flags |= F_LJUST; continue;
316 case '+': *flags |= F_SIGN; continue;
317 case ' ': *flags |= F_BLANK; continue;
318 case '#': *flags |= F_ALT; continue;
319 case '0': *flags |= F_ZERO; continue;
320 }
321 break;
322 }
323 if ('0' <= ch && ch <= '9') {
324 *width = 0;
325 int digits = 0;
326 while ('0' <= ch && ch <= '9') {
327 *width = *width * 10 + (ch - '0');
328 NEXTC;
329 if (++digits >= MAXDIGITS) {
330 return 0;
331 }
332 }
333 }
334
335 if (ch == '.') {
336 NEXTC;
337 *prec = 0;
338 if ('0' <= ch && ch <= '9') {
339 int digits = 0;
340 while ('0' <= ch && ch <= '9') {
341 *prec = *prec * 10 + (ch - '0');
342 NEXTC;
343 if (++digits >= MAXDIGITS) {
344 return 0;
345 }
346 }
347 }
348 }
349 *spec = ch;
350 *ppos = pos;
351 return 1;
352
353 #undef NEXTC
354 }
355
356 static expr_ty
357 parse_format(PyObject *fmt, Py_ssize_t *ppos, expr_ty arg, PyArena *arena)
358 {
359 int spec, flags, width = -1, prec = -1;
360 if (!simple_format_arg_parse(fmt, ppos, &spec, &flags, &width, &prec)) {
361 // Unsupported format.
362 return NULL;
363 }
364 if (spec == 's' || spec == 'r' || spec == 'a') {
365 char buf[1 + MAXDIGITS + 1 + MAXDIGITS + 1], *p = buf;
366 if (!(flags & F_LJUST) && width > 0) {
367 *p++ = '>';
368 }
369 if (width >= 0) {
370 p += snprintf(p, MAXDIGITS + 1, "%d", width);
371 }
372 if (prec >= 0) {
373 p += snprintf(p, MAXDIGITS + 2, ".%d", prec);
374 }
375 expr_ty format_spec = NULL;
376 if (p != buf) {
377 PyObject *str = PyUnicode_FromString(buf);
378 if (str == NULL) {
379 return NULL;
380 }
381 if (_PyArena_AddPyObject(arena, str) < 0) {
382 Py_DECREF(str);
383 return NULL;
384 }
385 format_spec = _PyAST_Constant(str, NULL, -1, -1, -1, -1, arena);
386 if (format_spec == NULL) {
387 return NULL;
388 }
389 }
390 return _PyAST_FormattedValue(arg, spec, format_spec,
391 arg->lineno, arg->col_offset,
392 arg->end_lineno, arg->end_col_offset,
393 arena);
394 }
395 // Unsupported format.
396 return NULL;
397 }
398
399 static int
400 optimize_format(expr_ty node, PyObject *fmt, asdl_expr_seq *elts, PyArena *arena)
401 {
402 Py_ssize_t pos = 0;
403 Py_ssize_t cnt = 0;
404 asdl_expr_seq *seq = _Py_asdl_expr_seq_new(asdl_seq_LEN(elts) * 2 + 1, arena);
405 if (!seq) {
406 return 0;
407 }
408 seq->size = 0;
409
410 while (1) {
411 expr_ty lit = parse_literal(fmt, &pos, arena);
412 if (lit) {
413 asdl_seq_SET(seq, seq->size++, lit);
414 }
415 else if (PyErr_Occurred()) {
416 return 0;
417 }
418
419 if (pos >= PyUnicode_GET_LENGTH(fmt)) {
420 break;
421 }
422 if (cnt >= asdl_seq_LEN(elts)) {
423 // More format units than items.
424 return 1;
425 }
426 assert(PyUnicode_READ_CHAR(fmt, pos) == '%');
427 pos++;
428 expr_ty expr = parse_format(fmt, &pos, asdl_seq_GET(elts, cnt), arena);
429 cnt++;
430 if (!expr) {
431 return !PyErr_Occurred();
432 }
433 asdl_seq_SET(seq, seq->size++, expr);
434 }
435 if (cnt < asdl_seq_LEN(elts)) {
436 // More items than format units.
437 return 1;
438 }
439 expr_ty res = _PyAST_JoinedStr(seq,
440 node->lineno, node->col_offset,
441 node->end_lineno, node->end_col_offset,
442 arena);
443 if (!res) {
444 return 0;
445 }
446 COPY_NODE(node, res);
447 // PySys_FormatStderr("format = %R\n", fmt);
448 return 1;
449 }
450
451 static int
452 fold_binop(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
453 {
454 expr_ty lhs, rhs;
455 lhs = node->v.BinOp.left;
456 rhs = node->v.BinOp.right;
457 if (lhs->kind != Constant_kind) {
458 return 1;
459 }
460 PyObject *lv = lhs->v.Constant.value;
461
462 if (node->v.BinOp.op == Mod &&
463 rhs->kind == Tuple_kind &&
464 PyUnicode_Check(lv) &&
465 !has_starred(rhs->v.Tuple.elts))
466 {
467 return optimize_format(node, lv, rhs->v.Tuple.elts, arena);
468 }
469
470 if (rhs->kind != Constant_kind) {
471 return 1;
472 }
473
474 PyObject *rv = rhs->v.Constant.value;
475 PyObject *newval = NULL;
476
477 switch (node->v.BinOp.op) {
478 case Add:
479 newval = PyNumber_Add(lv, rv);
480 break;
481 case Sub:
482 newval = PyNumber_Subtract(lv, rv);
483 break;
484 case Mult:
485 newval = safe_multiply(lv, rv);
486 break;
487 case Div:
488 newval = PyNumber_TrueDivide(lv, rv);
489 break;
490 case FloorDiv:
491 newval = PyNumber_FloorDivide(lv, rv);
492 break;
493 case Mod:
494 newval = safe_mod(lv, rv);
495 break;
496 case Pow:
497 newval = safe_power(lv, rv);
498 break;
499 case LShift:
500 newval = safe_lshift(lv, rv);
501 break;
502 case RShift:
503 newval = PyNumber_Rshift(lv, rv);
504 break;
505 case BitOr:
506 newval = PyNumber_Or(lv, rv);
507 break;
508 case BitXor:
509 newval = PyNumber_Xor(lv, rv);
510 break;
511 case BitAnd:
512 newval = PyNumber_And(lv, rv);
513 break;
514 // No builtin constants implement the following operators
515 case MatMult:
516 return 1;
517 // No default case, so the compiler will emit a warning if new binary
518 // operators are added without being handled here
519 }
520
521 return make_const(node, newval, arena);
522 }
523
524 static PyObject*
525 make_const_tuple(asdl_expr_seq *elts)
526 {
527 for (int i = 0; i < asdl_seq_LEN(elts); i++) {
528 expr_ty e = (expr_ty)asdl_seq_GET(elts, i);
529 if (e->kind != Constant_kind) {
530 return NULL;
531 }
532 }
533
534 PyObject *newval = PyTuple_New(asdl_seq_LEN(elts));
535 if (newval == NULL) {
536 return NULL;
537 }
538
539 for (int i = 0; i < asdl_seq_LEN(elts); i++) {
540 expr_ty e = (expr_ty)asdl_seq_GET(elts, i);
541 PyObject *v = e->v.Constant.value;
542 PyTuple_SET_ITEM(newval, i, Py_NewRef(v));
543 }
544 return newval;
545 }
546
547 static int
548 fold_tuple(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
549 {
550 PyObject *newval;
551
552 if (node->v.Tuple.ctx != Load)
553 return 1;
554
555 newval = make_const_tuple(node->v.Tuple.elts);
556 return make_const(node, newval, arena);
557 }
558
559 static int
560 fold_subscr(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
561 {
562 PyObject *newval;
563 expr_ty arg, idx;
564
565 arg = node->v.Subscript.value;
566 idx = node->v.Subscript.slice;
567 if (node->v.Subscript.ctx != Load ||
568 arg->kind != Constant_kind ||
569 idx->kind != Constant_kind)
570 {
571 return 1;
572 }
573
574 newval = PyObject_GetItem(arg->v.Constant.value, idx->v.Constant.value);
575 return make_const(node, newval, arena);
576 }
577
578 /* Change literal list or set of constants into constant
579 tuple or frozenset respectively. Change literal list of
580 non-constants into tuple.
581 Used for right operand of "in" and "not in" tests and for iterable
582 in "for" loop and comprehensions.
583 */
584 static int
585 fold_iter(expr_ty arg, PyArena *arena, _PyASTOptimizeState *state)
586 {
587 PyObject *newval;
588 if (arg->kind == List_kind) {
589 /* First change a list into tuple. */
590 asdl_expr_seq *elts = arg->v.List.elts;
591 if (has_starred(elts)) {
592 return 1;
593 }
594 expr_context_ty ctx = arg->v.List.ctx;
595 arg->kind = Tuple_kind;
596 arg->v.Tuple.elts = elts;
597 arg->v.Tuple.ctx = ctx;
598 /* Try to create a constant tuple. */
599 newval = make_const_tuple(elts);
600 }
601 else if (arg->kind == Set_kind) {
602 newval = make_const_tuple(arg->v.Set.elts);
603 if (newval) {
604 Py_SETREF(newval, PyFrozenSet_New(newval));
605 }
606 }
607 else {
608 return 1;
609 }
610 return make_const(arg, newval, arena);
611 }
612
613 static int
614 fold_compare(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
615 {
616 asdl_int_seq *ops;
617 asdl_expr_seq *args;
618 Py_ssize_t i;
619
620 ops = node->v.Compare.ops;
621 args = node->v.Compare.comparators;
622 /* Change literal list or set in 'in' or 'not in' into
623 tuple or frozenset respectively. */
624 i = asdl_seq_LEN(ops) - 1;
625 int op = asdl_seq_GET(ops, i);
626 if (op == In || op == NotIn) {
627 if (!fold_iter((expr_ty)asdl_seq_GET(args, i), arena, state)) {
628 return 0;
629 }
630 }
631 return 1;
632 }
633
634 static int astfold_mod(mod_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
635 static int astfold_stmt(stmt_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
636 static int astfold_expr(expr_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
637 static int astfold_arguments(arguments_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
638 static int astfold_comprehension(comprehension_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
639 static int astfold_keyword(keyword_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
640 static int astfold_arg(arg_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
641 static int astfold_withitem(withitem_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
642 static int astfold_excepthandler(excepthandler_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
643 static int astfold_match_case(match_case_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
644 static int astfold_pattern(pattern_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
645 static int astfold_type_param(type_param_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
646
647 #define CALL(FUNC, TYPE, ARG) \
648 if (!FUNC((ARG), ctx_, state)) \
649 return 0;
650
651 #define CALL_OPT(FUNC, TYPE, ARG) \
652 if ((ARG) != NULL && !FUNC((ARG), ctx_, state)) \
653 return 0;
654
655 #define CALL_SEQ(FUNC, TYPE, ARG) { \
656 int i; \
657 asdl_ ## TYPE ## _seq *seq = (ARG); /* avoid variable capture */ \
658 for (i = 0; i < asdl_seq_LEN(seq); i++) { \
659 TYPE ## _ty elt = (TYPE ## _ty)asdl_seq_GET(seq, i); \
660 if (elt != NULL && !FUNC(elt, ctx_, state)) \
661 return 0; \
662 } \
663 }
664
665
666 static int
667 astfold_body(asdl_stmt_seq *stmts, PyArena *ctx_, _PyASTOptimizeState *state)
668 {
669 int docstring = _PyAST_GetDocString(stmts) != NULL;
670 CALL_SEQ(astfold_stmt, stmt, stmts);
671 if (!docstring && _PyAST_GetDocString(stmts) != NULL) {
672 stmt_ty st = (stmt_ty)asdl_seq_GET(stmts, 0);
673 asdl_expr_seq *values = _Py_asdl_expr_seq_new(1, ctx_);
674 if (!values) {
675 return 0;
676 }
677 asdl_seq_SET(values, 0, st->v.Expr.value);
678 expr_ty expr = _PyAST_JoinedStr(values, st->lineno, st->col_offset,
679 st->end_lineno, st->end_col_offset,
680 ctx_);
681 if (!expr) {
682 return 0;
683 }
684 st->v.Expr.value = expr;
685 }
686 return 1;
687 }
688
689 static int
690 astfold_mod(mod_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
691 {
692 switch (node_->kind) {
693 case Module_kind:
694 CALL(astfold_body, asdl_seq, node_->v.Module.body);
695 break;
696 case Interactive_kind:
697 CALL_SEQ(astfold_stmt, stmt, node_->v.Interactive.body);
698 break;
699 case Expression_kind:
700 CALL(astfold_expr, expr_ty, node_->v.Expression.body);
701 break;
702 // The following top level nodes don't participate in constant folding
703 case FunctionType_kind:
704 break;
705 // No default case, so the compiler will emit a warning if new top level
706 // compilation nodes are added without being handled here
707 }
708 return 1;
709 }
710
711 static int
712 astfold_expr(expr_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
713 {
714 if (++state->recursion_depth > state->recursion_limit) {
715 PyErr_SetString(PyExc_RecursionError,
716 "maximum recursion depth exceeded during compilation");
717 return 0;
718 }
719 switch (node_->kind) {
720 case BoolOp_kind:
721 CALL_SEQ(astfold_expr, expr, node_->v.BoolOp.values);
722 break;
723 case BinOp_kind:
724 CALL(astfold_expr, expr_ty, node_->v.BinOp.left);
725 CALL(astfold_expr, expr_ty, node_->v.BinOp.right);
726 CALL(fold_binop, expr_ty, node_);
727 break;
728 case UnaryOp_kind:
729 CALL(astfold_expr, expr_ty, node_->v.UnaryOp.operand);
730 CALL(fold_unaryop, expr_ty, node_);
731 break;
732 case Lambda_kind:
733 CALL(astfold_arguments, arguments_ty, node_->v.Lambda.args);
734 CALL(astfold_expr, expr_ty, node_->v.Lambda.body);
735 break;
736 case IfExp_kind:
737 CALL(astfold_expr, expr_ty, node_->v.IfExp.test);
738 CALL(astfold_expr, expr_ty, node_->v.IfExp.body);
739 CALL(astfold_expr, expr_ty, node_->v.IfExp.orelse);
740 break;
741 case Dict_kind:
742 CALL_SEQ(astfold_expr, expr, node_->v.Dict.keys);
743 CALL_SEQ(astfold_expr, expr, node_->v.Dict.values);
744 break;
745 case Set_kind:
746 CALL_SEQ(astfold_expr, expr, node_->v.Set.elts);
747 break;
748 case ListComp_kind:
749 CALL(astfold_expr, expr_ty, node_->v.ListComp.elt);
750 CALL_SEQ(astfold_comprehension, comprehension, node_->v.ListComp.generators);
751 break;
752 case SetComp_kind:
753 CALL(astfold_expr, expr_ty, node_->v.SetComp.elt);
754 CALL_SEQ(astfold_comprehension, comprehension, node_->v.SetComp.generators);
755 break;
756 case DictComp_kind:
757 CALL(astfold_expr, expr_ty, node_->v.DictComp.key);
758 CALL(astfold_expr, expr_ty, node_->v.DictComp.value);
759 CALL_SEQ(astfold_comprehension, comprehension, node_->v.DictComp.generators);
760 break;
761 case GeneratorExp_kind:
762 CALL(astfold_expr, expr_ty, node_->v.GeneratorExp.elt);
763 CALL_SEQ(astfold_comprehension, comprehension, node_->v.GeneratorExp.generators);
764 break;
765 case Await_kind:
766 CALL(astfold_expr, expr_ty, node_->v.Await.value);
767 break;
768 case Yield_kind:
769 CALL_OPT(astfold_expr, expr_ty, node_->v.Yield.value);
770 break;
771 case YieldFrom_kind:
772 CALL(astfold_expr, expr_ty, node_->v.YieldFrom.value);
773 break;
774 case Compare_kind:
775 CALL(astfold_expr, expr_ty, node_->v.Compare.left);
776 CALL_SEQ(astfold_expr, expr, node_->v.Compare.comparators);
777 CALL(fold_compare, expr_ty, node_);
778 break;
779 case Call_kind:
780 CALL(astfold_expr, expr_ty, node_->v.Call.func);
781 CALL_SEQ(astfold_expr, expr, node_->v.Call.args);
782 CALL_SEQ(astfold_keyword, keyword, node_->v.Call.keywords);
783 break;
784 case FormattedValue_kind:
785 CALL(astfold_expr, expr_ty, node_->v.FormattedValue.value);
786 CALL_OPT(astfold_expr, expr_ty, node_->v.FormattedValue.format_spec);
787 break;
788 case JoinedStr_kind:
789 CALL_SEQ(astfold_expr, expr, node_->v.JoinedStr.values);
790 break;
791 case Attribute_kind:
792 CALL(astfold_expr, expr_ty, node_->v.Attribute.value);
793 break;
794 case Subscript_kind:
795 CALL(astfold_expr, expr_ty, node_->v.Subscript.value);
796 CALL(astfold_expr, expr_ty, node_->v.Subscript.slice);
797 CALL(fold_subscr, expr_ty, node_);
798 break;
799 case Starred_kind:
800 CALL(astfold_expr, expr_ty, node_->v.Starred.value);
801 break;
802 case Slice_kind:
803 CALL_OPT(astfold_expr, expr_ty, node_->v.Slice.lower);
804 CALL_OPT(astfold_expr, expr_ty, node_->v.Slice.upper);
805 CALL_OPT(astfold_expr, expr_ty, node_->v.Slice.step);
806 break;
807 case List_kind:
808 CALL_SEQ(astfold_expr, expr, node_->v.List.elts);
809 break;
810 case Tuple_kind:
811 CALL_SEQ(astfold_expr, expr, node_->v.Tuple.elts);
812 CALL(fold_tuple, expr_ty, node_);
813 break;
814 case Name_kind:
815 if (node_->v.Name.ctx == Load &&
816 _PyUnicode_EqualToASCIIString(node_->v.Name.id, "__debug__")) {
817 state->recursion_depth--;
818 return make_const(node_, PyBool_FromLong(!state->optimize), ctx_);
819 }
820 break;
821 case NamedExpr_kind:
822 CALL(astfold_expr, expr_ty, node_->v.NamedExpr.value);
823 break;
824 case Constant_kind:
825 // Already a constant, nothing further to do
826 break;
827 // No default case, so the compiler will emit a warning if new expression
828 // kinds are added without being handled here
829 }
830 state->recursion_depth--;
831 return 1;
832 }
833
834 static int
835 astfold_keyword(keyword_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
836 {
837 CALL(astfold_expr, expr_ty, node_->value);
838 return 1;
839 }
840
841 static int
842 astfold_comprehension(comprehension_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
843 {
844 CALL(astfold_expr, expr_ty, node_->target);
845 CALL(astfold_expr, expr_ty, node_->iter);
846 CALL_SEQ(astfold_expr, expr, node_->ifs);
847
848 CALL(fold_iter, expr_ty, node_->iter);
849 return 1;
850 }
851
852 static int
853 astfold_arguments(arguments_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
854 {
855 CALL_SEQ(astfold_arg, arg, node_->posonlyargs);
856 CALL_SEQ(astfold_arg, arg, node_->args);
857 CALL_OPT(astfold_arg, arg_ty, node_->vararg);
858 CALL_SEQ(astfold_arg, arg, node_->kwonlyargs);
859 CALL_SEQ(astfold_expr, expr, node_->kw_defaults);
860 CALL_OPT(astfold_arg, arg_ty, node_->kwarg);
861 CALL_SEQ(astfold_expr, expr, node_->defaults);
862 return 1;
863 }
864
865 static int
866 astfold_arg(arg_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
867 {
868 if (!(state->ff_features & CO_FUTURE_ANNOTATIONS)) {
869 CALL_OPT(astfold_expr, expr_ty, node_->annotation);
870 }
871 return 1;
872 }
873
874 static int
875 astfold_stmt(stmt_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
876 {
877 if (++state->recursion_depth > state->recursion_limit) {
878 PyErr_SetString(PyExc_RecursionError,
879 "maximum recursion depth exceeded during compilation");
880 return 0;
881 }
882 switch (node_->kind) {
883 case FunctionDef_kind:
884 CALL_SEQ(astfold_type_param, type_param, node_->v.FunctionDef.type_params);
885 CALL(astfold_arguments, arguments_ty, node_->v.FunctionDef.args);
886 CALL(astfold_body, asdl_seq, node_->v.FunctionDef.body);
887 CALL_SEQ(astfold_expr, expr, node_->v.FunctionDef.decorator_list);
888 if (!(state->ff_features & CO_FUTURE_ANNOTATIONS)) {
889 CALL_OPT(astfold_expr, expr_ty, node_->v.FunctionDef.returns);
890 }
891 break;
892 case AsyncFunctionDef_kind:
893 CALL_SEQ(astfold_type_param, type_param, node_->v.AsyncFunctionDef.type_params);
894 CALL(astfold_arguments, arguments_ty, node_->v.AsyncFunctionDef.args);
895 CALL(astfold_body, asdl_seq, node_->v.AsyncFunctionDef.body);
896 CALL_SEQ(astfold_expr, expr, node_->v.AsyncFunctionDef.decorator_list);
897 if (!(state->ff_features & CO_FUTURE_ANNOTATIONS)) {
898 CALL_OPT(astfold_expr, expr_ty, node_->v.AsyncFunctionDef.returns);
899 }
900 break;
901 case ClassDef_kind:
902 CALL_SEQ(astfold_type_param, type_param, node_->v.ClassDef.type_params);
903 CALL_SEQ(astfold_expr, expr, node_->v.ClassDef.bases);
904 CALL_SEQ(astfold_keyword, keyword, node_->v.ClassDef.keywords);
905 CALL(astfold_body, asdl_seq, node_->v.ClassDef.body);
906 CALL_SEQ(astfold_expr, expr, node_->v.ClassDef.decorator_list);
907 break;
908 case Return_kind:
909 CALL_OPT(astfold_expr, expr_ty, node_->v.Return.value);
910 break;
911 case Delete_kind:
912 CALL_SEQ(astfold_expr, expr, node_->v.Delete.targets);
913 break;
914 case Assign_kind:
915 CALL_SEQ(astfold_expr, expr, node_->v.Assign.targets);
916 CALL(astfold_expr, expr_ty, node_->v.Assign.value);
917 break;
918 case AugAssign_kind:
919 CALL(astfold_expr, expr_ty, node_->v.AugAssign.target);
920 CALL(astfold_expr, expr_ty, node_->v.AugAssign.value);
921 break;
922 case AnnAssign_kind:
923 CALL(astfold_expr, expr_ty, node_->v.AnnAssign.target);
924 if (!(state->ff_features & CO_FUTURE_ANNOTATIONS)) {
925 CALL(astfold_expr, expr_ty, node_->v.AnnAssign.annotation);
926 }
927 CALL_OPT(astfold_expr, expr_ty, node_->v.AnnAssign.value);
928 break;
929 case TypeAlias_kind:
930 CALL(astfold_expr, expr_ty, node_->v.TypeAlias.name);
931 CALL_SEQ(astfold_type_param, type_param, node_->v.TypeAlias.type_params);
932 CALL(astfold_expr, expr_ty, node_->v.TypeAlias.value);
933 break;
934 case For_kind:
935 CALL(astfold_expr, expr_ty, node_->v.For.target);
936 CALL(astfold_expr, expr_ty, node_->v.For.iter);
937 CALL_SEQ(astfold_stmt, stmt, node_->v.For.body);
938 CALL_SEQ(astfold_stmt, stmt, node_->v.For.orelse);
939
940 CALL(fold_iter, expr_ty, node_->v.For.iter);
941 break;
942 case AsyncFor_kind:
943 CALL(astfold_expr, expr_ty, node_->v.AsyncFor.target);
944 CALL(astfold_expr, expr_ty, node_->v.AsyncFor.iter);
945 CALL_SEQ(astfold_stmt, stmt, node_->v.AsyncFor.body);
946 CALL_SEQ(astfold_stmt, stmt, node_->v.AsyncFor.orelse);
947 break;
948 case While_kind:
949 CALL(astfold_expr, expr_ty, node_->v.While.test);
950 CALL_SEQ(astfold_stmt, stmt, node_->v.While.body);
951 CALL_SEQ(astfold_stmt, stmt, node_->v.While.orelse);
952 break;
953 case If_kind:
954 CALL(astfold_expr, expr_ty, node_->v.If.test);
955 CALL_SEQ(astfold_stmt, stmt, node_->v.If.body);
956 CALL_SEQ(astfold_stmt, stmt, node_->v.If.orelse);
957 break;
958 case With_kind:
959 CALL_SEQ(astfold_withitem, withitem, node_->v.With.items);
960 CALL_SEQ(astfold_stmt, stmt, node_->v.With.body);
961 break;
962 case AsyncWith_kind:
963 CALL_SEQ(astfold_withitem, withitem, node_->v.AsyncWith.items);
964 CALL_SEQ(astfold_stmt, stmt, node_->v.AsyncWith.body);
965 break;
966 case Raise_kind:
967 CALL_OPT(astfold_expr, expr_ty, node_->v.Raise.exc);
968 CALL_OPT(astfold_expr, expr_ty, node_->v.Raise.cause);
969 break;
970 case Try_kind:
971 CALL_SEQ(astfold_stmt, stmt, node_->v.Try.body);
972 CALL_SEQ(astfold_excepthandler, excepthandler, node_->v.Try.handlers);
973 CALL_SEQ(astfold_stmt, stmt, node_->v.Try.orelse);
974 CALL_SEQ(astfold_stmt, stmt, node_->v.Try.finalbody);
975 break;
976 case TryStar_kind:
977 CALL_SEQ(astfold_stmt, stmt, node_->v.TryStar.body);
978 CALL_SEQ(astfold_excepthandler, excepthandler, node_->v.TryStar.handlers);
979 CALL_SEQ(astfold_stmt, stmt, node_->v.TryStar.orelse);
980 CALL_SEQ(astfold_stmt, stmt, node_->v.TryStar.finalbody);
981 break;
982 case Assert_kind:
983 CALL(astfold_expr, expr_ty, node_->v.Assert.test);
984 CALL_OPT(astfold_expr, expr_ty, node_->v.Assert.msg);
985 break;
986 case Expr_kind:
987 CALL(astfold_expr, expr_ty, node_->v.Expr.value);
988 break;
989 case Match_kind:
990 CALL(astfold_expr, expr_ty, node_->v.Match.subject);
991 CALL_SEQ(astfold_match_case, match_case, node_->v.Match.cases);
992 break;
993 // The following statements don't contain any subexpressions to be folded
994 case Import_kind:
995 case ImportFrom_kind:
996 case Global_kind:
997 case Nonlocal_kind:
998 case Pass_kind:
999 case Break_kind:
1000 case Continue_kind:
1001 break;
1002 // No default case, so the compiler will emit a warning if new statement
1003 // kinds are added without being handled here
1004 }
1005 state->recursion_depth--;
1006 return 1;
1007 }
1008
1009 static int
1010 astfold_excepthandler(excepthandler_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
1011 {
1012 switch (node_->kind) {
1013 case ExceptHandler_kind:
1014 CALL_OPT(astfold_expr, expr_ty, node_->v.ExceptHandler.type);
1015 CALL_SEQ(astfold_stmt, stmt, node_->v.ExceptHandler.body);
1016 break;
1017 // No default case, so the compiler will emit a warning if new handler
1018 // kinds are added without being handled here
1019 }
1020 return 1;
1021 }
1022
1023 static int
1024 astfold_withitem(withitem_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
1025 {
1026 CALL(astfold_expr, expr_ty, node_->context_expr);
1027 CALL_OPT(astfold_expr, expr_ty, node_->optional_vars);
1028 return 1;
1029 }
1030
1031 static int
1032 astfold_pattern(pattern_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
1033 {
1034 // Currently, this is really only used to form complex/negative numeric
1035 // constants in MatchValue and MatchMapping nodes
1036 // We still recurse into all subexpressions and subpatterns anyway
1037 if (++state->recursion_depth > state->recursion_limit) {
1038 PyErr_SetString(PyExc_RecursionError,
1039 "maximum recursion depth exceeded during compilation");
1040 return 0;
1041 }
1042 switch (node_->kind) {
1043 case MatchValue_kind:
1044 CALL(astfold_expr, expr_ty, node_->v.MatchValue.value);
1045 break;
1046 case MatchSingleton_kind:
1047 break;
1048 case MatchSequence_kind:
1049 CALL_SEQ(astfold_pattern, pattern, node_->v.MatchSequence.patterns);
1050 break;
1051 case MatchMapping_kind:
1052 CALL_SEQ(astfold_expr, expr, node_->v.MatchMapping.keys);
1053 CALL_SEQ(astfold_pattern, pattern, node_->v.MatchMapping.patterns);
1054 break;
1055 case MatchClass_kind:
1056 CALL(astfold_expr, expr_ty, node_->v.MatchClass.cls);
1057 CALL_SEQ(astfold_pattern, pattern, node_->v.MatchClass.patterns);
1058 CALL_SEQ(astfold_pattern, pattern, node_->v.MatchClass.kwd_patterns);
1059 break;
1060 case MatchStar_kind:
1061 break;
1062 case MatchAs_kind:
1063 if (node_->v.MatchAs.pattern) {
1064 CALL(astfold_pattern, pattern_ty, node_->v.MatchAs.pattern);
1065 }
1066 break;
1067 case MatchOr_kind:
1068 CALL_SEQ(astfold_pattern, pattern, node_->v.MatchOr.patterns);
1069 break;
1070 // No default case, so the compiler will emit a warning if new pattern
1071 // kinds are added without being handled here
1072 }
1073 state->recursion_depth--;
1074 return 1;
1075 }
1076
1077 static int
1078 astfold_match_case(match_case_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
1079 {
1080 CALL(astfold_pattern, expr_ty, node_->pattern);
1081 CALL_OPT(astfold_expr, expr_ty, node_->guard);
1082 CALL_SEQ(astfold_stmt, stmt, node_->body);
1083 return 1;
1084 }
1085
1086 static int
1087 astfold_type_param(type_param_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
1088 {
1089 switch (node_->kind) {
1090 case TypeVar_kind:
1091 CALL_OPT(astfold_expr, expr_ty, node_->v.TypeVar.bound);
1092 break;
1093 case ParamSpec_kind:
1094 break;
1095 case TypeVarTuple_kind:
1096 break;
1097 }
1098 return 1;
1099 }
1100
1101 #undef CALL
1102 #undef CALL_OPT
1103 #undef CALL_SEQ
1104
1105 /* See comments in symtable.c. */
1106 #define COMPILER_STACK_FRAME_SCALE 2
1107
1108 int
1109 _PyAST_Optimize(mod_ty mod, PyArena *arena, _PyASTOptimizeState *state)
1110 {
1111 PyThreadState *tstate;
1112 int starting_recursion_depth;
1113
1114 /* Setup recursion depth check counters */
1115 tstate = _PyThreadState_GET();
1116 if (!tstate) {
1117 return 0;
1118 }
1119 /* Be careful here to prevent overflow. */
1120 int recursion_depth = C_RECURSION_LIMIT - tstate->c_recursion_remaining;
1121 starting_recursion_depth = recursion_depth * COMPILER_STACK_FRAME_SCALE;
1122 state->recursion_depth = starting_recursion_depth;
1123 state->recursion_limit = C_RECURSION_LIMIT * COMPILER_STACK_FRAME_SCALE;
1124
1125 int ret = astfold_mod(mod, arena, state);
1126 assert(ret || PyErr_Occurred());
1127
1128 /* Check that the recursion depth counting balanced correctly */
1129 if (ret && state->recursion_depth != starting_recursion_depth) {
1130 PyErr_Format(PyExc_SystemError,
1131 "AST optimizer recursion depth mismatch (before=%d, after=%d)",
1132 starting_recursion_depth, state->recursion_depth);
1133 return 0;
1134 }
1135
1136 return ret;
1137 }