1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Lowers break statements to conditionals.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.autograph.core import converter 22from tensorflow.python.autograph.pyct import anno 23from tensorflow.python.autograph.pyct import templates 24from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno 25 26 27class _Break(object): 28 29 def __init__(self): 30 self.used = False 31 self.control_var_name = None 32 33 def __repr__(self): 34 return 'used: %s, var: %s' % (self.used, self.control_var_name) 35 36 37class BreakTransformer(converter.Base): 38 """Canonicalizes break statements into additional conditionals.""" 39 40 def visit_Break(self, node): 41 self.state[_Break].used = True 42 var_name = self.state[_Break].control_var_name 43 # TODO(mdan): This will fail when expanded inside a top-level else block. 44 template = """ 45 var_name = True 46 continue 47 """ 48 return templates.replace(template, var_name=var_name) 49 50 def _guard_if_present(self, block, var_name): 51 """Prevents the block from executing if var_name is set.""" 52 if not block: 53 return block 54 55 template = """ 56 if ag__.not_(var_name): 57 block 58 """ 59 node = templates.replace( 60 template, 61 var_name=var_name, 62 block=block) 63 return node 64 65 def _process_body(self, nodes, break_var): 66 self.state[_Break].enter() 67 self.state[_Break].control_var_name = break_var 68 nodes = self.visit_block(nodes) 69 break_used = self.state[_Break].used 70 self.state[_Break].exit() 71 return nodes, break_used 72 73 def visit_While(self, node): 74 scope = anno.getanno(node, NodeAnno.BODY_SCOPE) 75 break_var = self.ctx.namer.new_symbol('break_', scope.referenced) 76 77 node.test = self.visit(node.test) 78 node.body, break_used = self._process_body(node.body, break_var) 79 # A break in the else clause applies to the containing scope. 80 node.orelse = self.visit_block(node.orelse) 81 82 if break_used: 83 # Python's else clause only triggers if the loop exited cleanly (e.g. 84 # break did not trigger). 85 guarded_orelse = self._guard_if_present(node.orelse, break_var) 86 87 template = """ 88 var_name = False 89 while ag__.and_(lambda: test, lambda: ag__.not_(var_name)): 90 body 91 else: 92 orelse 93 """ 94 node = templates.replace( 95 template, 96 var_name=break_var, 97 test=node.test, 98 body=node.body, 99 orelse=guarded_orelse) 100 101 return node 102 103 def visit_For(self, node): 104 scope = anno.getanno(node, NodeAnno.BODY_SCOPE) 105 break_var = self.ctx.namer.new_symbol('break_', scope.referenced) 106 107 node.target = self.visit(node.target) 108 node.iter = self.visit(node.iter) 109 node.body, break_used = self._process_body(node.body, break_var) 110 # A break in the else clause applies to the containing scope. 111 node.orelse = self.visit_block(node.orelse) 112 113 if break_used: 114 # Python's else clause only triggers if the loop exited cleanly (e.g. 115 # break did not trigger). 116 guarded_orelse = self._guard_if_present(node.orelse, break_var) 117 extra_test = templates.replace_as_expression( 118 'ag__.not_(var_name)', var_name=break_var) 119 120 # The extra test is hidden in the AST, which will confuse the static 121 # analysis. To mitigate that, we insert a no-op statement that ensures 122 # the control variable is marked as used. 123 # TODO(mdan): Use a marker instead, e.g. ag__.condition_loop_on(var_name) 124 template = """ 125 var_name = False 126 for target in iter_: 127 (var_name,) 128 body 129 else: 130 orelse 131 """ 132 node = templates.replace( 133 template, 134 var_name=break_var, 135 iter_=node.iter, 136 target=node.target, 137 body=node.body, 138 orelse=guarded_orelse) 139 140 anno.setanno(node[1], 'extra_test', extra_test) 141 142 return node 143 144 145def transform(node, ctx): 146 return BreakTransformer(ctx).visit(node) 147