1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Live variable analysis. 16 17See https://en.wikipedia.org/wiki/Live_variable_analysis for a definition of 18the following idioms: live variable, live in, live out, which are used 19throughout this file. 20 21This analysis attaches the following: 22 * symbols that are live at the exit of control flow statements 23 * symbols that are live at the entry of control flow statements 24 25Requires activity analysis. 26""" 27 28from __future__ import absolute_import 29from __future__ import division 30from __future__ import print_function 31 32import gast 33 34from tensorflow.python.autograph.pyct import anno 35from tensorflow.python.autograph.pyct import cfg 36from tensorflow.python.autograph.pyct import transformer 37from tensorflow.python.autograph.pyct.static_analysis import annos 38 39 40class Analyzer(cfg.GraphVisitor): 41 """CFG visitor that performs liveness analysis at statement level.""" 42 43 def __init__(self, graph): 44 super(Analyzer, self).__init__(graph) 45 # This allows communicating that nodes generate extra symbols, 46 # e.g. those that a function definition closes over. 47 self.extra_gen = {} 48 49 def init_state(self, _): 50 return set() 51 52 def visit_node(self, node): 53 prev_live_in = self.in_[node] 54 55 if anno.hasanno(node.ast_node, anno.Static.SCOPE): 56 node_scope = anno.getanno(node.ast_node, anno.Static.SCOPE) 57 58 gen = node_scope.read | self.extra_gen.get(node.ast_node, frozenset()) 59 # TODO(mdan): verify whether composites' parents need to be added. 60 # E.g. whether x needs to be added if x.y is live. Theoretically the 61 # activity analysis should have both so that wouldn't be needed. 62 kill = node_scope.modified | node_scope.deleted 63 64 live_out = set() 65 for n in node.next: 66 live_out |= self.in_[n] 67 live_in = gen | (live_out - kill) 68 69 else: 70 # Nodes that don't have a scope annotation are assumed not to touch any 71 # symbols. 72 # This Name node below is a literal name, e.g. False 73 assert isinstance(node.ast_node, 74 (gast.Name, gast.Continue, gast.Break)), type( 75 node.ast_node) 76 live_in = prev_live_in 77 live_out = live_in 78 79 self.in_[node] = live_in 80 self.out[node] = live_out 81 82 # TODO(mdan): Move this to the superclass? 83 return prev_live_in != live_in 84 85 86class WholeTreeAnalyzer(transformer.Base): 87 """Runs liveness analysis on each of the functions defined in the AST. 88 89 If a function defined other local functions, those will have separate CFGs. 90 However, dataflow analysis needs to tie up these CFGs to properly emulate the 91 effect of closures. In the case of liveness, the parent function's live 92 variables must account for the variables that are live at the entry of each 93 subfunction. For example: 94 95 def foo(): 96 # baz is live here 97 def bar(): 98 print(baz) 99 100 This analyzer runs liveness analysis on each individual function, accounting 101 for the effect above. 102 """ 103 104 def __init__(self, source_info, graphs): 105 super(WholeTreeAnalyzer, self).__init__(source_info) 106 self.graphs = graphs 107 self.current_analyzer = None 108 self.analyzers = {} 109 110 def visit_FunctionDef(self, node): 111 parent_analyzer = self.current_analyzer 112 subgraph = self.graphs[node] 113 114 # Postorder tree processing makes this a bit complicated: 115 # 1. construct an analyzer object and put it on stack 116 # 2. recursively walk the subtree; this will initialize the analyzer's 117 # in_ state properly (done in a block below) 118 # 3. run the final analysis 119 analyzer = Analyzer(subgraph) 120 self.current_analyzer = analyzer 121 node = self.generic_visit(node) 122 analyzer.visit_reverse() 123 124 if parent_analyzer is not None: 125 # Wire the state between the two subgraphs' analyzers. 126 child_in_state = analyzer.in_[subgraph.entry] 127 # Exception: symbols modified in the child function are local to it 128 body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) 129 for qn in body_scope.modified: 130 # Note: a function modifying the symbol doesn't make that symbol 131 # live at the function's entry. In fact when that happens it is 132 # probably a case of undefined assignment, like this: 133 # 134 # bar = 0 135 # def foo(): 136 # print(bar) # bar is undefined here! 137 # bar = 1 138 # 139 # Hence we use discard and not remove below. 140 child_in_state.discard(qn) 141 parent_analyzer.extra_gen[node] = frozenset(child_in_state,) 142 143 self.analyzers[node] = analyzer 144 self.current_analyzer = parent_analyzer 145 return node 146 147 def visit_Nonlocal(self, node): 148 raise NotImplementedError() 149 150 def visit_Global(self, node): 151 raise NotImplementedError() 152 153 154class Annotator(transformer.Base): 155 """AST visitor that annotates each control flow block with live symbols.""" 156 157 # Note: additional nodes may be added as needed. 158 159 def __init__(self, source_info, cross_function_analyzer): 160 super(Annotator, self).__init__(source_info) 161 self.cross_function_analyzer = cross_function_analyzer 162 self.current_analyzer = None 163 164 def visit(self, node): 165 node = super(Annotator, self).visit(node) 166 if (self.current_analyzer is not None and 167 isinstance(node, gast.stmt) and 168 node in self.current_analyzer.graph.index): 169 cfg_node = self.current_analyzer.graph.index[node] 170 anno.setanno(node, anno.Static.LIVE_VARS_IN, 171 frozenset(self.current_analyzer.in_[cfg_node])) 172 return node 173 174 def visit_FunctionDef(self, node): 175 parent_analyzer = self.current_analyzer 176 self.current_analyzer = self.cross_function_analyzer.analyzers[node] 177 178 node = self.generic_visit(node) 179 self.current_analyzer = parent_analyzer 180 return node 181 182 def _block_statement_live_out(self, node): 183 successors = self.current_analyzer.graph.stmt_next[node] 184 stmt_live_out = set() 185 for s in successors: 186 stmt_live_out.update(self.current_analyzer.in_[s]) 187 anno.setanno(node, anno.Static.LIVE_VARS_OUT, frozenset(stmt_live_out)) 188 return node 189 190 def _block_statement_live_in(self, node, entry_node): 191 cfg_node = self.current_analyzer.graph.index[entry_node] 192 stmt_live_in = frozenset(self.current_analyzer.in_[cfg_node]) 193 anno.setanno(node, anno.Static.LIVE_VARS_IN, stmt_live_in) 194 return node 195 196 def visit_If(self, node): 197 node = self.generic_visit(node) 198 node = self._block_statement_live_out(node) 199 return self._block_statement_live_in(node, node.test) 200 201 def visit_For(self, node): 202 node = self.generic_visit(node) 203 node = self._block_statement_live_out(node) 204 return self._block_statement_live_in(node, node.iter) 205 206 def visit_While(self, node): 207 node = self.generic_visit(node) 208 node = self._block_statement_live_out(node) 209 return self._block_statement_live_in(node, node.test) 210 211 def visit_With(self, node): 212 node = self.generic_visit(node) 213 return self._block_statement_live_in(node, node.items[0]) 214 215 def visit_Expr(self, node): 216 node = self.generic_visit(node) 217 cfg_node = self.current_analyzer.graph.index[node] 218 anno.setanno(node, anno.Static.LIVE_VARS_OUT, 219 frozenset(self.current_analyzer.out[cfg_node])) 220 return node 221 222 def visit_ExceptHandler(self, node): 223 # TODO(b/123995141) Add Exception Handlers to the CFG 224 return node 225 226 227def resolve(node, source_info, graphs): 228 """Resolves the live symbols at the exit of control flow statements. 229 230 Args: 231 node: ast.AST 232 source_info: transformer.SourceInfo 233 graphs: Dict[ast.FunctionDef, cfg.Graph] 234 Returns: 235 ast.AST 236 """ 237 cross_function_analyzer = WholeTreeAnalyzer(source_info, graphs) 238 node = cross_function_analyzer.visit(node) 239 visitor = Annotator(source_info, cross_function_analyzer) 240 node = visitor.visit(node) 241 return node 242