1 /* AST Optimizer */
2 #include "Python.h"
3 #include "Python-ast.h"
4 #include "ast.h"
5 
6 
7 static int
make_const(expr_ty node,PyObject * val,PyArena * arena)8 make_const(expr_ty node, PyObject *val, PyArena *arena)
9 {
10     if (val == NULL) {
11         if (PyErr_ExceptionMatches(PyExc_KeyboardInterrupt)) {
12             return 0;
13         }
14         PyErr_Clear();
15         return 1;
16     }
17     if (PyArena_AddPyObject(arena, val) < 0) {
18         Py_DECREF(val);
19         return 0;
20     }
21     node->kind = Constant_kind;
22     node->v.Constant.kind = NULL;
23     node->v.Constant.value = val;
24     return 1;
25 }
26 
27 #define COPY_NODE(TO, FROM) (memcpy((TO), (FROM), sizeof(struct _expr)))
28 
29 static PyObject*
unary_not(PyObject * v)30 unary_not(PyObject *v)
31 {
32     int r = PyObject_IsTrue(v);
33     if (r < 0)
34         return NULL;
35     return PyBool_FromLong(!r);
36 }
37 
38 static int
fold_unaryop(expr_ty node,PyArena * arena,_PyASTOptimizeState * state)39 fold_unaryop(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
40 {
41     expr_ty arg = node->v.UnaryOp.operand;
42 
43     if (arg->kind != Constant_kind) {
44         /* Fold not into comparison */
45         if (node->v.UnaryOp.op == Not && arg->kind == Compare_kind &&
46                 asdl_seq_LEN(arg->v.Compare.ops) == 1) {
47             /* Eq and NotEq are often implemented in terms of one another, so
48                folding not (self == other) into self != other breaks implementation
49                of !=. Detecting such cases doesn't seem worthwhile.
50                Python uses </> for 'is subset'/'is superset' operations on sets.
51                They don't satisfy not folding laws. */
52             int op = asdl_seq_GET(arg->v.Compare.ops, 0);
53             switch (op) {
54             case Is:
55                 op = IsNot;
56                 break;
57             case IsNot:
58                 op = Is;
59                 break;
60             case In:
61                 op = NotIn;
62                 break;
63             case NotIn:
64                 op = In;
65                 break;
66             default:
67                 op = 0;
68             }
69             if (op) {
70                 asdl_seq_SET(arg->v.Compare.ops, 0, op);
71                 COPY_NODE(node, arg);
72                 return 1;
73             }
74         }
75         return 1;
76     }
77 
78     typedef PyObject *(*unary_op)(PyObject*);
79     static const unary_op ops[] = {
80         [Invert] = PyNumber_Invert,
81         [Not] = unary_not,
82         [UAdd] = PyNumber_Positive,
83         [USub] = PyNumber_Negative,
84     };
85     PyObject *newval = ops[node->v.UnaryOp.op](arg->v.Constant.value);
86     return make_const(node, newval, arena);
87 }
88 
89 /* Check whether a collection doesn't containing too much items (including
90    subcollections).  This protects from creating a constant that needs
91    too much time for calculating a hash.
92    "limit" is the maximal number of items.
93    Returns the negative number if the total number of items exceeds the
94    limit.  Otherwise returns the limit minus the total number of items.
95 */
96 
97 static Py_ssize_t
check_complexity(PyObject * obj,Py_ssize_t limit)98 check_complexity(PyObject *obj, Py_ssize_t limit)
99 {
100     if (PyTuple_Check(obj)) {
101         Py_ssize_t i;
102         limit -= PyTuple_GET_SIZE(obj);
103         for (i = 0; limit >= 0 && i < PyTuple_GET_SIZE(obj); i++) {
104             limit = check_complexity(PyTuple_GET_ITEM(obj, i), limit);
105         }
106         return limit;
107     }
108     else if (PyFrozenSet_Check(obj)) {
109         Py_ssize_t i = 0;
110         PyObject *item;
111         Py_hash_t hash;
112         limit -= PySet_GET_SIZE(obj);
113         while (limit >= 0 && _PySet_NextEntry(obj, &i, &item, &hash)) {
114             limit = check_complexity(item, limit);
115         }
116     }
117     return limit;
118 }
119 
120 #define MAX_INT_SIZE           128  /* bits */
121 #define MAX_COLLECTION_SIZE    256  /* items */
122 #define MAX_STR_SIZE          4096  /* characters */
123 #define MAX_TOTAL_ITEMS       1024  /* including nested collections */
124 
125 static PyObject *
safe_multiply(PyObject * v,PyObject * w)126 safe_multiply(PyObject *v, PyObject *w)
127 {
128     if (PyLong_Check(v) && PyLong_Check(w) && Py_SIZE(v) && Py_SIZE(w)) {
129         size_t vbits = _PyLong_NumBits(v);
130         size_t wbits = _PyLong_NumBits(w);
131         if (vbits == (size_t)-1 || wbits == (size_t)-1) {
132             return NULL;
133         }
134         if (vbits + wbits > MAX_INT_SIZE) {
135             return NULL;
136         }
137     }
138     else if (PyLong_Check(v) && (PyTuple_Check(w) || PyFrozenSet_Check(w))) {
139         Py_ssize_t size = PyTuple_Check(w) ? PyTuple_GET_SIZE(w) :
140                                              PySet_GET_SIZE(w);
141         if (size) {
142             long n = PyLong_AsLong(v);
143             if (n < 0 || n > MAX_COLLECTION_SIZE / size) {
144                 return NULL;
145             }
146             if (n && check_complexity(w, MAX_TOTAL_ITEMS / n) < 0) {
147                 return NULL;
148             }
149         }
150     }
151     else if (PyLong_Check(v) && (PyUnicode_Check(w) || PyBytes_Check(w))) {
152         Py_ssize_t size = PyUnicode_Check(w) ? PyUnicode_GET_LENGTH(w) :
153                                                PyBytes_GET_SIZE(w);
154         if (size) {
155             long n = PyLong_AsLong(v);
156             if (n < 0 || n > MAX_STR_SIZE / size) {
157                 return NULL;
158             }
159         }
160     }
161     else if (PyLong_Check(w) &&
162              (PyTuple_Check(v) || PyFrozenSet_Check(v) ||
163               PyUnicode_Check(v) || PyBytes_Check(v)))
164     {
165         return safe_multiply(w, v);
166     }
167 
168     return PyNumber_Multiply(v, w);
169 }
170 
171 static PyObject *
safe_power(PyObject * v,PyObject * w)172 safe_power(PyObject *v, PyObject *w)
173 {
174     if (PyLong_Check(v) && PyLong_Check(w) && Py_SIZE(v) && Py_SIZE(w) > 0) {
175         size_t vbits = _PyLong_NumBits(v);
176         size_t wbits = PyLong_AsSize_t(w);
177         if (vbits == (size_t)-1 || wbits == (size_t)-1) {
178             return NULL;
179         }
180         if (vbits > MAX_INT_SIZE / wbits) {
181             return NULL;
182         }
183     }
184 
185     return PyNumber_Power(v, w, Py_None);
186 }
187 
188 static PyObject *
safe_lshift(PyObject * v,PyObject * w)189 safe_lshift(PyObject *v, PyObject *w)
190 {
191     if (PyLong_Check(v) && PyLong_Check(w) && Py_SIZE(v) && Py_SIZE(w)) {
192         size_t vbits = _PyLong_NumBits(v);
193         size_t wbits = PyLong_AsSize_t(w);
194         if (vbits == (size_t)-1 || wbits == (size_t)-1) {
195             return NULL;
196         }
197         if (wbits > MAX_INT_SIZE || vbits > MAX_INT_SIZE - wbits) {
198             return NULL;
199         }
200     }
201 
202     return PyNumber_Lshift(v, w);
203 }
204 
205 static PyObject *
safe_mod(PyObject * v,PyObject * w)206 safe_mod(PyObject *v, PyObject *w)
207 {
208     if (PyUnicode_Check(v) || PyBytes_Check(v)) {
209         return NULL;
210     }
211 
212     return PyNumber_Remainder(v, w);
213 }
214 
215 static int
fold_binop(expr_ty node,PyArena * arena,_PyASTOptimizeState * state)216 fold_binop(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
217 {
218     expr_ty lhs, rhs;
219     lhs = node->v.BinOp.left;
220     rhs = node->v.BinOp.right;
221     if (lhs->kind != Constant_kind || rhs->kind != Constant_kind) {
222         return 1;
223     }
224 
225     PyObject *lv = lhs->v.Constant.value;
226     PyObject *rv = rhs->v.Constant.value;
227     PyObject *newval;
228 
229     switch (node->v.BinOp.op) {
230     case Add:
231         newval = PyNumber_Add(lv, rv);
232         break;
233     case Sub:
234         newval = PyNumber_Subtract(lv, rv);
235         break;
236     case Mult:
237         newval = safe_multiply(lv, rv);
238         break;
239     case Div:
240         newval = PyNumber_TrueDivide(lv, rv);
241         break;
242     case FloorDiv:
243         newval = PyNumber_FloorDivide(lv, rv);
244         break;
245     case Mod:
246         newval = safe_mod(lv, rv);
247         break;
248     case Pow:
249         newval = safe_power(lv, rv);
250         break;
251     case LShift:
252         newval = safe_lshift(lv, rv);
253         break;
254     case RShift:
255         newval = PyNumber_Rshift(lv, rv);
256         break;
257     case BitOr:
258         newval = PyNumber_Or(lv, rv);
259         break;
260     case BitXor:
261         newval = PyNumber_Xor(lv, rv);
262         break;
263     case BitAnd:
264         newval = PyNumber_And(lv, rv);
265         break;
266     default: // Unknown operator
267         return 1;
268     }
269 
270     return make_const(node, newval, arena);
271 }
272 
273 static PyObject*
make_const_tuple(asdl_seq * elts)274 make_const_tuple(asdl_seq *elts)
275 {
276     for (int i = 0; i < asdl_seq_LEN(elts); i++) {
277         expr_ty e = (expr_ty)asdl_seq_GET(elts, i);
278         if (e->kind != Constant_kind) {
279             return NULL;
280         }
281     }
282 
283     PyObject *newval = PyTuple_New(asdl_seq_LEN(elts));
284     if (newval == NULL) {
285         return NULL;
286     }
287 
288     for (int i = 0; i < asdl_seq_LEN(elts); i++) {
289         expr_ty e = (expr_ty)asdl_seq_GET(elts, i);
290         PyObject *v = e->v.Constant.value;
291         Py_INCREF(v);
292         PyTuple_SET_ITEM(newval, i, v);
293     }
294     return newval;
295 }
296 
297 static int
fold_tuple(expr_ty node,PyArena * arena,_PyASTOptimizeState * state)298 fold_tuple(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
299 {
300     PyObject *newval;
301 
302     if (node->v.Tuple.ctx != Load)
303         return 1;
304 
305     newval = make_const_tuple(node->v.Tuple.elts);
306     return make_const(node, newval, arena);
307 }
308 
309 static int
fold_subscr(expr_ty node,PyArena * arena,_PyASTOptimizeState * state)310 fold_subscr(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
311 {
312     PyObject *newval;
313     expr_ty arg, idx;
314 
315     arg = node->v.Subscript.value;
316     idx = node->v.Subscript.slice;
317     if (node->v.Subscript.ctx != Load ||
318             arg->kind != Constant_kind ||
319             idx->kind != Constant_kind)
320     {
321         return 1;
322     }
323 
324     newval = PyObject_GetItem(arg->v.Constant.value, idx->v.Constant.value);
325     return make_const(node, newval, arena);
326 }
327 
328 /* Change literal list or set of constants into constant
329    tuple or frozenset respectively.  Change literal list of
330    non-constants into tuple.
331    Used for right operand of "in" and "not in" tests and for iterable
332    in "for" loop and comprehensions.
333 */
334 static int
fold_iter(expr_ty arg,PyArena * arena,_PyASTOptimizeState * state)335 fold_iter(expr_ty arg, PyArena *arena, _PyASTOptimizeState *state)
336 {
337     PyObject *newval;
338     if (arg->kind == List_kind) {
339         /* First change a list into tuple. */
340         asdl_seq *elts = arg->v.List.elts;
341         Py_ssize_t n = asdl_seq_LEN(elts);
342         for (Py_ssize_t i = 0; i < n; i++) {
343             expr_ty e = (expr_ty)asdl_seq_GET(elts, i);
344             if (e->kind == Starred_kind) {
345                 return 1;
346             }
347         }
348         expr_context_ty ctx = arg->v.List.ctx;
349         arg->kind = Tuple_kind;
350         arg->v.Tuple.elts = elts;
351         arg->v.Tuple.ctx = ctx;
352         /* Try to create a constant tuple. */
353         newval = make_const_tuple(elts);
354     }
355     else if (arg->kind == Set_kind) {
356         newval = make_const_tuple(arg->v.Set.elts);
357         if (newval) {
358             Py_SETREF(newval, PyFrozenSet_New(newval));
359         }
360     }
361     else {
362         return 1;
363     }
364     return make_const(arg, newval, arena);
365 }
366 
367 static int
fold_compare(expr_ty node,PyArena * arena,_PyASTOptimizeState * state)368 fold_compare(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
369 {
370     asdl_int_seq *ops;
371     asdl_seq *args;
372     Py_ssize_t i;
373 
374     ops = node->v.Compare.ops;
375     args = node->v.Compare.comparators;
376     /* TODO: optimize cases with literal arguments. */
377     /* Change literal list or set in 'in' or 'not in' into
378        tuple or frozenset respectively. */
379     i = asdl_seq_LEN(ops) - 1;
380     int op = asdl_seq_GET(ops, i);
381     if (op == In || op == NotIn) {
382         if (!fold_iter((expr_ty)asdl_seq_GET(args, i), arena, state)) {
383             return 0;
384         }
385     }
386     return 1;
387 }
388 
389 static int astfold_mod(mod_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
390 static int astfold_stmt(stmt_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
391 static int astfold_expr(expr_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
392 static int astfold_arguments(arguments_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
393 static int astfold_comprehension(comprehension_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
394 static int astfold_keyword(keyword_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
395 static int astfold_arg(arg_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
396 static int astfold_withitem(withitem_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
397 static int astfold_excepthandler(excepthandler_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
398 #define CALL(FUNC, TYPE, ARG) \
399     if (!FUNC((ARG), ctx_, state)) \
400         return 0;
401 
402 #define CALL_OPT(FUNC, TYPE, ARG) \
403     if ((ARG) != NULL && !FUNC((ARG), ctx_, state)) \
404         return 0;
405 
406 #define CALL_SEQ(FUNC, TYPE, ARG) { \
407     int i; \
408     asdl_seq *seq = (ARG); /* avoid variable capture */ \
409     for (i = 0; i < asdl_seq_LEN(seq); i++) { \
410         TYPE elt = (TYPE)asdl_seq_GET(seq, i); \
411         if (elt != NULL && !FUNC(elt, ctx_, state)) \
412             return 0; \
413     } \
414 }
415 
416 #define CALL_INT_SEQ(FUNC, TYPE, ARG) { \
417     int i; \
418     asdl_int_seq *seq = (ARG); /* avoid variable capture */ \
419     for (i = 0; i < asdl_seq_LEN(seq); i++) { \
420         TYPE elt = (TYPE)asdl_seq_GET(seq, i); \
421         if (!FUNC(elt, ctx_, state)) \
422             return 0; \
423     } \
424 }
425 
426 static int
astfold_body(asdl_seq * stmts,PyArena * ctx_,_PyASTOptimizeState * state)427 astfold_body(asdl_seq *stmts, PyArena *ctx_, _PyASTOptimizeState *state)
428 {
429     int docstring = _PyAST_GetDocString(stmts) != NULL;
430     CALL_SEQ(astfold_stmt, stmt_ty, stmts);
431     if (!docstring && _PyAST_GetDocString(stmts) != NULL) {
432         stmt_ty st = (stmt_ty)asdl_seq_GET(stmts, 0);
433         asdl_seq *values = _Py_asdl_seq_new(1, ctx_);
434         if (!values) {
435             return 0;
436         }
437         asdl_seq_SET(values, 0, st->v.Expr.value);
438         expr_ty expr = JoinedStr(values, st->lineno, st->col_offset,
439                                  st->end_lineno, st->end_col_offset, ctx_);
440         if (!expr) {
441             return 0;
442         }
443         st->v.Expr.value = expr;
444     }
445     return 1;
446 }
447 
448 static int
astfold_mod(mod_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)449 astfold_mod(mod_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
450 {
451     switch (node_->kind) {
452     case Module_kind:
453         CALL(astfold_body, asdl_seq, node_->v.Module.body);
454         break;
455     case Interactive_kind:
456         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.Interactive.body);
457         break;
458     case Expression_kind:
459         CALL(astfold_expr, expr_ty, node_->v.Expression.body);
460         break;
461     default:
462         break;
463     }
464     return 1;
465 }
466 
467 static int
astfold_expr(expr_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)468 astfold_expr(expr_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
469 {
470     switch (node_->kind) {
471     case BoolOp_kind:
472         CALL_SEQ(astfold_expr, expr_ty, node_->v.BoolOp.values);
473         break;
474     case BinOp_kind:
475         CALL(astfold_expr, expr_ty, node_->v.BinOp.left);
476         CALL(astfold_expr, expr_ty, node_->v.BinOp.right);
477         CALL(fold_binop, expr_ty, node_);
478         break;
479     case UnaryOp_kind:
480         CALL(astfold_expr, expr_ty, node_->v.UnaryOp.operand);
481         CALL(fold_unaryop, expr_ty, node_);
482         break;
483     case Lambda_kind:
484         CALL(astfold_arguments, arguments_ty, node_->v.Lambda.args);
485         CALL(astfold_expr, expr_ty, node_->v.Lambda.body);
486         break;
487     case IfExp_kind:
488         CALL(astfold_expr, expr_ty, node_->v.IfExp.test);
489         CALL(astfold_expr, expr_ty, node_->v.IfExp.body);
490         CALL(astfold_expr, expr_ty, node_->v.IfExp.orelse);
491         break;
492     case Dict_kind:
493         CALL_SEQ(astfold_expr, expr_ty, node_->v.Dict.keys);
494         CALL_SEQ(astfold_expr, expr_ty, node_->v.Dict.values);
495         break;
496     case Set_kind:
497         CALL_SEQ(astfold_expr, expr_ty, node_->v.Set.elts);
498         break;
499     case ListComp_kind:
500         CALL(astfold_expr, expr_ty, node_->v.ListComp.elt);
501         CALL_SEQ(astfold_comprehension, comprehension_ty, node_->v.ListComp.generators);
502         break;
503     case SetComp_kind:
504         CALL(astfold_expr, expr_ty, node_->v.SetComp.elt);
505         CALL_SEQ(astfold_comprehension, comprehension_ty, node_->v.SetComp.generators);
506         break;
507     case DictComp_kind:
508         CALL(astfold_expr, expr_ty, node_->v.DictComp.key);
509         CALL(astfold_expr, expr_ty, node_->v.DictComp.value);
510         CALL_SEQ(astfold_comprehension, comprehension_ty, node_->v.DictComp.generators);
511         break;
512     case GeneratorExp_kind:
513         CALL(astfold_expr, expr_ty, node_->v.GeneratorExp.elt);
514         CALL_SEQ(astfold_comprehension, comprehension_ty, node_->v.GeneratorExp.generators);
515         break;
516     case Await_kind:
517         CALL(astfold_expr, expr_ty, node_->v.Await.value);
518         break;
519     case Yield_kind:
520         CALL_OPT(astfold_expr, expr_ty, node_->v.Yield.value);
521         break;
522     case YieldFrom_kind:
523         CALL(astfold_expr, expr_ty, node_->v.YieldFrom.value);
524         break;
525     case Compare_kind:
526         CALL(astfold_expr, expr_ty, node_->v.Compare.left);
527         CALL_SEQ(astfold_expr, expr_ty, node_->v.Compare.comparators);
528         CALL(fold_compare, expr_ty, node_);
529         break;
530     case Call_kind:
531         CALL(astfold_expr, expr_ty, node_->v.Call.func);
532         CALL_SEQ(astfold_expr, expr_ty, node_->v.Call.args);
533         CALL_SEQ(astfold_keyword, keyword_ty, node_->v.Call.keywords);
534         break;
535     case FormattedValue_kind:
536         CALL(astfold_expr, expr_ty, node_->v.FormattedValue.value);
537         CALL_OPT(astfold_expr, expr_ty, node_->v.FormattedValue.format_spec);
538         break;
539     case JoinedStr_kind:
540         CALL_SEQ(astfold_expr, expr_ty, node_->v.JoinedStr.values);
541         break;
542     case Attribute_kind:
543         CALL(astfold_expr, expr_ty, node_->v.Attribute.value);
544         break;
545     case Subscript_kind:
546         CALL(astfold_expr, expr_ty, node_->v.Subscript.value);
547         CALL(astfold_expr, expr_ty, node_->v.Subscript.slice);
548         CALL(fold_subscr, expr_ty, node_);
549         break;
550     case Starred_kind:
551         CALL(astfold_expr, expr_ty, node_->v.Starred.value);
552         break;
553     case Slice_kind:
554         CALL_OPT(astfold_expr, expr_ty, node_->v.Slice.lower);
555         CALL_OPT(astfold_expr, expr_ty, node_->v.Slice.upper);
556         CALL_OPT(astfold_expr, expr_ty, node_->v.Slice.step);
557         break;
558     case List_kind:
559         CALL_SEQ(astfold_expr, expr_ty, node_->v.List.elts);
560         break;
561     case Tuple_kind:
562         CALL_SEQ(astfold_expr, expr_ty, node_->v.Tuple.elts);
563         CALL(fold_tuple, expr_ty, node_);
564         break;
565     case Name_kind:
566         if (node_->v.Name.ctx == Load &&
567                 _PyUnicode_EqualToASCIIString(node_->v.Name.id, "__debug__")) {
568             return make_const(node_, PyBool_FromLong(!state->optimize), ctx_);
569         }
570         break;
571     default:
572         break;
573     }
574     return 1;
575 }
576 
577 static int
astfold_keyword(keyword_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)578 astfold_keyword(keyword_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
579 {
580     CALL(astfold_expr, expr_ty, node_->value);
581     return 1;
582 }
583 
584 static int
astfold_comprehension(comprehension_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)585 astfold_comprehension(comprehension_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
586 {
587     CALL(astfold_expr, expr_ty, node_->target);
588     CALL(astfold_expr, expr_ty, node_->iter);
589     CALL_SEQ(astfold_expr, expr_ty, node_->ifs);
590 
591     CALL(fold_iter, expr_ty, node_->iter);
592     return 1;
593 }
594 
595 static int
astfold_arguments(arguments_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)596 astfold_arguments(arguments_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
597 {
598     CALL_SEQ(astfold_arg, arg_ty, node_->posonlyargs);
599     CALL_SEQ(astfold_arg, arg_ty, node_->args);
600     CALL_OPT(astfold_arg, arg_ty, node_->vararg);
601     CALL_SEQ(astfold_arg, arg_ty, node_->kwonlyargs);
602     CALL_SEQ(astfold_expr, expr_ty, node_->kw_defaults);
603     CALL_OPT(astfold_arg, arg_ty, node_->kwarg);
604     CALL_SEQ(astfold_expr, expr_ty, node_->defaults);
605     return 1;
606 }
607 
608 static int
astfold_arg(arg_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)609 astfold_arg(arg_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
610 {
611     if (!(state->ff_features & CO_FUTURE_ANNOTATIONS)) {
612         CALL_OPT(astfold_expr, expr_ty, node_->annotation);
613     }
614     return 1;
615 }
616 
617 static int
astfold_stmt(stmt_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)618 astfold_stmt(stmt_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
619 {
620     switch (node_->kind) {
621     case FunctionDef_kind:
622         CALL(astfold_arguments, arguments_ty, node_->v.FunctionDef.args);
623         CALL(astfold_body, asdl_seq, node_->v.FunctionDef.body);
624         CALL_SEQ(astfold_expr, expr_ty, node_->v.FunctionDef.decorator_list);
625         if (!(state->ff_features & CO_FUTURE_ANNOTATIONS)) {
626             CALL_OPT(astfold_expr, expr_ty, node_->v.FunctionDef.returns);
627         }
628         break;
629     case AsyncFunctionDef_kind:
630         CALL(astfold_arguments, arguments_ty, node_->v.AsyncFunctionDef.args);
631         CALL(astfold_body, asdl_seq, node_->v.AsyncFunctionDef.body);
632         CALL_SEQ(astfold_expr, expr_ty, node_->v.AsyncFunctionDef.decorator_list);
633         if (!(state->ff_features & CO_FUTURE_ANNOTATIONS)) {
634             CALL_OPT(astfold_expr, expr_ty, node_->v.AsyncFunctionDef.returns);
635         }
636         break;
637     case ClassDef_kind:
638         CALL_SEQ(astfold_expr, expr_ty, node_->v.ClassDef.bases);
639         CALL_SEQ(astfold_keyword, keyword_ty, node_->v.ClassDef.keywords);
640         CALL(astfold_body, asdl_seq, node_->v.ClassDef.body);
641         CALL_SEQ(astfold_expr, expr_ty, node_->v.ClassDef.decorator_list);
642         break;
643     case Return_kind:
644         CALL_OPT(astfold_expr, expr_ty, node_->v.Return.value);
645         break;
646     case Delete_kind:
647         CALL_SEQ(astfold_expr, expr_ty, node_->v.Delete.targets);
648         break;
649     case Assign_kind:
650         CALL_SEQ(astfold_expr, expr_ty, node_->v.Assign.targets);
651         CALL(astfold_expr, expr_ty, node_->v.Assign.value);
652         break;
653     case AugAssign_kind:
654         CALL(astfold_expr, expr_ty, node_->v.AugAssign.target);
655         CALL(astfold_expr, expr_ty, node_->v.AugAssign.value);
656         break;
657     case AnnAssign_kind:
658         CALL(astfold_expr, expr_ty, node_->v.AnnAssign.target);
659         if (!(state->ff_features & CO_FUTURE_ANNOTATIONS)) {
660             CALL(astfold_expr, expr_ty, node_->v.AnnAssign.annotation);
661         }
662         CALL_OPT(astfold_expr, expr_ty, node_->v.AnnAssign.value);
663         break;
664     case For_kind:
665         CALL(astfold_expr, expr_ty, node_->v.For.target);
666         CALL(astfold_expr, expr_ty, node_->v.For.iter);
667         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.For.body);
668         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.For.orelse);
669 
670         CALL(fold_iter, expr_ty, node_->v.For.iter);
671         break;
672     case AsyncFor_kind:
673         CALL(astfold_expr, expr_ty, node_->v.AsyncFor.target);
674         CALL(astfold_expr, expr_ty, node_->v.AsyncFor.iter);
675         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.AsyncFor.body);
676         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.AsyncFor.orelse);
677         break;
678     case While_kind:
679         CALL(astfold_expr, expr_ty, node_->v.While.test);
680         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.While.body);
681         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.While.orelse);
682         break;
683     case If_kind:
684         CALL(astfold_expr, expr_ty, node_->v.If.test);
685         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.If.body);
686         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.If.orelse);
687         break;
688     case With_kind:
689         CALL_SEQ(astfold_withitem, withitem_ty, node_->v.With.items);
690         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.With.body);
691         break;
692     case AsyncWith_kind:
693         CALL_SEQ(astfold_withitem, withitem_ty, node_->v.AsyncWith.items);
694         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.AsyncWith.body);
695         break;
696     case Raise_kind:
697         CALL_OPT(astfold_expr, expr_ty, node_->v.Raise.exc);
698         CALL_OPT(astfold_expr, expr_ty, node_->v.Raise.cause);
699         break;
700     case Try_kind:
701         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.Try.body);
702         CALL_SEQ(astfold_excepthandler, excepthandler_ty, node_->v.Try.handlers);
703         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.Try.orelse);
704         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.Try.finalbody);
705         break;
706     case Assert_kind:
707         CALL(astfold_expr, expr_ty, node_->v.Assert.test);
708         CALL_OPT(astfold_expr, expr_ty, node_->v.Assert.msg);
709         break;
710     case Expr_kind:
711         CALL(astfold_expr, expr_ty, node_->v.Expr.value);
712         break;
713     default:
714         break;
715     }
716     return 1;
717 }
718 
719 static int
astfold_excepthandler(excepthandler_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)720 astfold_excepthandler(excepthandler_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
721 {
722     switch (node_->kind) {
723     case ExceptHandler_kind:
724         CALL_OPT(astfold_expr, expr_ty, node_->v.ExceptHandler.type);
725         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.ExceptHandler.body);
726         break;
727     default:
728         break;
729     }
730     return 1;
731 }
732 
733 static int
astfold_withitem(withitem_ty node_,PyArena * ctx_,_PyASTOptimizeState * state)734 astfold_withitem(withitem_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
735 {
736     CALL(astfold_expr, expr_ty, node_->context_expr);
737     CALL_OPT(astfold_expr, expr_ty, node_->optional_vars);
738     return 1;
739 }
740 
741 #undef CALL
742 #undef CALL_OPT
743 #undef CALL_SEQ
744 #undef CALL_INT_SEQ
745 
746 int
_PyAST_Optimize(mod_ty mod,PyArena * arena,_PyASTOptimizeState * state)747 _PyAST_Optimize(mod_ty mod, PyArena *arena, _PyASTOptimizeState *state)
748 {
749     int ret = astfold_mod(mod, arena, state);
750     assert(ret || PyErr_Occurred());
751     return ret;
752 }
753