1 /* AST Optimizer */
2 #include "Python.h"
3 #include "Python-ast.h"
4 
5 
6 /* TODO: is_const and get_const_value are copied from Python/compile.c.
7    It should be deduped in the future.  Maybe, we can include this file
8    from compile.c?
9 */
10 static int
is_const(expr_ty e)11 is_const(expr_ty e)
12 {
13     switch (e->kind) {
14     case Constant_kind:
15     case Num_kind:
16     case Str_kind:
17     case Bytes_kind:
18     case Ellipsis_kind:
19     case NameConstant_kind:
20         return 1;
21     default:
22         return 0;
23     }
24 }
25 
26 static PyObject *
get_const_value(expr_ty e)27 get_const_value(expr_ty e)
28 {
29     switch (e->kind) {
30     case Constant_kind:
31         return e->v.Constant.value;
32     case Num_kind:
33         return e->v.Num.n;
34     case Str_kind:
35         return e->v.Str.s;
36     case Bytes_kind:
37         return e->v.Bytes.s;
38     case Ellipsis_kind:
39         return Py_Ellipsis;
40     case NameConstant_kind:
41         return e->v.NameConstant.value;
42     default:
43         Py_UNREACHABLE();
44     }
45 }
46 
47 static int
make_const(expr_ty node,PyObject * val,PyArena * arena)48 make_const(expr_ty node, PyObject *val, PyArena *arena)
49 {
50     if (val == NULL) {
51         if (PyErr_ExceptionMatches(PyExc_KeyboardInterrupt)) {
52             return 0;
53         }
54         PyErr_Clear();
55         return 1;
56     }
57     if (PyArena_AddPyObject(arena, val) < 0) {
58         Py_DECREF(val);
59         return 0;
60     }
61     node->kind = Constant_kind;
62     node->v.Constant.value = val;
63     return 1;
64 }
65 
66 #define COPY_NODE(TO, FROM) (memcpy((TO), (FROM), sizeof(struct _expr)))
67 
68 static PyObject*
unary_not(PyObject * v)69 unary_not(PyObject *v)
70 {
71     int r = PyObject_IsTrue(v);
72     if (r < 0)
73         return NULL;
74     return PyBool_FromLong(!r);
75 }
76 
77 static int
fold_unaryop(expr_ty node,PyArena * arena,int optimize)78 fold_unaryop(expr_ty node, PyArena *arena, int optimize)
79 {
80     expr_ty arg = node->v.UnaryOp.operand;
81 
82     if (!is_const(arg)) {
83         /* Fold not into comparison */
84         if (node->v.UnaryOp.op == Not && arg->kind == Compare_kind &&
85                 asdl_seq_LEN(arg->v.Compare.ops) == 1) {
86             /* Eq and NotEq are often implemented in terms of one another, so
87                folding not (self == other) into self != other breaks implementation
88                of !=. Detecting such cases doesn't seem worthwhile.
89                Python uses </> for 'is subset'/'is superset' operations on sets.
90                They don't satisfy not folding laws. */
91             int op = asdl_seq_GET(arg->v.Compare.ops, 0);
92             switch (op) {
93             case Is:
94                 op = IsNot;
95                 break;
96             case IsNot:
97                 op = Is;
98                 break;
99             case In:
100                 op = NotIn;
101                 break;
102             case NotIn:
103                 op = In;
104                 break;
105             default:
106                 op = 0;
107             }
108             if (op) {
109                 asdl_seq_SET(arg->v.Compare.ops, 0, op);
110                 COPY_NODE(node, arg);
111                 return 1;
112             }
113         }
114         return 1;
115     }
116 
117     typedef PyObject *(*unary_op)(PyObject*);
118     static const unary_op ops[] = {
119         [Invert] = PyNumber_Invert,
120         [Not] = unary_not,
121         [UAdd] = PyNumber_Positive,
122         [USub] = PyNumber_Negative,
123     };
124     PyObject *newval = ops[node->v.UnaryOp.op](get_const_value(arg));
125     return make_const(node, newval, arena);
126 }
127 
128 /* Check whether a collection doesn't containing too much items (including
129    subcollections).  This protects from creating a constant that needs
130    too much time for calculating a hash.
131    "limit" is the maximal number of items.
132    Returns the negative number if the total number of items exceeds the
133    limit.  Otherwise returns the limit minus the total number of items.
134 */
135 
136 static Py_ssize_t
check_complexity(PyObject * obj,Py_ssize_t limit)137 check_complexity(PyObject *obj, Py_ssize_t limit)
138 {
139     if (PyTuple_Check(obj)) {
140         Py_ssize_t i;
141         limit -= PyTuple_GET_SIZE(obj);
142         for (i = 0; limit >= 0 && i < PyTuple_GET_SIZE(obj); i++) {
143             limit = check_complexity(PyTuple_GET_ITEM(obj, i), limit);
144         }
145         return limit;
146     }
147     else if (PyFrozenSet_Check(obj)) {
148         Py_ssize_t i = 0;
149         PyObject *item;
150         Py_hash_t hash;
151         limit -= PySet_GET_SIZE(obj);
152         while (limit >= 0 && _PySet_NextEntry(obj, &i, &item, &hash)) {
153             limit = check_complexity(item, limit);
154         }
155     }
156     return limit;
157 }
158 
159 #define MAX_INT_SIZE           128  /* bits */
160 #define MAX_COLLECTION_SIZE    256  /* items */
161 #define MAX_STR_SIZE          4096  /* characters */
162 #define MAX_TOTAL_ITEMS       1024  /* including nested collections */
163 
164 static PyObject *
safe_multiply(PyObject * v,PyObject * w)165 safe_multiply(PyObject *v, PyObject *w)
166 {
167     if (PyLong_Check(v) && PyLong_Check(w) && Py_SIZE(v) && Py_SIZE(w)) {
168         size_t vbits = _PyLong_NumBits(v);
169         size_t wbits = _PyLong_NumBits(w);
170         if (vbits == (size_t)-1 || wbits == (size_t)-1) {
171             return NULL;
172         }
173         if (vbits + wbits > MAX_INT_SIZE) {
174             return NULL;
175         }
176     }
177     else if (PyLong_Check(v) && (PyTuple_Check(w) || PyFrozenSet_Check(w))) {
178         Py_ssize_t size = PyTuple_Check(w) ? PyTuple_GET_SIZE(w) :
179                                              PySet_GET_SIZE(w);
180         if (size) {
181             long n = PyLong_AsLong(v);
182             if (n < 0 || n > MAX_COLLECTION_SIZE / size) {
183                 return NULL;
184             }
185             if (n && check_complexity(w, MAX_TOTAL_ITEMS / n) < 0) {
186                 return NULL;
187             }
188         }
189     }
190     else if (PyLong_Check(v) && (PyUnicode_Check(w) || PyBytes_Check(w))) {
191         Py_ssize_t size = PyUnicode_Check(w) ? PyUnicode_GET_LENGTH(w) :
192                                                PyBytes_GET_SIZE(w);
193         if (size) {
194             long n = PyLong_AsLong(v);
195             if (n < 0 || n > MAX_STR_SIZE / size) {
196                 return NULL;
197             }
198         }
199     }
200     else if (PyLong_Check(w) &&
201              (PyTuple_Check(v) || PyFrozenSet_Check(v) ||
202               PyUnicode_Check(v) || PyBytes_Check(v)))
203     {
204         return safe_multiply(w, v);
205     }
206 
207     return PyNumber_Multiply(v, w);
208 }
209 
210 static PyObject *
safe_power(PyObject * v,PyObject * w)211 safe_power(PyObject *v, PyObject *w)
212 {
213     if (PyLong_Check(v) && PyLong_Check(w) && Py_SIZE(v) && Py_SIZE(w) > 0) {
214         size_t vbits = _PyLong_NumBits(v);
215         size_t wbits = PyLong_AsSize_t(w);
216         if (vbits == (size_t)-1 || wbits == (size_t)-1) {
217             return NULL;
218         }
219         if (vbits > MAX_INT_SIZE / wbits) {
220             return NULL;
221         }
222     }
223 
224     return PyNumber_Power(v, w, Py_None);
225 }
226 
227 static PyObject *
safe_lshift(PyObject * v,PyObject * w)228 safe_lshift(PyObject *v, PyObject *w)
229 {
230     if (PyLong_Check(v) && PyLong_Check(w) && Py_SIZE(v) && Py_SIZE(w)) {
231         size_t vbits = _PyLong_NumBits(v);
232         size_t wbits = PyLong_AsSize_t(w);
233         if (vbits == (size_t)-1 || wbits == (size_t)-1) {
234             return NULL;
235         }
236         if (wbits > MAX_INT_SIZE || vbits > MAX_INT_SIZE - wbits) {
237             return NULL;
238         }
239     }
240 
241     return PyNumber_Lshift(v, w);
242 }
243 
244 static PyObject *
safe_mod(PyObject * v,PyObject * w)245 safe_mod(PyObject *v, PyObject *w)
246 {
247     if (PyUnicode_Check(v) || PyBytes_Check(v)) {
248         return NULL;
249     }
250 
251     return PyNumber_Remainder(v, w);
252 }
253 
254 static int
fold_binop(expr_ty node,PyArena * arena,int optimize)255 fold_binop(expr_ty node, PyArena *arena, int optimize)
256 {
257     expr_ty lhs, rhs;
258     lhs = node->v.BinOp.left;
259     rhs = node->v.BinOp.right;
260     if (!is_const(lhs) || !is_const(rhs)) {
261         return 1;
262     }
263 
264     PyObject *lv = get_const_value(lhs);
265     PyObject *rv = get_const_value(rhs);
266     PyObject *newval;
267 
268     switch (node->v.BinOp.op) {
269     case Add:
270         newval = PyNumber_Add(lv, rv);
271         break;
272     case Sub:
273         newval = PyNumber_Subtract(lv, rv);
274         break;
275     case Mult:
276         newval = safe_multiply(lv, rv);
277         break;
278     case Div:
279         newval = PyNumber_TrueDivide(lv, rv);
280         break;
281     case FloorDiv:
282         newval = PyNumber_FloorDivide(lv, rv);
283         break;
284     case Mod:
285         newval = safe_mod(lv, rv);
286         break;
287     case Pow:
288         newval = safe_power(lv, rv);
289         break;
290     case LShift:
291         newval = safe_lshift(lv, rv);
292         break;
293     case RShift:
294         newval = PyNumber_Rshift(lv, rv);
295         break;
296     case BitOr:
297         newval = PyNumber_Or(lv, rv);
298         break;
299     case BitXor:
300         newval = PyNumber_Xor(lv, rv);
301         break;
302     case BitAnd:
303         newval = PyNumber_And(lv, rv);
304         break;
305     default: // Unknown operator
306         return 1;
307     }
308 
309     return make_const(node, newval, arena);
310 }
311 
312 static PyObject*
make_const_tuple(asdl_seq * elts)313 make_const_tuple(asdl_seq *elts)
314 {
315     for (int i = 0; i < asdl_seq_LEN(elts); i++) {
316         expr_ty e = (expr_ty)asdl_seq_GET(elts, i);
317         if (!is_const(e)) {
318             return NULL;
319         }
320     }
321 
322     PyObject *newval = PyTuple_New(asdl_seq_LEN(elts));
323     if (newval == NULL) {
324         return NULL;
325     }
326 
327     for (int i = 0; i < asdl_seq_LEN(elts); i++) {
328         expr_ty e = (expr_ty)asdl_seq_GET(elts, i);
329         PyObject *v = get_const_value(e);
330         Py_INCREF(v);
331         PyTuple_SET_ITEM(newval, i, v);
332     }
333     return newval;
334 }
335 
336 static int
fold_tuple(expr_ty node,PyArena * arena,int optimize)337 fold_tuple(expr_ty node, PyArena *arena, int optimize)
338 {
339     PyObject *newval;
340 
341     if (node->v.Tuple.ctx != Load)
342         return 1;
343 
344     newval = make_const_tuple(node->v.Tuple.elts);
345     return make_const(node, newval, arena);
346 }
347 
348 static int
fold_subscr(expr_ty node,PyArena * arena,int optimize)349 fold_subscr(expr_ty node, PyArena *arena, int optimize)
350 {
351     PyObject *newval;
352     expr_ty arg, idx;
353     slice_ty slice;
354 
355     arg = node->v.Subscript.value;
356     slice = node->v.Subscript.slice;
357     if (node->v.Subscript.ctx != Load ||
358             !is_const(arg) ||
359             /* TODO: handle other types of slices */
360             slice->kind != Index_kind ||
361             !is_const(slice->v.Index.value))
362     {
363         return 1;
364     }
365 
366     idx = slice->v.Index.value;
367     newval = PyObject_GetItem(get_const_value(arg), get_const_value(idx));
368     return make_const(node, newval, arena);
369 }
370 
371 /* Change literal list or set of constants into constant
372    tuple or frozenset respectively.
373    Used for right operand of "in" and "not in" tests and for iterable
374    in "for" loop and comprehensions.
375 */
376 static int
fold_iter(expr_ty arg,PyArena * arena,int optimize)377 fold_iter(expr_ty arg, PyArena *arena, int optimize)
378 {
379     PyObject *newval;
380     if (arg->kind == List_kind) {
381         newval = make_const_tuple(arg->v.List.elts);
382     }
383     else if (arg->kind == Set_kind) {
384         newval = make_const_tuple(arg->v.Set.elts);
385         if (newval) {
386             Py_SETREF(newval, PyFrozenSet_New(newval));
387         }
388     }
389     else {
390         return 1;
391     }
392     return make_const(arg, newval, arena);
393 }
394 
395 static int
fold_compare(expr_ty node,PyArena * arena,int optimize)396 fold_compare(expr_ty node, PyArena *arena, int optimize)
397 {
398     asdl_int_seq *ops;
399     asdl_seq *args;
400     Py_ssize_t i;
401 
402     ops = node->v.Compare.ops;
403     args = node->v.Compare.comparators;
404     /* TODO: optimize cases with literal arguments. */
405     /* Change literal list or set in 'in' or 'not in' into
406        tuple or frozenset respectively. */
407     i = asdl_seq_LEN(ops) - 1;
408     int op = asdl_seq_GET(ops, i);
409     if (op == In || op == NotIn) {
410         if (!fold_iter((expr_ty)asdl_seq_GET(args, i), arena, optimize)) {
411             return 0;
412         }
413     }
414     return 1;
415 }
416 
417 static int astfold_mod(mod_ty node_, PyArena *ctx_, int optimize_);
418 static int astfold_stmt(stmt_ty node_, PyArena *ctx_, int optimize_);
419 static int astfold_expr(expr_ty node_, PyArena *ctx_, int optimize_);
420 static int astfold_arguments(arguments_ty node_, PyArena *ctx_, int optimize_);
421 static int astfold_comprehension(comprehension_ty node_, PyArena *ctx_, int optimize_);
422 static int astfold_keyword(keyword_ty node_, PyArena *ctx_, int optimize_);
423 static int astfold_slice(slice_ty node_, PyArena *ctx_, int optimize_);
424 static int astfold_arg(arg_ty node_, PyArena *ctx_, int optimize_);
425 static int astfold_withitem(withitem_ty node_, PyArena *ctx_, int optimize_);
426 static int astfold_excepthandler(excepthandler_ty node_, PyArena *ctx_, int optimize_);
427 #define CALL(FUNC, TYPE, ARG) \
428     if (!FUNC((ARG), ctx_, optimize_)) \
429         return 0;
430 
431 #define CALL_OPT(FUNC, TYPE, ARG) \
432     if ((ARG) != NULL && !FUNC((ARG), ctx_, optimize_)) \
433         return 0;
434 
435 #define CALL_SEQ(FUNC, TYPE, ARG) { \
436     int i; \
437     asdl_seq *seq = (ARG); /* avoid variable capture */ \
438     for (i = 0; i < asdl_seq_LEN(seq); i++) { \
439         TYPE elt = (TYPE)asdl_seq_GET(seq, i); \
440         if (elt != NULL && !FUNC(elt, ctx_, optimize_)) \
441             return 0; \
442     } \
443 }
444 
445 #define CALL_INT_SEQ(FUNC, TYPE, ARG) { \
446     int i; \
447     asdl_int_seq *seq = (ARG); /* avoid variable capture */ \
448     for (i = 0; i < asdl_seq_LEN(seq); i++) { \
449         TYPE elt = (TYPE)asdl_seq_GET(seq, i); \
450         if (!FUNC(elt, ctx_, optimize_)) \
451             return 0; \
452     } \
453 }
454 
455 static int
isdocstring(stmt_ty s)456 isdocstring(stmt_ty s)
457 {
458     if (s->kind != Expr_kind)
459         return 0;
460     if (s->v.Expr.value->kind == Str_kind)
461         return 1;
462     if (s->v.Expr.value->kind == Constant_kind)
463         return PyUnicode_CheckExact(s->v.Expr.value->v.Constant.value);
464     return 0;
465 }
466 
467 static int
astfold_body(asdl_seq * stmts,PyArena * ctx_,int optimize_)468 astfold_body(asdl_seq *stmts, PyArena *ctx_, int optimize_)
469 {
470     if (!asdl_seq_LEN(stmts)) {
471         return 1;
472     }
473     int docstring = isdocstring((stmt_ty)asdl_seq_GET(stmts, 0));
474     CALL_SEQ(astfold_stmt, stmt_ty, stmts);
475     if (docstring) {
476         return 1;
477     }
478     stmt_ty st = (stmt_ty)asdl_seq_GET(stmts, 0);
479     if (isdocstring(st)) {
480         asdl_seq *values = _Py_asdl_seq_new(1, ctx_);
481         if (!values) {
482             return 0;
483         }
484         asdl_seq_SET(values, 0, st->v.Expr.value);
485         expr_ty expr = _Py_JoinedStr(values, st->lineno, st->col_offset, ctx_);
486         if (!expr) {
487             return 0;
488         }
489         st->v.Expr.value = expr;
490     }
491     return 1;
492 }
493 
494 static int
astfold_mod(mod_ty node_,PyArena * ctx_,int optimize_)495 astfold_mod(mod_ty node_, PyArena *ctx_, int optimize_)
496 {
497     switch (node_->kind) {
498     case Module_kind:
499         CALL(astfold_body, asdl_seq, node_->v.Module.body);
500         break;
501     case Interactive_kind:
502         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.Interactive.body);
503         break;
504     case Expression_kind:
505         CALL(astfold_expr, expr_ty, node_->v.Expression.body);
506         break;
507     case Suite_kind:
508         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.Suite.body);
509         break;
510     default:
511         break;
512     }
513     return 1;
514 }
515 
516 static int
astfold_expr(expr_ty node_,PyArena * ctx_,int optimize_)517 astfold_expr(expr_ty node_, PyArena *ctx_, int optimize_)
518 {
519     switch (node_->kind) {
520     case BoolOp_kind:
521         CALL_SEQ(astfold_expr, expr_ty, node_->v.BoolOp.values);
522         break;
523     case BinOp_kind:
524         CALL(astfold_expr, expr_ty, node_->v.BinOp.left);
525         CALL(astfold_expr, expr_ty, node_->v.BinOp.right);
526         CALL(fold_binop, expr_ty, node_);
527         break;
528     case UnaryOp_kind:
529         CALL(astfold_expr, expr_ty, node_->v.UnaryOp.operand);
530         CALL(fold_unaryop, expr_ty, node_);
531         break;
532     case Lambda_kind:
533         CALL(astfold_arguments, arguments_ty, node_->v.Lambda.args);
534         CALL(astfold_expr, expr_ty, node_->v.Lambda.body);
535         break;
536     case IfExp_kind:
537         CALL(astfold_expr, expr_ty, node_->v.IfExp.test);
538         CALL(astfold_expr, expr_ty, node_->v.IfExp.body);
539         CALL(astfold_expr, expr_ty, node_->v.IfExp.orelse);
540         break;
541     case Dict_kind:
542         CALL_SEQ(astfold_expr, expr_ty, node_->v.Dict.keys);
543         CALL_SEQ(astfold_expr, expr_ty, node_->v.Dict.values);
544         break;
545     case Set_kind:
546         CALL_SEQ(astfold_expr, expr_ty, node_->v.Set.elts);
547         break;
548     case ListComp_kind:
549         CALL(astfold_expr, expr_ty, node_->v.ListComp.elt);
550         CALL_SEQ(astfold_comprehension, comprehension_ty, node_->v.ListComp.generators);
551         break;
552     case SetComp_kind:
553         CALL(astfold_expr, expr_ty, node_->v.SetComp.elt);
554         CALL_SEQ(astfold_comprehension, comprehension_ty, node_->v.SetComp.generators);
555         break;
556     case DictComp_kind:
557         CALL(astfold_expr, expr_ty, node_->v.DictComp.key);
558         CALL(astfold_expr, expr_ty, node_->v.DictComp.value);
559         CALL_SEQ(astfold_comprehension, comprehension_ty, node_->v.DictComp.generators);
560         break;
561     case GeneratorExp_kind:
562         CALL(astfold_expr, expr_ty, node_->v.GeneratorExp.elt);
563         CALL_SEQ(astfold_comprehension, comprehension_ty, node_->v.GeneratorExp.generators);
564         break;
565     case Await_kind:
566         CALL(astfold_expr, expr_ty, node_->v.Await.value);
567         break;
568     case Yield_kind:
569         CALL_OPT(astfold_expr, expr_ty, node_->v.Yield.value);
570         break;
571     case YieldFrom_kind:
572         CALL(astfold_expr, expr_ty, node_->v.YieldFrom.value);
573         break;
574     case Compare_kind:
575         CALL(astfold_expr, expr_ty, node_->v.Compare.left);
576         CALL_SEQ(astfold_expr, expr_ty, node_->v.Compare.comparators);
577         CALL(fold_compare, expr_ty, node_);
578         break;
579     case Call_kind:
580         CALL(astfold_expr, expr_ty, node_->v.Call.func);
581         CALL_SEQ(astfold_expr, expr_ty, node_->v.Call.args);
582         CALL_SEQ(astfold_keyword, keyword_ty, node_->v.Call.keywords);
583         break;
584     case FormattedValue_kind:
585         CALL(astfold_expr, expr_ty, node_->v.FormattedValue.value);
586         CALL_OPT(astfold_expr, expr_ty, node_->v.FormattedValue.format_spec);
587         break;
588     case JoinedStr_kind:
589         CALL_SEQ(astfold_expr, expr_ty, node_->v.JoinedStr.values);
590         break;
591     case Attribute_kind:
592         CALL(astfold_expr, expr_ty, node_->v.Attribute.value);
593         break;
594     case Subscript_kind:
595         CALL(astfold_expr, expr_ty, node_->v.Subscript.value);
596         CALL(astfold_slice, slice_ty, node_->v.Subscript.slice);
597         CALL(fold_subscr, expr_ty, node_);
598         break;
599     case Starred_kind:
600         CALL(astfold_expr, expr_ty, node_->v.Starred.value);
601         break;
602     case List_kind:
603         CALL_SEQ(astfold_expr, expr_ty, node_->v.List.elts);
604         break;
605     case Tuple_kind:
606         CALL_SEQ(astfold_expr, expr_ty, node_->v.Tuple.elts);
607         CALL(fold_tuple, expr_ty, node_);
608         break;
609     case Name_kind:
610         if (_PyUnicode_EqualToASCIIString(node_->v.Name.id, "__debug__")) {
611             return make_const(node_, PyBool_FromLong(!optimize_), ctx_);
612         }
613         break;
614     default:
615         break;
616     }
617     return 1;
618 }
619 
620 static int
astfold_slice(slice_ty node_,PyArena * ctx_,int optimize_)621 astfold_slice(slice_ty node_, PyArena *ctx_, int optimize_)
622 {
623     switch (node_->kind) {
624     case Slice_kind:
625         CALL_OPT(astfold_expr, expr_ty, node_->v.Slice.lower);
626         CALL_OPT(astfold_expr, expr_ty, node_->v.Slice.upper);
627         CALL_OPT(astfold_expr, expr_ty, node_->v.Slice.step);
628         break;
629     case ExtSlice_kind:
630         CALL_SEQ(astfold_slice, slice_ty, node_->v.ExtSlice.dims);
631         break;
632     case Index_kind:
633         CALL(astfold_expr, expr_ty, node_->v.Index.value);
634         break;
635     default:
636         break;
637     }
638     return 1;
639 }
640 
641 static int
astfold_keyword(keyword_ty node_,PyArena * ctx_,int optimize_)642 astfold_keyword(keyword_ty node_, PyArena *ctx_, int optimize_)
643 {
644     CALL(astfold_expr, expr_ty, node_->value);
645     return 1;
646 }
647 
648 static int
astfold_comprehension(comprehension_ty node_,PyArena * ctx_,int optimize_)649 astfold_comprehension(comprehension_ty node_, PyArena *ctx_, int optimize_)
650 {
651     CALL(astfold_expr, expr_ty, node_->target);
652     CALL(astfold_expr, expr_ty, node_->iter);
653     CALL_SEQ(astfold_expr, expr_ty, node_->ifs);
654 
655     CALL(fold_iter, expr_ty, node_->iter);
656     return 1;
657 }
658 
659 static int
astfold_arguments(arguments_ty node_,PyArena * ctx_,int optimize_)660 astfold_arguments(arguments_ty node_, PyArena *ctx_, int optimize_)
661 {
662     CALL_SEQ(astfold_arg, arg_ty, node_->args);
663     CALL_OPT(astfold_arg, arg_ty, node_->vararg);
664     CALL_SEQ(astfold_arg, arg_ty, node_->kwonlyargs);
665     CALL_SEQ(astfold_expr, expr_ty, node_->kw_defaults);
666     CALL_OPT(astfold_arg, arg_ty, node_->kwarg);
667     CALL_SEQ(astfold_expr, expr_ty, node_->defaults);
668     return 1;
669 }
670 
671 static int
astfold_arg(arg_ty node_,PyArena * ctx_,int optimize_)672 astfold_arg(arg_ty node_, PyArena *ctx_, int optimize_)
673 {
674     CALL_OPT(astfold_expr, expr_ty, node_->annotation);
675     return 1;
676 }
677 
678 static int
astfold_stmt(stmt_ty node_,PyArena * ctx_,int optimize_)679 astfold_stmt(stmt_ty node_, PyArena *ctx_, int optimize_)
680 {
681     switch (node_->kind) {
682     case FunctionDef_kind:
683         CALL(astfold_arguments, arguments_ty, node_->v.FunctionDef.args);
684         CALL(astfold_body, asdl_seq, node_->v.FunctionDef.body);
685         CALL_SEQ(astfold_expr, expr_ty, node_->v.FunctionDef.decorator_list);
686         CALL_OPT(astfold_expr, expr_ty, node_->v.FunctionDef.returns);
687         break;
688     case AsyncFunctionDef_kind:
689         CALL(astfold_arguments, arguments_ty, node_->v.AsyncFunctionDef.args);
690         CALL(astfold_body, asdl_seq, node_->v.AsyncFunctionDef.body);
691         CALL_SEQ(astfold_expr, expr_ty, node_->v.AsyncFunctionDef.decorator_list);
692         CALL_OPT(astfold_expr, expr_ty, node_->v.AsyncFunctionDef.returns);
693         break;
694     case ClassDef_kind:
695         CALL_SEQ(astfold_expr, expr_ty, node_->v.ClassDef.bases);
696         CALL_SEQ(astfold_keyword, keyword_ty, node_->v.ClassDef.keywords);
697         CALL(astfold_body, asdl_seq, node_->v.ClassDef.body);
698         CALL_SEQ(astfold_expr, expr_ty, node_->v.ClassDef.decorator_list);
699         break;
700     case Return_kind:
701         CALL_OPT(astfold_expr, expr_ty, node_->v.Return.value);
702         break;
703     case Delete_kind:
704         CALL_SEQ(astfold_expr, expr_ty, node_->v.Delete.targets);
705         break;
706     case Assign_kind:
707         CALL_SEQ(astfold_expr, expr_ty, node_->v.Assign.targets);
708         CALL(astfold_expr, expr_ty, node_->v.Assign.value);
709         break;
710     case AugAssign_kind:
711         CALL(astfold_expr, expr_ty, node_->v.AugAssign.target);
712         CALL(astfold_expr, expr_ty, node_->v.AugAssign.value);
713         break;
714     case AnnAssign_kind:
715         CALL(astfold_expr, expr_ty, node_->v.AnnAssign.target);
716         CALL(astfold_expr, expr_ty, node_->v.AnnAssign.annotation);
717         CALL_OPT(astfold_expr, expr_ty, node_->v.AnnAssign.value);
718         break;
719     case For_kind:
720         CALL(astfold_expr, expr_ty, node_->v.For.target);
721         CALL(astfold_expr, expr_ty, node_->v.For.iter);
722         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.For.body);
723         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.For.orelse);
724 
725         CALL(fold_iter, expr_ty, node_->v.For.iter);
726         break;
727     case AsyncFor_kind:
728         CALL(astfold_expr, expr_ty, node_->v.AsyncFor.target);
729         CALL(astfold_expr, expr_ty, node_->v.AsyncFor.iter);
730         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.AsyncFor.body);
731         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.AsyncFor.orelse);
732         break;
733     case While_kind:
734         CALL(astfold_expr, expr_ty, node_->v.While.test);
735         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.While.body);
736         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.While.orelse);
737         break;
738     case If_kind:
739         CALL(astfold_expr, expr_ty, node_->v.If.test);
740         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.If.body);
741         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.If.orelse);
742         break;
743     case With_kind:
744         CALL_SEQ(astfold_withitem, withitem_ty, node_->v.With.items);
745         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.With.body);
746         break;
747     case AsyncWith_kind:
748         CALL_SEQ(astfold_withitem, withitem_ty, node_->v.AsyncWith.items);
749         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.AsyncWith.body);
750         break;
751     case Raise_kind:
752         CALL_OPT(astfold_expr, expr_ty, node_->v.Raise.exc);
753         CALL_OPT(astfold_expr, expr_ty, node_->v.Raise.cause);
754         break;
755     case Try_kind:
756         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.Try.body);
757         CALL_SEQ(astfold_excepthandler, excepthandler_ty, node_->v.Try.handlers);
758         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.Try.orelse);
759         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.Try.finalbody);
760         break;
761     case Assert_kind:
762         CALL(astfold_expr, expr_ty, node_->v.Assert.test);
763         CALL_OPT(astfold_expr, expr_ty, node_->v.Assert.msg);
764         break;
765     case Expr_kind:
766         CALL(astfold_expr, expr_ty, node_->v.Expr.value);
767         break;
768     default:
769         break;
770     }
771     return 1;
772 }
773 
774 static int
astfold_excepthandler(excepthandler_ty node_,PyArena * ctx_,int optimize_)775 astfold_excepthandler(excepthandler_ty node_, PyArena *ctx_, int optimize_)
776 {
777     switch (node_->kind) {
778     case ExceptHandler_kind:
779         CALL_OPT(astfold_expr, expr_ty, node_->v.ExceptHandler.type);
780         CALL_SEQ(astfold_stmt, stmt_ty, node_->v.ExceptHandler.body);
781         break;
782     default:
783         break;
784     }
785     return 1;
786 }
787 
788 static int
astfold_withitem(withitem_ty node_,PyArena * ctx_,int optimize_)789 astfold_withitem(withitem_ty node_, PyArena *ctx_, int optimize_)
790 {
791     CALL(astfold_expr, expr_ty, node_->context_expr);
792     CALL_OPT(astfold_expr, expr_ty, node_->optional_vars);
793     return 1;
794 }
795 
796 #undef CALL
797 #undef CALL_OPT
798 #undef CALL_SEQ
799 #undef CALL_INT_SEQ
800 
801 int
_PyAST_Optimize(mod_ty mod,PyArena * arena,int optimize)802 _PyAST_Optimize(mod_ty mod, PyArena *arena, int optimize)
803 {
804     int ret = astfold_mod(mod, arena, optimize);
805     assert(ret || PyErr_Occurred());
806     return ret;
807 }
808