(root)/
Python-3.12.0/
Python/
ast_opt.c
       1  /* AST Optimizer */
       2  #include "Python.h"
       3  #include "pycore_ast.h"           // _PyAST_GetDocString()
       4  #include "pycore_compile.h"       // _PyASTOptimizeState
       5  #include "pycore_long.h"           // _PyLong
       6  #include "pycore_pystate.h"       // _PyThreadState_GET()
       7  #include "pycore_format.h"        // F_LJUST
       8  
       9  
      10  static int
      11  make_const(expr_ty node, PyObject *val, PyArena *arena)
      12  {
      13      // Even if no new value was calculated, make_const may still
      14      // need to clear an error (e.g. for division by zero)
      15      if (val == NULL) {
      16          if (PyErr_ExceptionMatches(PyExc_KeyboardInterrupt)) {
      17              return 0;
      18          }
      19          PyErr_Clear();
      20          return 1;
      21      }
      22      if (_PyArena_AddPyObject(arena, val) < 0) {
      23          Py_DECREF(val);
      24          return 0;
      25      }
      26      node->kind = Constant_kind;
      27      node->v.Constant.kind = NULL;
      28      node->v.Constant.value = val;
      29      return 1;
      30  }
      31  
      32  #define COPY_NODE(TO, FROM) (memcpy((TO), (FROM), sizeof(struct _expr)))
      33  
      34  static int
      35  has_starred(asdl_expr_seq *elts)
      36  {
      37      Py_ssize_t n = asdl_seq_LEN(elts);
      38      for (Py_ssize_t i = 0; i < n; i++) {
      39          expr_ty e = (expr_ty)asdl_seq_GET(elts, i);
      40          if (e->kind == Starred_kind) {
      41              return 1;
      42          }
      43      }
      44      return 0;
      45  }
      46  
      47  
      48  static PyObject*
      49  unary_not(PyObject *v)
      50  {
      51      int r = PyObject_IsTrue(v);
      52      if (r < 0)
      53          return NULL;
      54      return PyBool_FromLong(!r);
      55  }
      56  
      57  static int
      58  fold_unaryop(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
      59  {
      60      expr_ty arg = node->v.UnaryOp.operand;
      61  
      62      if (arg->kind != Constant_kind) {
      63          /* Fold not into comparison */
      64          if (node->v.UnaryOp.op == Not && arg->kind == Compare_kind &&
      65                  asdl_seq_LEN(arg->v.Compare.ops) == 1) {
      66              /* Eq and NotEq are often implemented in terms of one another, so
      67                 folding not (self == other) into self != other breaks implementation
      68                 of !=. Detecting such cases doesn't seem worthwhile.
      69                 Python uses </> for 'is subset'/'is superset' operations on sets.
      70                 They don't satisfy not folding laws. */
      71              cmpop_ty op = asdl_seq_GET(arg->v.Compare.ops, 0);
      72              switch (op) {
      73              case Is:
      74                  op = IsNot;
      75                  break;
      76              case IsNot:
      77                  op = Is;
      78                  break;
      79              case In:
      80                  op = NotIn;
      81                  break;
      82              case NotIn:
      83                  op = In;
      84                  break;
      85              // The remaining comparison operators can't be safely inverted
      86              case Eq:
      87              case NotEq:
      88              case Lt:
      89              case LtE:
      90              case Gt:
      91              case GtE:
      92                  op = 0; // The AST enums leave "0" free as an "unused" marker
      93                  break;
      94              // No default case, so the compiler will emit a warning if new
      95              // comparison operators are added without being handled here
      96              }
      97              if (op) {
      98                  asdl_seq_SET(arg->v.Compare.ops, 0, op);
      99                  COPY_NODE(node, arg);
     100                  return 1;
     101              }
     102          }
     103          return 1;
     104      }
     105  
     106      typedef PyObject *(*unary_op)(PyObject*);
     107      static const unary_op ops[] = {
     108          [Invert] = PyNumber_Invert,
     109          [Not] = unary_not,
     110          [UAdd] = PyNumber_Positive,
     111          [USub] = PyNumber_Negative,
     112      };
     113      PyObject *newval = ops[node->v.UnaryOp.op](arg->v.Constant.value);
     114      return make_const(node, newval, arena);
     115  }
     116  
     117  /* Check whether a collection doesn't containing too much items (including
     118     subcollections).  This protects from creating a constant that needs
     119     too much time for calculating a hash.
     120     "limit" is the maximal number of items.
     121     Returns the negative number if the total number of items exceeds the
     122     limit.  Otherwise returns the limit minus the total number of items.
     123  */
     124  
     125  static Py_ssize_t
     126  check_complexity(PyObject *obj, Py_ssize_t limit)
     127  {
     128      if (PyTuple_Check(obj)) {
     129          Py_ssize_t i;
     130          limit -= PyTuple_GET_SIZE(obj);
     131          for (i = 0; limit >= 0 && i < PyTuple_GET_SIZE(obj); i++) {
     132              limit = check_complexity(PyTuple_GET_ITEM(obj, i), limit);
     133          }
     134          return limit;
     135      }
     136      else if (PyFrozenSet_Check(obj)) {
     137          Py_ssize_t i = 0;
     138          PyObject *item;
     139          Py_hash_t hash;
     140          limit -= PySet_GET_SIZE(obj);
     141          while (limit >= 0 && _PySet_NextEntry(obj, &i, &item, &hash)) {
     142              limit = check_complexity(item, limit);
     143          }
     144      }
     145      return limit;
     146  }
     147  
     148  #define MAX_INT_SIZE           128  /* bits */
     149  #define MAX_COLLECTION_SIZE    256  /* items */
     150  #define MAX_STR_SIZE          4096  /* characters */
     151  #define MAX_TOTAL_ITEMS       1024  /* including nested collections */
     152  
     153  static PyObject *
     154  safe_multiply(PyObject *v, PyObject *w)
     155  {
     156      if (PyLong_Check(v) && PyLong_Check(w) &&
     157          !_PyLong_IsZero((PyLongObject *)v) && !_PyLong_IsZero((PyLongObject *)w)
     158      ) {
     159          size_t vbits = _PyLong_NumBits(v);
     160          size_t wbits = _PyLong_NumBits(w);
     161          if (vbits == (size_t)-1 || wbits == (size_t)-1) {
     162              return NULL;
     163          }
     164          if (vbits + wbits > MAX_INT_SIZE) {
     165              return NULL;
     166          }
     167      }
     168      else if (PyLong_Check(v) && (PyTuple_Check(w) || PyFrozenSet_Check(w))) {
     169          Py_ssize_t size = PyTuple_Check(w) ? PyTuple_GET_SIZE(w) :
     170                                               PySet_GET_SIZE(w);
     171          if (size) {
     172              long n = PyLong_AsLong(v);
     173              if (n < 0 || n > MAX_COLLECTION_SIZE / size) {
     174                  return NULL;
     175              }
     176              if (n && check_complexity(w, MAX_TOTAL_ITEMS / n) < 0) {
     177                  return NULL;
     178              }
     179          }
     180      }
     181      else if (PyLong_Check(v) && (PyUnicode_Check(w) || PyBytes_Check(w))) {
     182          Py_ssize_t size = PyUnicode_Check(w) ? PyUnicode_GET_LENGTH(w) :
     183                                                 PyBytes_GET_SIZE(w);
     184          if (size) {
     185              long n = PyLong_AsLong(v);
     186              if (n < 0 || n > MAX_STR_SIZE / size) {
     187                  return NULL;
     188              }
     189          }
     190      }
     191      else if (PyLong_Check(w) &&
     192               (PyTuple_Check(v) || PyFrozenSet_Check(v) ||
     193                PyUnicode_Check(v) || PyBytes_Check(v)))
     194      {
     195          return safe_multiply(w, v);
     196      }
     197  
     198      return PyNumber_Multiply(v, w);
     199  }
     200  
     201  static PyObject *
     202  safe_power(PyObject *v, PyObject *w)
     203  {
     204      if (PyLong_Check(v) && PyLong_Check(w) &&
     205          !_PyLong_IsZero((PyLongObject *)v) && _PyLong_IsPositive((PyLongObject *)w)
     206      ) {
     207          size_t vbits = _PyLong_NumBits(v);
     208          size_t wbits = PyLong_AsSize_t(w);
     209          if (vbits == (size_t)-1 || wbits == (size_t)-1) {
     210              return NULL;
     211          }
     212          if (vbits > MAX_INT_SIZE / wbits) {
     213              return NULL;
     214          }
     215      }
     216  
     217      return PyNumber_Power(v, w, Py_None);
     218  }
     219  
     220  static PyObject *
     221  safe_lshift(PyObject *v, PyObject *w)
     222  {
     223      if (PyLong_Check(v) && PyLong_Check(w) &&
     224          !_PyLong_IsZero((PyLongObject *)v) && !_PyLong_IsZero((PyLongObject *)w)
     225      ) {
     226          size_t vbits = _PyLong_NumBits(v);
     227          size_t wbits = PyLong_AsSize_t(w);
     228          if (vbits == (size_t)-1 || wbits == (size_t)-1) {
     229              return NULL;
     230          }
     231          if (wbits > MAX_INT_SIZE || vbits > MAX_INT_SIZE - wbits) {
     232              return NULL;
     233          }
     234      }
     235  
     236      return PyNumber_Lshift(v, w);
     237  }
     238  
     239  static PyObject *
     240  safe_mod(PyObject *v, PyObject *w)
     241  {
     242      if (PyUnicode_Check(v) || PyBytes_Check(v)) {
     243          return NULL;
     244      }
     245  
     246      return PyNumber_Remainder(v, w);
     247  }
     248  
     249  
     250  static expr_ty
     251  parse_literal(PyObject *fmt, Py_ssize_t *ppos, PyArena *arena)
     252  {
     253      const void *data = PyUnicode_DATA(fmt);
     254      int kind = PyUnicode_KIND(fmt);
     255      Py_ssize_t size = PyUnicode_GET_LENGTH(fmt);
     256      Py_ssize_t start, pos;
     257      int has_percents = 0;
     258      start = pos = *ppos;
     259      while (pos < size) {
     260          if (PyUnicode_READ(kind, data, pos) != '%') {
     261              pos++;
     262          }
     263          else if (pos+1 < size && PyUnicode_READ(kind, data, pos+1) == '%') {
     264              has_percents = 1;
     265              pos += 2;
     266          }
     267          else {
     268              break;
     269          }
     270      }
     271      *ppos = pos;
     272      if (pos == start) {
     273          return NULL;
     274      }
     275      PyObject *str = PyUnicode_Substring(fmt, start, pos);
     276      /* str = str.replace('%%', '%') */
     277      if (str && has_percents) {
     278          _Py_DECLARE_STR(percent, "%");
     279          _Py_DECLARE_STR(dbl_percent, "%%");
     280          Py_SETREF(str, PyUnicode_Replace(str, &_Py_STR(dbl_percent),
     281                                           &_Py_STR(percent), -1));
     282      }
     283      if (!str) {
     284          return NULL;
     285      }
     286  
     287      if (_PyArena_AddPyObject(arena, str) < 0) {
     288          Py_DECREF(str);
     289          return NULL;
     290      }
     291      return _PyAST_Constant(str, NULL, -1, -1, -1, -1, arena);
     292  }
     293  
     294  #define MAXDIGITS 3
     295  
     296  static int
     297  simple_format_arg_parse(PyObject *fmt, Py_ssize_t *ppos,
     298                          int *spec, int *flags, int *width, int *prec)
     299  {
     300      Py_ssize_t pos = *ppos, len = PyUnicode_GET_LENGTH(fmt);
     301      Py_UCS4 ch;
     302  
     303  #define NEXTC do {                      \
     304      if (pos >= len) {                   \
     305          return 0;                       \
     306      }                                   \
     307      ch = PyUnicode_READ_CHAR(fmt, pos); \
     308      pos++;                              \
     309  } while (0)
     310  
     311      *flags = 0;
     312      while (1) {
     313          NEXTC;
     314          switch (ch) {
     315              case '-': *flags |= F_LJUST; continue;
     316              case '+': *flags |= F_SIGN; continue;
     317              case ' ': *flags |= F_BLANK; continue;
     318              case '#': *flags |= F_ALT; continue;
     319              case '0': *flags |= F_ZERO; continue;
     320          }
     321          break;
     322      }
     323      if ('0' <= ch && ch <= '9') {
     324          *width = 0;
     325          int digits = 0;
     326          while ('0' <= ch && ch <= '9') {
     327              *width = *width * 10 + (ch - '0');
     328              NEXTC;
     329              if (++digits >= MAXDIGITS) {
     330                  return 0;
     331              }
     332          }
     333      }
     334  
     335      if (ch == '.') {
     336          NEXTC;
     337          *prec = 0;
     338          if ('0' <= ch && ch <= '9') {
     339              int digits = 0;
     340              while ('0' <= ch && ch <= '9') {
     341                  *prec = *prec * 10 + (ch - '0');
     342                  NEXTC;
     343                  if (++digits >= MAXDIGITS) {
     344                      return 0;
     345                  }
     346              }
     347          }
     348      }
     349      *spec = ch;
     350      *ppos = pos;
     351      return 1;
     352  
     353  #undef NEXTC
     354  }
     355  
     356  static expr_ty
     357  parse_format(PyObject *fmt, Py_ssize_t *ppos, expr_ty arg, PyArena *arena)
     358  {
     359      int spec, flags, width = -1, prec = -1;
     360      if (!simple_format_arg_parse(fmt, ppos, &spec, &flags, &width, &prec)) {
     361          // Unsupported format.
     362          return NULL;
     363      }
     364      if (spec == 's' || spec == 'r' || spec == 'a') {
     365          char buf[1 + MAXDIGITS + 1 + MAXDIGITS + 1], *p = buf;
     366          if (!(flags & F_LJUST) && width > 0) {
     367              *p++ = '>';
     368          }
     369          if (width >= 0) {
     370              p += snprintf(p, MAXDIGITS + 1, "%d", width);
     371          }
     372          if (prec >= 0) {
     373              p += snprintf(p, MAXDIGITS + 2, ".%d", prec);
     374          }
     375          expr_ty format_spec = NULL;
     376          if (p != buf) {
     377              PyObject *str = PyUnicode_FromString(buf);
     378              if (str == NULL) {
     379                  return NULL;
     380              }
     381              if (_PyArena_AddPyObject(arena, str) < 0) {
     382                  Py_DECREF(str);
     383                  return NULL;
     384              }
     385              format_spec = _PyAST_Constant(str, NULL, -1, -1, -1, -1, arena);
     386              if (format_spec == NULL) {
     387                  return NULL;
     388              }
     389          }
     390          return _PyAST_FormattedValue(arg, spec, format_spec,
     391                                       arg->lineno, arg->col_offset,
     392                                       arg->end_lineno, arg->end_col_offset,
     393                                       arena);
     394      }
     395      // Unsupported format.
     396      return NULL;
     397  }
     398  
     399  static int
     400  optimize_format(expr_ty node, PyObject *fmt, asdl_expr_seq *elts, PyArena *arena)
     401  {
     402      Py_ssize_t pos = 0;
     403      Py_ssize_t cnt = 0;
     404      asdl_expr_seq *seq = _Py_asdl_expr_seq_new(asdl_seq_LEN(elts) * 2 + 1, arena);
     405      if (!seq) {
     406          return 0;
     407      }
     408      seq->size = 0;
     409  
     410      while (1) {
     411          expr_ty lit = parse_literal(fmt, &pos, arena);
     412          if (lit) {
     413              asdl_seq_SET(seq, seq->size++, lit);
     414          }
     415          else if (PyErr_Occurred()) {
     416              return 0;
     417          }
     418  
     419          if (pos >= PyUnicode_GET_LENGTH(fmt)) {
     420              break;
     421          }
     422          if (cnt >= asdl_seq_LEN(elts)) {
     423              // More format units than items.
     424              return 1;
     425          }
     426          assert(PyUnicode_READ_CHAR(fmt, pos) == '%');
     427          pos++;
     428          expr_ty expr = parse_format(fmt, &pos, asdl_seq_GET(elts, cnt), arena);
     429          cnt++;
     430          if (!expr) {
     431              return !PyErr_Occurred();
     432          }
     433          asdl_seq_SET(seq, seq->size++, expr);
     434      }
     435      if (cnt < asdl_seq_LEN(elts)) {
     436          // More items than format units.
     437          return 1;
     438      }
     439      expr_ty res = _PyAST_JoinedStr(seq,
     440                                     node->lineno, node->col_offset,
     441                                     node->end_lineno, node->end_col_offset,
     442                                     arena);
     443      if (!res) {
     444          return 0;
     445      }
     446      COPY_NODE(node, res);
     447  //     PySys_FormatStderr("format = %R\n", fmt);
     448      return 1;
     449  }
     450  
     451  static int
     452  fold_binop(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
     453  {
     454      expr_ty lhs, rhs;
     455      lhs = node->v.BinOp.left;
     456      rhs = node->v.BinOp.right;
     457      if (lhs->kind != Constant_kind) {
     458          return 1;
     459      }
     460      PyObject *lv = lhs->v.Constant.value;
     461  
     462      if (node->v.BinOp.op == Mod &&
     463          rhs->kind == Tuple_kind &&
     464          PyUnicode_Check(lv) &&
     465          !has_starred(rhs->v.Tuple.elts))
     466      {
     467          return optimize_format(node, lv, rhs->v.Tuple.elts, arena);
     468      }
     469  
     470      if (rhs->kind != Constant_kind) {
     471          return 1;
     472      }
     473  
     474      PyObject *rv = rhs->v.Constant.value;
     475      PyObject *newval = NULL;
     476  
     477      switch (node->v.BinOp.op) {
     478      case Add:
     479          newval = PyNumber_Add(lv, rv);
     480          break;
     481      case Sub:
     482          newval = PyNumber_Subtract(lv, rv);
     483          break;
     484      case Mult:
     485          newval = safe_multiply(lv, rv);
     486          break;
     487      case Div:
     488          newval = PyNumber_TrueDivide(lv, rv);
     489          break;
     490      case FloorDiv:
     491          newval = PyNumber_FloorDivide(lv, rv);
     492          break;
     493      case Mod:
     494          newval = safe_mod(lv, rv);
     495          break;
     496      case Pow:
     497          newval = safe_power(lv, rv);
     498          break;
     499      case LShift:
     500          newval = safe_lshift(lv, rv);
     501          break;
     502      case RShift:
     503          newval = PyNumber_Rshift(lv, rv);
     504          break;
     505      case BitOr:
     506          newval = PyNumber_Or(lv, rv);
     507          break;
     508      case BitXor:
     509          newval = PyNumber_Xor(lv, rv);
     510          break;
     511      case BitAnd:
     512          newval = PyNumber_And(lv, rv);
     513          break;
     514      // No builtin constants implement the following operators
     515      case MatMult:
     516          return 1;
     517      // No default case, so the compiler will emit a warning if new binary
     518      // operators are added without being handled here
     519      }
     520  
     521      return make_const(node, newval, arena);
     522  }
     523  
     524  static PyObject*
     525  make_const_tuple(asdl_expr_seq *elts)
     526  {
     527      for (int i = 0; i < asdl_seq_LEN(elts); i++) {
     528          expr_ty e = (expr_ty)asdl_seq_GET(elts, i);
     529          if (e->kind != Constant_kind) {
     530              return NULL;
     531          }
     532      }
     533  
     534      PyObject *newval = PyTuple_New(asdl_seq_LEN(elts));
     535      if (newval == NULL) {
     536          return NULL;
     537      }
     538  
     539      for (int i = 0; i < asdl_seq_LEN(elts); i++) {
     540          expr_ty e = (expr_ty)asdl_seq_GET(elts, i);
     541          PyObject *v = e->v.Constant.value;
     542          PyTuple_SET_ITEM(newval, i, Py_NewRef(v));
     543      }
     544      return newval;
     545  }
     546  
     547  static int
     548  fold_tuple(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
     549  {
     550      PyObject *newval;
     551  
     552      if (node->v.Tuple.ctx != Load)
     553          return 1;
     554  
     555      newval = make_const_tuple(node->v.Tuple.elts);
     556      return make_const(node, newval, arena);
     557  }
     558  
     559  static int
     560  fold_subscr(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
     561  {
     562      PyObject *newval;
     563      expr_ty arg, idx;
     564  
     565      arg = node->v.Subscript.value;
     566      idx = node->v.Subscript.slice;
     567      if (node->v.Subscript.ctx != Load ||
     568              arg->kind != Constant_kind ||
     569              idx->kind != Constant_kind)
     570      {
     571          return 1;
     572      }
     573  
     574      newval = PyObject_GetItem(arg->v.Constant.value, idx->v.Constant.value);
     575      return make_const(node, newval, arena);
     576  }
     577  
     578  /* Change literal list or set of constants into constant
     579     tuple or frozenset respectively.  Change literal list of
     580     non-constants into tuple.
     581     Used for right operand of "in" and "not in" tests and for iterable
     582     in "for" loop and comprehensions.
     583  */
     584  static int
     585  fold_iter(expr_ty arg, PyArena *arena, _PyASTOptimizeState *state)
     586  {
     587      PyObject *newval;
     588      if (arg->kind == List_kind) {
     589          /* First change a list into tuple. */
     590          asdl_expr_seq *elts = arg->v.List.elts;
     591          if (has_starred(elts)) {
     592              return 1;
     593          }
     594          expr_context_ty ctx = arg->v.List.ctx;
     595          arg->kind = Tuple_kind;
     596          arg->v.Tuple.elts = elts;
     597          arg->v.Tuple.ctx = ctx;
     598          /* Try to create a constant tuple. */
     599          newval = make_const_tuple(elts);
     600      }
     601      else if (arg->kind == Set_kind) {
     602          newval = make_const_tuple(arg->v.Set.elts);
     603          if (newval) {
     604              Py_SETREF(newval, PyFrozenSet_New(newval));
     605          }
     606      }
     607      else {
     608          return 1;
     609      }
     610      return make_const(arg, newval, arena);
     611  }
     612  
     613  static int
     614  fold_compare(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
     615  {
     616      asdl_int_seq *ops;
     617      asdl_expr_seq *args;
     618      Py_ssize_t i;
     619  
     620      ops = node->v.Compare.ops;
     621      args = node->v.Compare.comparators;
     622      /* Change literal list or set in 'in' or 'not in' into
     623         tuple or frozenset respectively. */
     624      i = asdl_seq_LEN(ops) - 1;
     625      int op = asdl_seq_GET(ops, i);
     626      if (op == In || op == NotIn) {
     627          if (!fold_iter((expr_ty)asdl_seq_GET(args, i), arena, state)) {
     628              return 0;
     629          }
     630      }
     631      return 1;
     632  }
     633  
     634  static int astfold_mod(mod_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
     635  static int astfold_stmt(stmt_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
     636  static int astfold_expr(expr_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
     637  static int astfold_arguments(arguments_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
     638  static int astfold_comprehension(comprehension_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
     639  static int astfold_keyword(keyword_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
     640  static int astfold_arg(arg_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
     641  static int astfold_withitem(withitem_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
     642  static int astfold_excepthandler(excepthandler_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
     643  static int astfold_match_case(match_case_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
     644  static int astfold_pattern(pattern_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
     645  static int astfold_type_param(type_param_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
     646  
     647  #define CALL(FUNC, TYPE, ARG) \
     648      if (!FUNC((ARG), ctx_, state)) \
     649          return 0;
     650  
     651  #define CALL_OPT(FUNC, TYPE, ARG) \
     652      if ((ARG) != NULL && !FUNC((ARG), ctx_, state)) \
     653          return 0;
     654  
     655  #define CALL_SEQ(FUNC, TYPE, ARG) { \
     656      int i; \
     657      asdl_ ## TYPE ## _seq *seq = (ARG); /* avoid variable capture */ \
     658      for (i = 0; i < asdl_seq_LEN(seq); i++) { \
     659          TYPE ## _ty elt = (TYPE ## _ty)asdl_seq_GET(seq, i); \
     660          if (elt != NULL && !FUNC(elt, ctx_, state)) \
     661              return 0; \
     662      } \
     663  }
     664  
     665  
     666  static int
     667  astfold_body(asdl_stmt_seq *stmts, PyArena *ctx_, _PyASTOptimizeState *state)
     668  {
     669      int docstring = _PyAST_GetDocString(stmts) != NULL;
     670      CALL_SEQ(astfold_stmt, stmt, stmts);
     671      if (!docstring && _PyAST_GetDocString(stmts) != NULL) {
     672          stmt_ty st = (stmt_ty)asdl_seq_GET(stmts, 0);
     673          asdl_expr_seq *values = _Py_asdl_expr_seq_new(1, ctx_);
     674          if (!values) {
     675              return 0;
     676          }
     677          asdl_seq_SET(values, 0, st->v.Expr.value);
     678          expr_ty expr = _PyAST_JoinedStr(values, st->lineno, st->col_offset,
     679                                          st->end_lineno, st->end_col_offset,
     680                                          ctx_);
     681          if (!expr) {
     682              return 0;
     683          }
     684          st->v.Expr.value = expr;
     685      }
     686      return 1;
     687  }
     688  
     689  static int
     690  astfold_mod(mod_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
     691  {
     692      switch (node_->kind) {
     693      case Module_kind:
     694          CALL(astfold_body, asdl_seq, node_->v.Module.body);
     695          break;
     696      case Interactive_kind:
     697          CALL_SEQ(astfold_stmt, stmt, node_->v.Interactive.body);
     698          break;
     699      case Expression_kind:
     700          CALL(astfold_expr, expr_ty, node_->v.Expression.body);
     701          break;
     702      // The following top level nodes don't participate in constant folding
     703      case FunctionType_kind:
     704          break;
     705      // No default case, so the compiler will emit a warning if new top level
     706      // compilation nodes are added without being handled here
     707      }
     708      return 1;
     709  }
     710  
     711  static int
     712  astfold_expr(expr_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
     713  {
     714      if (++state->recursion_depth > state->recursion_limit) {
     715          PyErr_SetString(PyExc_RecursionError,
     716                          "maximum recursion depth exceeded during compilation");
     717          return 0;
     718      }
     719      switch (node_->kind) {
     720      case BoolOp_kind:
     721          CALL_SEQ(astfold_expr, expr, node_->v.BoolOp.values);
     722          break;
     723      case BinOp_kind:
     724          CALL(astfold_expr, expr_ty, node_->v.BinOp.left);
     725          CALL(astfold_expr, expr_ty, node_->v.BinOp.right);
     726          CALL(fold_binop, expr_ty, node_);
     727          break;
     728      case UnaryOp_kind:
     729          CALL(astfold_expr, expr_ty, node_->v.UnaryOp.operand);
     730          CALL(fold_unaryop, expr_ty, node_);
     731          break;
     732      case Lambda_kind:
     733          CALL(astfold_arguments, arguments_ty, node_->v.Lambda.args);
     734          CALL(astfold_expr, expr_ty, node_->v.Lambda.body);
     735          break;
     736      case IfExp_kind:
     737          CALL(astfold_expr, expr_ty, node_->v.IfExp.test);
     738          CALL(astfold_expr, expr_ty, node_->v.IfExp.body);
     739          CALL(astfold_expr, expr_ty, node_->v.IfExp.orelse);
     740          break;
     741      case Dict_kind:
     742          CALL_SEQ(astfold_expr, expr, node_->v.Dict.keys);
     743          CALL_SEQ(astfold_expr, expr, node_->v.Dict.values);
     744          break;
     745      case Set_kind:
     746          CALL_SEQ(astfold_expr, expr, node_->v.Set.elts);
     747          break;
     748      case ListComp_kind:
     749          CALL(astfold_expr, expr_ty, node_->v.ListComp.elt);
     750          CALL_SEQ(astfold_comprehension, comprehension, node_->v.ListComp.generators);
     751          break;
     752      case SetComp_kind:
     753          CALL(astfold_expr, expr_ty, node_->v.SetComp.elt);
     754          CALL_SEQ(astfold_comprehension, comprehension, node_->v.SetComp.generators);
     755          break;
     756      case DictComp_kind:
     757          CALL(astfold_expr, expr_ty, node_->v.DictComp.key);
     758          CALL(astfold_expr, expr_ty, node_->v.DictComp.value);
     759          CALL_SEQ(astfold_comprehension, comprehension, node_->v.DictComp.generators);
     760          break;
     761      case GeneratorExp_kind:
     762          CALL(astfold_expr, expr_ty, node_->v.GeneratorExp.elt);
     763          CALL_SEQ(astfold_comprehension, comprehension, node_->v.GeneratorExp.generators);
     764          break;
     765      case Await_kind:
     766          CALL(astfold_expr, expr_ty, node_->v.Await.value);
     767          break;
     768      case Yield_kind:
     769          CALL_OPT(astfold_expr, expr_ty, node_->v.Yield.value);
     770          break;
     771      case YieldFrom_kind:
     772          CALL(astfold_expr, expr_ty, node_->v.YieldFrom.value);
     773          break;
     774      case Compare_kind:
     775          CALL(astfold_expr, expr_ty, node_->v.Compare.left);
     776          CALL_SEQ(astfold_expr, expr, node_->v.Compare.comparators);
     777          CALL(fold_compare, expr_ty, node_);
     778          break;
     779      case Call_kind:
     780          CALL(astfold_expr, expr_ty, node_->v.Call.func);
     781          CALL_SEQ(astfold_expr, expr, node_->v.Call.args);
     782          CALL_SEQ(astfold_keyword, keyword, node_->v.Call.keywords);
     783          break;
     784      case FormattedValue_kind:
     785          CALL(astfold_expr, expr_ty, node_->v.FormattedValue.value);
     786          CALL_OPT(astfold_expr, expr_ty, node_->v.FormattedValue.format_spec);
     787          break;
     788      case JoinedStr_kind:
     789          CALL_SEQ(astfold_expr, expr, node_->v.JoinedStr.values);
     790          break;
     791      case Attribute_kind:
     792          CALL(astfold_expr, expr_ty, node_->v.Attribute.value);
     793          break;
     794      case Subscript_kind:
     795          CALL(astfold_expr, expr_ty, node_->v.Subscript.value);
     796          CALL(astfold_expr, expr_ty, node_->v.Subscript.slice);
     797          CALL(fold_subscr, expr_ty, node_);
     798          break;
     799      case Starred_kind:
     800          CALL(astfold_expr, expr_ty, node_->v.Starred.value);
     801          break;
     802      case Slice_kind:
     803          CALL_OPT(astfold_expr, expr_ty, node_->v.Slice.lower);
     804          CALL_OPT(astfold_expr, expr_ty, node_->v.Slice.upper);
     805          CALL_OPT(astfold_expr, expr_ty, node_->v.Slice.step);
     806          break;
     807      case List_kind:
     808          CALL_SEQ(astfold_expr, expr, node_->v.List.elts);
     809          break;
     810      case Tuple_kind:
     811          CALL_SEQ(astfold_expr, expr, node_->v.Tuple.elts);
     812          CALL(fold_tuple, expr_ty, node_);
     813          break;
     814      case Name_kind:
     815          if (node_->v.Name.ctx == Load &&
     816                  _PyUnicode_EqualToASCIIString(node_->v.Name.id, "__debug__")) {
     817              state->recursion_depth--;
     818              return make_const(node_, PyBool_FromLong(!state->optimize), ctx_);
     819          }
     820          break;
     821      case NamedExpr_kind:
     822          CALL(astfold_expr, expr_ty, node_->v.NamedExpr.value);
     823          break;
     824      case Constant_kind:
     825          // Already a constant, nothing further to do
     826          break;
     827      // No default case, so the compiler will emit a warning if new expression
     828      // kinds are added without being handled here
     829      }
     830      state->recursion_depth--;
     831      return 1;
     832  }
     833  
     834  static int
     835  astfold_keyword(keyword_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
     836  {
     837      CALL(astfold_expr, expr_ty, node_->value);
     838      return 1;
     839  }
     840  
     841  static int
     842  astfold_comprehension(comprehension_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
     843  {
     844      CALL(astfold_expr, expr_ty, node_->target);
     845      CALL(astfold_expr, expr_ty, node_->iter);
     846      CALL_SEQ(astfold_expr, expr, node_->ifs);
     847  
     848      CALL(fold_iter, expr_ty, node_->iter);
     849      return 1;
     850  }
     851  
     852  static int
     853  astfold_arguments(arguments_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
     854  {
     855      CALL_SEQ(astfold_arg, arg, node_->posonlyargs);
     856      CALL_SEQ(astfold_arg, arg, node_->args);
     857      CALL_OPT(astfold_arg, arg_ty, node_->vararg);
     858      CALL_SEQ(astfold_arg, arg, node_->kwonlyargs);
     859      CALL_SEQ(astfold_expr, expr, node_->kw_defaults);
     860      CALL_OPT(astfold_arg, arg_ty, node_->kwarg);
     861      CALL_SEQ(astfold_expr, expr, node_->defaults);
     862      return 1;
     863  }
     864  
     865  static int
     866  astfold_arg(arg_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
     867  {
     868      if (!(state->ff_features & CO_FUTURE_ANNOTATIONS)) {
     869          CALL_OPT(astfold_expr, expr_ty, node_->annotation);
     870      }
     871      return 1;
     872  }
     873  
     874  static int
     875  astfold_stmt(stmt_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
     876  {
     877      if (++state->recursion_depth > state->recursion_limit) {
     878          PyErr_SetString(PyExc_RecursionError,
     879                          "maximum recursion depth exceeded during compilation");
     880          return 0;
     881      }
     882      switch (node_->kind) {
     883      case FunctionDef_kind:
     884          CALL_SEQ(astfold_type_param, type_param, node_->v.FunctionDef.type_params);
     885          CALL(astfold_arguments, arguments_ty, node_->v.FunctionDef.args);
     886          CALL(astfold_body, asdl_seq, node_->v.FunctionDef.body);
     887          CALL_SEQ(astfold_expr, expr, node_->v.FunctionDef.decorator_list);
     888          if (!(state->ff_features & CO_FUTURE_ANNOTATIONS)) {
     889              CALL_OPT(astfold_expr, expr_ty, node_->v.FunctionDef.returns);
     890          }
     891          break;
     892      case AsyncFunctionDef_kind:
     893          CALL_SEQ(astfold_type_param, type_param, node_->v.AsyncFunctionDef.type_params);
     894          CALL(astfold_arguments, arguments_ty, node_->v.AsyncFunctionDef.args);
     895          CALL(astfold_body, asdl_seq, node_->v.AsyncFunctionDef.body);
     896          CALL_SEQ(astfold_expr, expr, node_->v.AsyncFunctionDef.decorator_list);
     897          if (!(state->ff_features & CO_FUTURE_ANNOTATIONS)) {
     898              CALL_OPT(astfold_expr, expr_ty, node_->v.AsyncFunctionDef.returns);
     899          }
     900          break;
     901      case ClassDef_kind:
     902          CALL_SEQ(astfold_type_param, type_param, node_->v.ClassDef.type_params);
     903          CALL_SEQ(astfold_expr, expr, node_->v.ClassDef.bases);
     904          CALL_SEQ(astfold_keyword, keyword, node_->v.ClassDef.keywords);
     905          CALL(astfold_body, asdl_seq, node_->v.ClassDef.body);
     906          CALL_SEQ(astfold_expr, expr, node_->v.ClassDef.decorator_list);
     907          break;
     908      case Return_kind:
     909          CALL_OPT(astfold_expr, expr_ty, node_->v.Return.value);
     910          break;
     911      case Delete_kind:
     912          CALL_SEQ(astfold_expr, expr, node_->v.Delete.targets);
     913          break;
     914      case Assign_kind:
     915          CALL_SEQ(astfold_expr, expr, node_->v.Assign.targets);
     916          CALL(astfold_expr, expr_ty, node_->v.Assign.value);
     917          break;
     918      case AugAssign_kind:
     919          CALL(astfold_expr, expr_ty, node_->v.AugAssign.target);
     920          CALL(astfold_expr, expr_ty, node_->v.AugAssign.value);
     921          break;
     922      case AnnAssign_kind:
     923          CALL(astfold_expr, expr_ty, node_->v.AnnAssign.target);
     924          if (!(state->ff_features & CO_FUTURE_ANNOTATIONS)) {
     925              CALL(astfold_expr, expr_ty, node_->v.AnnAssign.annotation);
     926          }
     927          CALL_OPT(astfold_expr, expr_ty, node_->v.AnnAssign.value);
     928          break;
     929      case TypeAlias_kind:
     930          CALL(astfold_expr, expr_ty, node_->v.TypeAlias.name);
     931          CALL_SEQ(astfold_type_param, type_param, node_->v.TypeAlias.type_params);
     932          CALL(astfold_expr, expr_ty, node_->v.TypeAlias.value);
     933          break;
     934      case For_kind:
     935          CALL(astfold_expr, expr_ty, node_->v.For.target);
     936          CALL(astfold_expr, expr_ty, node_->v.For.iter);
     937          CALL_SEQ(astfold_stmt, stmt, node_->v.For.body);
     938          CALL_SEQ(astfold_stmt, stmt, node_->v.For.orelse);
     939  
     940          CALL(fold_iter, expr_ty, node_->v.For.iter);
     941          break;
     942      case AsyncFor_kind:
     943          CALL(astfold_expr, expr_ty, node_->v.AsyncFor.target);
     944          CALL(astfold_expr, expr_ty, node_->v.AsyncFor.iter);
     945          CALL_SEQ(astfold_stmt, stmt, node_->v.AsyncFor.body);
     946          CALL_SEQ(astfold_stmt, stmt, node_->v.AsyncFor.orelse);
     947          break;
     948      case While_kind:
     949          CALL(astfold_expr, expr_ty, node_->v.While.test);
     950          CALL_SEQ(astfold_stmt, stmt, node_->v.While.body);
     951          CALL_SEQ(astfold_stmt, stmt, node_->v.While.orelse);
     952          break;
     953      case If_kind:
     954          CALL(astfold_expr, expr_ty, node_->v.If.test);
     955          CALL_SEQ(astfold_stmt, stmt, node_->v.If.body);
     956          CALL_SEQ(astfold_stmt, stmt, node_->v.If.orelse);
     957          break;
     958      case With_kind:
     959          CALL_SEQ(astfold_withitem, withitem, node_->v.With.items);
     960          CALL_SEQ(astfold_stmt, stmt, node_->v.With.body);
     961          break;
     962      case AsyncWith_kind:
     963          CALL_SEQ(astfold_withitem, withitem, node_->v.AsyncWith.items);
     964          CALL_SEQ(astfold_stmt, stmt, node_->v.AsyncWith.body);
     965          break;
     966      case Raise_kind:
     967          CALL_OPT(astfold_expr, expr_ty, node_->v.Raise.exc);
     968          CALL_OPT(astfold_expr, expr_ty, node_->v.Raise.cause);
     969          break;
     970      case Try_kind:
     971          CALL_SEQ(astfold_stmt, stmt, node_->v.Try.body);
     972          CALL_SEQ(astfold_excepthandler, excepthandler, node_->v.Try.handlers);
     973          CALL_SEQ(astfold_stmt, stmt, node_->v.Try.orelse);
     974          CALL_SEQ(astfold_stmt, stmt, node_->v.Try.finalbody);
     975          break;
     976      case TryStar_kind:
     977          CALL_SEQ(astfold_stmt, stmt, node_->v.TryStar.body);
     978          CALL_SEQ(astfold_excepthandler, excepthandler, node_->v.TryStar.handlers);
     979          CALL_SEQ(astfold_stmt, stmt, node_->v.TryStar.orelse);
     980          CALL_SEQ(astfold_stmt, stmt, node_->v.TryStar.finalbody);
     981          break;
     982      case Assert_kind:
     983          CALL(astfold_expr, expr_ty, node_->v.Assert.test);
     984          CALL_OPT(astfold_expr, expr_ty, node_->v.Assert.msg);
     985          break;
     986      case Expr_kind:
     987          CALL(astfold_expr, expr_ty, node_->v.Expr.value);
     988          break;
     989      case Match_kind:
     990          CALL(astfold_expr, expr_ty, node_->v.Match.subject);
     991          CALL_SEQ(astfold_match_case, match_case, node_->v.Match.cases);
     992          break;
     993      // The following statements don't contain any subexpressions to be folded
     994      case Import_kind:
     995      case ImportFrom_kind:
     996      case Global_kind:
     997      case Nonlocal_kind:
     998      case Pass_kind:
     999      case Break_kind:
    1000      case Continue_kind:
    1001          break;
    1002      // No default case, so the compiler will emit a warning if new statement
    1003      // kinds are added without being handled here
    1004      }
    1005      state->recursion_depth--;
    1006      return 1;
    1007  }
    1008  
    1009  static int
    1010  astfold_excepthandler(excepthandler_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
    1011  {
    1012      switch (node_->kind) {
    1013      case ExceptHandler_kind:
    1014          CALL_OPT(astfold_expr, expr_ty, node_->v.ExceptHandler.type);
    1015          CALL_SEQ(astfold_stmt, stmt, node_->v.ExceptHandler.body);
    1016          break;
    1017      // No default case, so the compiler will emit a warning if new handler
    1018      // kinds are added without being handled here
    1019      }
    1020      return 1;
    1021  }
    1022  
    1023  static int
    1024  astfold_withitem(withitem_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
    1025  {
    1026      CALL(astfold_expr, expr_ty, node_->context_expr);
    1027      CALL_OPT(astfold_expr, expr_ty, node_->optional_vars);
    1028      return 1;
    1029  }
    1030  
    1031  static int
    1032  astfold_pattern(pattern_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
    1033  {
    1034      // Currently, this is really only used to form complex/negative numeric
    1035      // constants in MatchValue and MatchMapping nodes
    1036      // We still recurse into all subexpressions and subpatterns anyway
    1037      if (++state->recursion_depth > state->recursion_limit) {
    1038          PyErr_SetString(PyExc_RecursionError,
    1039                          "maximum recursion depth exceeded during compilation");
    1040          return 0;
    1041      }
    1042      switch (node_->kind) {
    1043          case MatchValue_kind:
    1044              CALL(astfold_expr, expr_ty, node_->v.MatchValue.value);
    1045              break;
    1046          case MatchSingleton_kind:
    1047              break;
    1048          case MatchSequence_kind:
    1049              CALL_SEQ(astfold_pattern, pattern, node_->v.MatchSequence.patterns);
    1050              break;
    1051          case MatchMapping_kind:
    1052              CALL_SEQ(astfold_expr, expr, node_->v.MatchMapping.keys);
    1053              CALL_SEQ(astfold_pattern, pattern, node_->v.MatchMapping.patterns);
    1054              break;
    1055          case MatchClass_kind:
    1056              CALL(astfold_expr, expr_ty, node_->v.MatchClass.cls);
    1057              CALL_SEQ(astfold_pattern, pattern, node_->v.MatchClass.patterns);
    1058              CALL_SEQ(astfold_pattern, pattern, node_->v.MatchClass.kwd_patterns);
    1059              break;
    1060          case MatchStar_kind:
    1061              break;
    1062          case MatchAs_kind:
    1063              if (node_->v.MatchAs.pattern) {
    1064                  CALL(astfold_pattern, pattern_ty, node_->v.MatchAs.pattern);
    1065              }
    1066              break;
    1067          case MatchOr_kind:
    1068              CALL_SEQ(astfold_pattern, pattern, node_->v.MatchOr.patterns);
    1069              break;
    1070      // No default case, so the compiler will emit a warning if new pattern
    1071      // kinds are added without being handled here
    1072      }
    1073      state->recursion_depth--;
    1074      return 1;
    1075  }
    1076  
    1077  static int
    1078  astfold_match_case(match_case_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
    1079  {
    1080      CALL(astfold_pattern, expr_ty, node_->pattern);
    1081      CALL_OPT(astfold_expr, expr_ty, node_->guard);
    1082      CALL_SEQ(astfold_stmt, stmt, node_->body);
    1083      return 1;
    1084  }
    1085  
    1086  static int
    1087  astfold_type_param(type_param_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
    1088  {
    1089      switch (node_->kind) {
    1090          case TypeVar_kind:
    1091              CALL_OPT(astfold_expr, expr_ty, node_->v.TypeVar.bound);
    1092              break;
    1093          case ParamSpec_kind:
    1094              break;
    1095          case TypeVarTuple_kind:
    1096              break;
    1097      }
    1098      return 1;
    1099  }
    1100  
    1101  #undef CALL
    1102  #undef CALL_OPT
    1103  #undef CALL_SEQ
    1104  
    1105  /* See comments in symtable.c. */
    1106  #define COMPILER_STACK_FRAME_SCALE 2
    1107  
    1108  int
    1109  _PyAST_Optimize(mod_ty mod, PyArena *arena, _PyASTOptimizeState *state)
    1110  {
    1111      PyThreadState *tstate;
    1112      int starting_recursion_depth;
    1113  
    1114      /* Setup recursion depth check counters */
    1115      tstate = _PyThreadState_GET();
    1116      if (!tstate) {
    1117          return 0;
    1118      }
    1119      /* Be careful here to prevent overflow. */
    1120      int recursion_depth = C_RECURSION_LIMIT - tstate->c_recursion_remaining;
    1121      starting_recursion_depth = recursion_depth * COMPILER_STACK_FRAME_SCALE;
    1122      state->recursion_depth = starting_recursion_depth;
    1123      state->recursion_limit = C_RECURSION_LIMIT * COMPILER_STACK_FRAME_SCALE;
    1124  
    1125      int ret = astfold_mod(mod, arena, state);
    1126      assert(ret || PyErr_Occurred());
    1127  
    1128      /* Check that the recursion depth counting balanced correctly */
    1129      if (ret && state->recursion_depth != starting_recursion_depth) {
    1130          PyErr_Format(PyExc_SystemError,
    1131              "AST optimizer recursion depth mismatch (before=%d, after=%d)",
    1132              starting_recursion_depth, state->recursion_depth);
    1133          return 0;
    1134      }
    1135  
    1136      return ret;
    1137  }