1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Handles control flow statements: while, for, if.
16
17Python 2 compatibility version. Not maintained.
18"""
19
20from __future__ import absolute_import
21from __future__ import division
22from __future__ import print_function
23
24import gast
25
26from tensorflow.python.autograph.core import converter
27from tensorflow.python.autograph.lang import directives
28from tensorflow.python.autograph.pyct import anno
29from tensorflow.python.autograph.pyct import ast_util
30from tensorflow.python.autograph.pyct import cfg
31from tensorflow.python.autograph.pyct import parser
32from tensorflow.python.autograph.pyct import qual_names
33from tensorflow.python.autograph.pyct import templates
34from tensorflow.python.autograph.pyct.static_analysis import activity
35from tensorflow.python.autograph.pyct.static_analysis import annos
36from tensorflow.python.autograph.pyct.static_analysis import liveness
37from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
38from tensorflow.python.autograph.pyct.static_analysis import reaching_fndefs
39
40
41# TODO(mdan): Refactor functions to make them smaller.
42
43
44class ControlFlowTransformer(converter.Base):
45  """Transforms control flow structures like loops an conditionals."""
46
47  def _create_cond_branch(self, body_name, aliased_orig_names,
48                          aliased_new_names, body, returns):
49    if len(returns) == 1:
50      template = """
51        return retval
52      """
53      return_stmt = templates.replace(template, retval=returns[0])
54    else:
55      template = """
56        return (retvals,)
57      """
58      return_stmt = templates.replace(template, retvals=returns)
59
60    if aliased_orig_names:
61      alias_declarations = []
62      for new_name, old_name in zip(aliased_new_names, aliased_orig_names):
63        template = """
64          try:
65            aliased_new_name = aliased_orig_name
66          except NameError:
67            aliased_new_name = ag__.Undefined(symbol_name)
68        """
69
70        alias_declarations.extend(
71            templates.replace(
72                template,
73                aliased_new_name=new_name,
74                aliased_orig_name=old_name,
75                symbol_name=gast.Constant(str(old_name), kind=None)))
76
77      template = """
78        def body_name():
79          alias_declarations
80          body
81          return_stmt
82      """
83      return templates.replace(
84          template,
85          alias_declarations=alias_declarations,
86          body_name=body_name,
87          body=body,
88          return_stmt=return_stmt)
89    else:
90      template = """
91        def body_name():
92          body
93          return_stmt
94      """
95      return templates.replace(
96          template, body_name=body_name, body=body, return_stmt=return_stmt)
97
98  def _create_cond_expr(self, results, test, body_name, orelse_name,
99                        state_getter_name, state_setter_name,
100                        basic_symbol_names, composite_symbol_names):
101    if results is not None:
102      template = """
103        results = ag__.if_stmt(test, body_name, orelse_name,
104                               state_getter_name, state_setter_name,
105                               (basic_symbol_names,),
106                               (composite_symbol_names,))
107      """
108      return templates.replace(
109          template,
110          test=test,
111          results=results,
112          body_name=body_name,
113          orelse_name=orelse_name,
114          state_getter_name=state_getter_name,
115          state_setter_name=state_setter_name,
116          basic_symbol_names=basic_symbol_names,
117          composite_symbol_names=composite_symbol_names)
118    else:
119      template = """
120        ag__.if_stmt(test, body_name, orelse_name, getter_name, setter_name,
121                     (basic_symbol_names,), (composite_symbol_names,))
122      """
123      return templates.replace(
124          template,
125          test=test,
126          body_name=body_name,
127          orelse_name=orelse_name,
128          getter_name=state_getter_name,
129          setter_name=state_setter_name,
130          basic_symbol_names=basic_symbol_names,
131          composite_symbol_names=composite_symbol_names)
132
133  def _fmt_symbols(self, symbol_set):
134    if not symbol_set:
135      return 'no variables'
136    return ', '.join(map(str, symbol_set))
137
138  def _determine_aliased_symbols(self, scope, node_defined_in):
139    modified_live = scope.modified & node_defined_in
140    # Composite symbols are handled elsewhere see _create_state_functions
141    return {s for s in modified_live if not s.is_composite()}
142
143  def _create_state_functions(self, composites, state_getter_name,
144                              state_setter_name):
145
146    if composites:
147      composite_tuple = tuple(composites)
148
149      template = """
150        def state_getter_name():
151          return composite_tuple,
152        def state_setter_name(vals):
153          composite_tuple, = vals
154      """
155      node = templates.replace(
156          template,
157          state_getter_name=state_getter_name,
158          state_setter_name=state_setter_name,
159          composite_tuple=composite_tuple)
160    else:
161      template = """
162        def state_getter_name():
163          return ()
164        def state_setter_name(_):
165          pass
166        """
167      node = templates.replace(
168          template,
169          state_getter_name=state_getter_name,
170          state_setter_name=state_setter_name)
171
172    return node
173
174  def _create_loop_options(self, node):
175    if not anno.hasanno(node, anno.Basic.DIRECTIVES):
176      return gast.Dict([], [])
177
178    loop_directives = anno.getanno(node, anno.Basic.DIRECTIVES)
179    if directives.set_loop_options not in loop_directives:
180      return gast.Dict([], [])
181
182    opts_dict = loop_directives[directives.set_loop_options]
183    str_keys, values = zip(*opts_dict.items())
184    keys = [gast.Constant(s, kind=None) for s in str_keys]
185    values = list(values)  # ast and gast don't play well with tuples.
186    return gast.Dict(keys, values)
187
188  def _create_undefined_assigns(self, undefined_symbols):
189    assignments = []
190    for s in undefined_symbols:
191      template = '''
192        var = ag__.Undefined(symbol_name)
193      '''
194      assignments += templates.replace(
195          template,
196          var=s,
197          symbol_name=gast.Constant(s.ssf(), kind=None))
198    return assignments
199
200  def visit_If(self, node):
201    body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
202    orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE)
203    defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN)
204    live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT)
205
206    # Note: this information needs to be extracted before the body conversion
207    # that happens in the call to generic_visit below, because the conversion
208    # generates nodes that lack static analysis annotations.
209    need_alias_in_body = self._determine_aliased_symbols(
210        body_scope, defined_in)
211    need_alias_in_orelse = self._determine_aliased_symbols(
212        orelse_scope, defined_in)
213
214    node = self.generic_visit(node)
215
216    modified_in_cond = body_scope.modified | orelse_scope.modified
217    returned_from_cond = set()
218    composites = set()
219    for s in modified_in_cond:
220      if s in live_out and not s.is_composite():
221        returned_from_cond.add(s)
222      if s.is_composite():
223        # Special treatment for compound objects, always return them.
224        # This allows special handling within the if_stmt itself.
225        # For example, in TensorFlow we need to restore the state of composite
226        # symbols to ensure that only effects from the executed branch are seen.
227        composites.add(s)
228
229    created_in_body = body_scope.modified & returned_from_cond - defined_in
230    created_in_orelse = orelse_scope.modified & returned_from_cond - defined_in
231
232    basic_created_in_body = tuple(
233        s for s in created_in_body if not s.is_composite())
234    basic_created_in_orelse = tuple(
235        s for s in created_in_orelse if not s.is_composite())
236
237    # These variables are defined only in a single branch. This is fine in
238    # Python so we pass them through. Another backend, e.g. Tensorflow, may need
239    # to handle these cases specially or throw an Error.
240    possibly_undefined = (set(basic_created_in_body) ^
241                          set(basic_created_in_orelse))
242
243    # Alias the closure variables inside the conditional functions, to allow
244    # the functions access to the respective variables.
245    # We will alias variables independently for body and orelse scope,
246    # because different branches might write different variables.
247    aliased_body_orig_names = tuple(need_alias_in_body)
248    aliased_orelse_orig_names = tuple(need_alias_in_orelse)
249    aliased_body_new_names = tuple(
250        self.ctx.namer.new_symbol(s.ssf(), body_scope.referenced)
251        for s in aliased_body_orig_names)
252    aliased_orelse_new_names = tuple(
253        self.ctx.namer.new_symbol(s.ssf(), orelse_scope.referenced)
254        for s in aliased_orelse_orig_names)
255
256    alias_body_map = dict(zip(aliased_body_orig_names, aliased_body_new_names))
257    alias_orelse_map = dict(
258        zip(aliased_orelse_orig_names, aliased_orelse_new_names))
259
260    node_body = ast_util.rename_symbols(node.body, alias_body_map)
261    node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map)
262
263    cond_var_name = self.ctx.namer.new_symbol('cond', body_scope.referenced)
264    body_name = self.ctx.namer.new_symbol('if_true', body_scope.referenced)
265    orelse_name = self.ctx.namer.new_symbol('if_false', orelse_scope.referenced)
266    all_referenced = body_scope.referenced | orelse_scope.referenced
267    state_getter_name = self.ctx.namer.new_symbol('get_state', all_referenced)
268    state_setter_name = self.ctx.namer.new_symbol('set_state', all_referenced)
269
270    returned_from_cond = tuple(returned_from_cond)
271    composites = tuple(composites)
272
273    if returned_from_cond:
274      if len(returned_from_cond) == 1:
275        cond_results = returned_from_cond[0]
276      else:
277        cond_results = gast.Tuple([s.ast() for s in returned_from_cond], None)
278
279      returned_from_body = tuple(
280          alias_body_map[s] if s in need_alias_in_body else s
281          for s in returned_from_cond)
282      returned_from_orelse = tuple(
283          alias_orelse_map[s] if s in need_alias_in_orelse else s
284          for s in returned_from_cond)
285
286    else:
287      # When the cond would return no value, we leave the cond called without
288      # results. That in turn should trigger the side effect guards. The
289      # branch functions will return a dummy value that ensures cond
290      # actually has some return value as well.
291      cond_results = None
292      # TODO(mdan): Replace with None once side_effect_guards is retired.
293      returned_from_body = (templates.replace_as_expression(
294          'ag__.match_staging_level(1, cond_var_name)',
295          cond_var_name=cond_var_name),)
296      returned_from_orelse = (templates.replace_as_expression(
297          'ag__.match_staging_level(1, cond_var_name)',
298          cond_var_name=cond_var_name),)
299
300    cond_assign = self.create_assignment(cond_var_name, node.test)
301    body_def = self._create_cond_branch(
302        body_name,
303        aliased_orig_names=aliased_body_orig_names,
304        aliased_new_names=aliased_body_new_names,
305        body=node_body,
306        returns=returned_from_body)
307    orelse_def = self._create_cond_branch(
308        orelse_name,
309        aliased_orig_names=aliased_orelse_orig_names,
310        aliased_new_names=aliased_orelse_new_names,
311        body=node_orelse,
312        returns=returned_from_orelse)
313    undefined_assigns = self._create_undefined_assigns(possibly_undefined)
314    composite_defs = self._create_state_functions(
315        composites, state_getter_name, state_setter_name)
316
317    basic_symbol_names = tuple(
318        gast.Constant(str(symbol), kind=None) for symbol in returned_from_cond)
319    composite_symbol_names = tuple(
320        gast.Constant(str(symbol), kind=None) for symbol in composites)
321
322    cond_expr = self._create_cond_expr(cond_results, cond_var_name, body_name,
323                                       orelse_name, state_getter_name,
324                                       state_setter_name, basic_symbol_names,
325                                       composite_symbol_names)
326
327    if_ast = (
328        undefined_assigns + composite_defs + body_def + orelse_def +
329        cond_assign + cond_expr)
330    return if_ast
331
332  def _get_basic_loop_vars(self, modified_symbols, live_in, live_out):
333    # The loop variables corresponding to simple symbols (e.g. `x`).
334    basic_loop_vars = []
335    for s in modified_symbols:
336      if s.is_composite():
337        # TODO(mdan): Raise an error when this happens for a TF loop.
338        continue
339      # Variables not live into or out of the loop are considered local to the
340      # loop.
341      if s not in live_in and s not in live_out:
342        continue
343      basic_loop_vars.append(s)
344    return frozenset(basic_loop_vars)
345
346  def _get_composite_loop_vars(self, modified_symbols, live_in):
347    # The loop variables corresponding to composite symbols (e.g. `self.x`).
348    composite_loop_vars = []
349    for s in modified_symbols:
350      if not s.is_composite():
351        continue
352      # Mutations made to objects created inside the loop will appear as writes
353      # to composite symbols. Because these mutations appear as modifications
354      # made to composite symbols, we check whether the composite's parent is
355      # actually live into the loop.
356      # Example:
357      #   while cond:
358      #     x = Foo()
359      #     x.foo = 2 * x.foo  # x.foo is live into the loop, but x is not.
360      #
361      # Note that some parents might not be symbols - for example, in x['foo'],
362      # 'foo' is a parent, but it's a literal, not a symbol. We don't check the
363      # liveness of literals.
364      support_set_symbols = tuple(
365          sss for sss in s.support_set if sss.is_symbol())
366      if not all(sss in live_in for sss in support_set_symbols):
367        continue
368      composite_loop_vars.append(s)
369    return frozenset(composite_loop_vars)
370
371  def _get_loop_vars(self, node, modified_symbols):
372    body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
373    defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN)
374    live_in = anno.getanno(node, anno.Static.LIVE_VARS_IN)
375    live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT)
376    reserved_symbols = body_scope.referenced
377
378    basic_loop_vars = self._get_basic_loop_vars(
379        modified_symbols, live_in, live_out)
380    composite_loop_vars = self._get_composite_loop_vars(
381        modified_symbols, live_in)
382
383    # Variable that are used or defined inside the loop, but not defined
384    # before entering the loop. Only simple variables must be defined. The
385    # composite ones will be implicitly checked at runtime.
386    undefined_lives = basic_loop_vars - defined_in
387
388    return (basic_loop_vars, composite_loop_vars, reserved_symbols,
389            undefined_lives)
390
391  def _loop_var_constructs(self, basic_loop_vars):
392    loop_vars = tuple(basic_loop_vars)
393    loop_vars_ast_tuple = gast.Tuple([n.ast() for n in loop_vars], None)
394
395    if len(loop_vars) == 1:
396      loop_vars = loop_vars[0]
397
398    return loop_vars, loop_vars_ast_tuple
399
400  def visit_While(self, node):
401    node = self.generic_visit(node)
402
403    (basic_loop_vars, composite_loop_vars, reserved_symbols,
404     possibly_undefs) = self._get_loop_vars(
405         node,
406         anno.getanno(node, annos.NodeAnno.BODY_SCOPE).modified)
407    loop_vars, loop_vars_ast_tuple = self._loop_var_constructs(
408        basic_loop_vars)
409
410    state_getter_name = self.ctx.namer.new_symbol('get_state', reserved_symbols)
411    state_setter_name = self.ctx.namer.new_symbol('set_state', reserved_symbols)
412    state_functions = self._create_state_functions(
413        composite_loop_vars, state_getter_name, state_setter_name)
414
415    basic_symbol_names = tuple(
416        gast.Constant(str(symbol), kind=None) for symbol in basic_loop_vars)
417    composite_symbol_names = tuple(
418        gast.Constant(str(symbol), kind=None) for symbol in composite_loop_vars)
419
420    opts = self._create_loop_options(node)
421
422    # TODO(mdan): Use a single template.
423    # If the body and test functions took a single tuple for loop_vars, instead
424    # of *loop_vars, then a single template could be used.
425    if loop_vars:
426      template = """
427        state_functions
428        def body_name(loop_vars):
429          body
430          return loop_vars,
431        def test_name(loop_vars):
432          return test
433        loop_vars_ast_tuple = ag__.while_stmt(
434            test_name,
435            body_name,
436            state_getter_name,
437            state_setter_name,
438            (loop_vars,),
439            (basic_symbol_names,),
440            (composite_symbol_names,),
441            opts)
442      """
443      node = templates.replace(
444          template,
445          loop_vars=loop_vars,
446          loop_vars_ast_tuple=loop_vars_ast_tuple,
447          test_name=self.ctx.namer.new_symbol('loop_test', reserved_symbols),
448          test=node.test,
449          body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols),
450          body=node.body,
451          state_functions=state_functions,
452          state_getter_name=state_getter_name,
453          state_setter_name=state_setter_name,
454          basic_symbol_names=basic_symbol_names,
455          composite_symbol_names=composite_symbol_names,
456          opts=opts)
457    else:
458      template = """
459        state_functions
460        def body_name():
461          body
462          return ()
463        def test_name():
464          return test
465        ag__.while_stmt(
466            test_name,
467            body_name,
468            state_getter_name,
469            state_setter_name,
470            (),
471            (),
472            (composite_symbol_names,),
473            opts)
474      """
475      node = templates.replace(
476          template,
477          test_name=self.ctx.namer.new_symbol('loop_test', reserved_symbols),
478          test=node.test,
479          body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols),
480          body=node.body,
481          state_functions=state_functions,
482          state_getter_name=state_getter_name,
483          state_setter_name=state_setter_name,
484          composite_symbol_names=composite_symbol_names,
485          opts=opts)
486
487    undefined_assigns = self._create_undefined_assigns(possibly_undefs)
488    return undefined_assigns + node
489
490  def visit_For(self, node):
491    node = self.generic_visit(node)
492
493    (basic_loop_vars, composite_loop_vars,
494     reserved_symbols, possibly_undefs) = self._get_loop_vars(
495         node, (anno.getanno(node, annos.NodeAnno.BODY_SCOPE).modified
496                | anno.getanno(node, annos.NodeAnno.ITERATE_SCOPE).modified))
497    loop_vars, loop_vars_ast_tuple = self._loop_var_constructs(
498        basic_loop_vars)
499    body_name = self.ctx.namer.new_symbol('loop_body', reserved_symbols)
500
501    state_getter_name = self.ctx.namer.new_symbol('get_state', reserved_symbols)
502    state_setter_name = self.ctx.namer.new_symbol('set_state', reserved_symbols)
503    state_functions = self._create_state_functions(
504        composite_loop_vars, state_getter_name, state_setter_name)
505
506    if anno.hasanno(node, anno.Basic.EXTRA_LOOP_TEST):
507      extra_test = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST)
508      extra_test_name = self.ctx.namer.new_symbol(
509          'extra_test', reserved_symbols)
510      template = """
511        def extra_test_name(loop_vars):
512          return extra_test_expr
513      """
514      extra_test_function = templates.replace(
515          template,
516          extra_test_name=extra_test_name,
517          loop_vars=loop_vars,
518          extra_test_expr=extra_test)
519    else:
520      extra_test_name = parser.parse_expression('None')
521      extra_test_function = []
522
523    # Workaround for PEP-3113
524    # iterates_var holds a single variable with the iterates, which may be a
525    # tuple.
526    iterates_var_name = self.ctx.namer.new_symbol(
527        'iterates', reserved_symbols)
528    template = """
529      iterates = iterates_var_name
530    """
531    iterate_expansion = templates.replace(
532        template,
533        iterates=node.target,
534        iterates_var_name=iterates_var_name)
535
536    undefined_assigns = self._create_undefined_assigns(possibly_undefs)
537
538    basic_symbol_names = tuple(
539        gast.Constant(str(symbol), kind=None) for symbol in basic_loop_vars)
540    composite_symbol_names = tuple(
541        gast.Constant(str(symbol), kind=None) for symbol in composite_loop_vars)
542
543    opts = self._create_loop_options(node)
544
545    # TODO(mdan): Use a single template.
546    # If the body and test functions took a single tuple for loop_vars, instead
547    # of *loop_vars, then a single template could be used.
548    if loop_vars:
549      template = """
550        undefined_assigns
551        state_functions
552        def body_name(iterates_var_name, loop_vars):
553          iterate_expansion
554          body
555          return loop_vars,
556        extra_test_function
557        loop_vars_ast_tuple = ag__.for_stmt(
558            iter_,
559            extra_test_name,
560            body_name,
561            state_getter_name,
562            state_setter_name,
563            (loop_vars,),
564            (basic_symbol_names,),
565            (composite_symbol_names,),
566            opts)
567      """
568      return templates.replace(
569          template,
570          undefined_assigns=undefined_assigns,
571          loop_vars=loop_vars,
572          loop_vars_ast_tuple=loop_vars_ast_tuple,
573          iter_=node.iter,
574          iterate_expansion=iterate_expansion,
575          iterates_var_name=iterates_var_name,
576          extra_test_name=extra_test_name,
577          extra_test_function=extra_test_function,
578          body_name=body_name,
579          body=node.body,
580          state_functions=state_functions,
581          state_getter_name=state_getter_name,
582          state_setter_name=state_setter_name,
583          basic_symbol_names=basic_symbol_names,
584          composite_symbol_names=composite_symbol_names,
585          opts=opts)
586    else:
587      template = """
588        undefined_assigns
589        state_functions
590        def body_name(iterates_var_name):
591          iterate_expansion
592          body
593          return ()
594        extra_test_function
595        ag__.for_stmt(
596            iter_,
597            extra_test_name,
598            body_name,
599            state_getter_name,
600            state_setter_name,
601            (),
602            (),
603            (composite_symbol_names,),
604            opts)
605      """
606      return templates.replace(
607          template,
608          undefined_assigns=undefined_assigns,
609          iter_=node.iter,
610          iterate_expansion=iterate_expansion,
611          iterates_var_name=iterates_var_name,
612          extra_test_name=extra_test_name,
613          extra_test_function=extra_test_function,
614          body_name=body_name,
615          body=node.body,
616          state_functions=state_functions,
617          state_getter_name=state_getter_name,
618          state_setter_name=state_setter_name,
619          composite_symbol_names=composite_symbol_names,
620          opts=opts)
621
622
623class AnnotatedDef(reaching_definitions.Definition):
624
625  def __init__(self):
626    super(AnnotatedDef, self).__init__()
627    self.directives = {}
628
629
630def transform(node, ctx):
631  graphs = cfg.build(node)
632  node = qual_names.resolve(node)
633  node = activity.resolve(node, ctx, None)
634  node = reaching_definitions.resolve(node, ctx, graphs)
635  node = reaching_fndefs.resolve(node, ctx, graphs)
636  node = liveness.resolve(node, ctx, graphs)
637
638  node = ControlFlowTransformer(ctx).visit(node)
639  return node
640