1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Lowers break statements to conditionals."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.autograph.core import converter
22from tensorflow.python.autograph.pyct import anno
23from tensorflow.python.autograph.pyct import templates
24from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
25
26
27class _Break(object):
28
29  def __init__(self):
30    self.used = False
31    self.control_var_name = None
32
33  def __repr__(self):
34    return 'used: %s, var: %s' % (self.used, self.control_var_name)
35
36
37class BreakTransformer(converter.Base):
38  """Canonicalizes break statements into additional conditionals."""
39
40  def visit_Break(self, node):
41    self.state[_Break].used = True
42    var_name = self.state[_Break].control_var_name
43    # TODO(mdan): This will fail when expanded inside a top-level else block.
44    template = """
45      var_name = True
46      continue
47    """
48    return templates.replace(template, var_name=var_name)
49
50  def _guard_if_present(self, block, var_name):
51    """Prevents the block from executing if var_name is set."""
52    if not block:
53      return block
54
55    template = """
56        if ag__.not_(var_name):
57          block
58      """
59    node = templates.replace(
60        template,
61        var_name=var_name,
62        block=block)
63    return node
64
65  def _process_body(self, nodes, break_var):
66    self.state[_Break].enter()
67    self.state[_Break].control_var_name = break_var
68    nodes = self.visit_block(nodes)
69    break_used = self.state[_Break].used
70    self.state[_Break].exit()
71    return nodes, break_used
72
73  def visit_While(self, node):
74    scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
75    break_var = self.ctx.namer.new_symbol('break_', scope.referenced)
76
77    node.test = self.visit(node.test)
78    node.body, break_used = self._process_body(node.body, break_var)
79    # A break in the else clause applies to the containing scope.
80    node.orelse = self.visit_block(node.orelse)
81
82    if break_used:
83      # Python's else clause only triggers if the loop exited cleanly (e.g.
84      # break did not trigger).
85      guarded_orelse = self._guard_if_present(node.orelse, break_var)
86
87      template = """
88        var_name = False
89        while ag__.and_(lambda: test, lambda: ag__.not_(var_name)):
90          body
91        else:
92          orelse
93      """
94      node = templates.replace(
95          template,
96          var_name=break_var,
97          test=node.test,
98          body=node.body,
99          orelse=guarded_orelse)
100
101    return node
102
103  def visit_For(self, node):
104    scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
105    break_var = self.ctx.namer.new_symbol('break_', scope.referenced)
106
107    node.target = self.visit(node.target)
108    node.iter = self.visit(node.iter)
109    node.body, break_used = self._process_body(node.body, break_var)
110    # A break in the else clause applies to the containing scope.
111    node.orelse = self.visit_block(node.orelse)
112
113    if break_used:
114      # Python's else clause only triggers if the loop exited cleanly (e.g.
115      # break did not trigger).
116      guarded_orelse = self._guard_if_present(node.orelse, break_var)
117      extra_test = templates.replace_as_expression(
118          'ag__.not_(var_name)', var_name=break_var)
119
120      # The extra test is hidden in the AST, which will confuse the static
121      # analysis. To mitigate that, we insert a no-op statement that ensures
122      # the control variable is marked as used.
123      # TODO(mdan): Use a marker instead, e.g. ag__.condition_loop_on(var_name)
124      template = """
125        var_name = False
126        for target in iter_:
127          (var_name,)
128          body
129        else:
130          orelse
131      """
132      node = templates.replace(
133          template,
134          var_name=break_var,
135          iter_=node.iter,
136          target=node.target,
137          body=node.body,
138          orelse=guarded_orelse)
139
140      anno.setanno(node[1], 'extra_test', extra_test)
141
142    return node
143
144
145def transform(node, ctx):
146  return BreakTransformer(ctx).visit(node)
147