1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Reaching definition analysis.
16
17This analysis attaches a set of a Definition objects to each symbol, one
18for each distinct definition that may reach it. The Definition objects are
19mutable and may be used by subsequent analyses to further annotate data like
20static type and value information.
21The analysis also attaches the set of the symbols defined at the entry of
22control flow statements.
23
24Requires activity analysis.
25"""
26
27from __future__ import absolute_import
28from __future__ import division
29from __future__ import print_function
30
31import weakref
32
33import gast
34
35from tensorflow.python.autograph.pyct import anno
36from tensorflow.python.autograph.pyct import cfg
37from tensorflow.python.autograph.pyct import transformer
38
39
40class Definition(object):
41  """Definition objects describe a unique definition of a variable.
42
43  Subclasses of this may be used by passing an appropriate factory function to
44  resolve.
45
46  Attributes:
47    param_of: Optional[ast.AST]
48    directives: Dict, optional definition annotations
49  """
50
51  def __init__(self):
52    self.param_of = None
53    self.directives = {}
54
55  def __repr__(self):
56    return '%s[%d]' % (self.__class__.__name__, id(self))
57
58
59class _NodeState(object):
60  """Abstraction for the state of the CFG walk for reaching definition analysis.
61
62  This is a value type. Only implements the strictly necessary operators.
63
64  Attributes:
65    value: Dict[qual_names.QN, Set[Definition, ...]], the defined symbols and
66        their possible definitions
67  """
68
69  def __init__(self, init_from=None):
70    if init_from:
71      if isinstance(init_from, _NodeState):
72        self.value = {
73            s: set(other_infos) for s, other_infos in init_from.value.items()
74        }
75      elif isinstance(init_from, dict):
76        self.value = {s: set((init_from[s],)) for s in init_from}
77      else:
78        assert False, init_from
79    else:
80      self.value = {}
81
82  def __eq__(self, other):
83    if frozenset(self.value.keys()) != frozenset(other.value.keys()):
84      return False
85    ret = all(self.value[s] == other.value[s] for s in self.value)
86    return ret
87
88  def __ne__(self, other):
89    return not self.__eq__(other)
90
91  def __or__(self, other):
92    assert isinstance(other, _NodeState)
93    result = _NodeState(self)
94    for s, other_infos in other.value.items():
95      if s in result.value:
96        result.value[s].update(other_infos)
97      else:
98        result.value[s] = set(other_infos)
99    return result
100
101  def __sub__(self, other):
102    assert isinstance(other, set)
103    result = _NodeState(self)
104    for s in other:
105      result.value.pop(s, None)
106    return result
107
108  def __repr__(self):
109    return 'NodeState[%s]=%s' % (id(self), repr(self.value))
110
111
112class Analyzer(cfg.GraphVisitor):
113  """CFG visitor that determines reaching definitions at statement level."""
114
115  def __init__(self, graph, definition_factory):
116    self._definition_factory = definition_factory
117    super(Analyzer, self).__init__(graph)
118    self.gen_map = {}
119
120  def init_state(self, _):
121    return _NodeState()
122
123  def visit_node(self, node):
124    prev_defs_out = self.out[node]
125
126    defs_in = _NodeState()
127    for n in node.prev:
128      defs_in |= self.out[n]
129
130    if anno.hasanno(node.ast_node, anno.Static.SCOPE):
131      node_scope = anno.getanno(node.ast_node, anno.Static.SCOPE)
132      # The definition objects created by each node must be singletons because
133      # their ids are used in equality checks.
134      if node not in self.gen_map:
135        node_symbols = {}
136        # Every binding operation (assign, nonlocal, global, etc.) counts as a
137        # definition, with the exception of del, which only deletes without
138        # creating a new variable.
139        newly_defined = ((node_scope.bound | node_scope.globals) -
140                         node_scope.deleted)
141        for s in newly_defined:
142          def_ = self._definition_factory()
143          node_symbols[s] = def_
144        # Every param receives a definition. Params are not necessarily
145        # considered as "modified".
146        for s, p in node_scope.params.items():
147          def_ = self._definition_factory()
148          def_.param_of = weakref.ref(p)
149          node_symbols[s] = def_
150        self.gen_map[node] = _NodeState(node_symbols)
151
152      gen = self.gen_map[node]
153      kill = node_scope.modified | node_scope.deleted
154      defs_out = gen | (defs_in - kill)
155
156      gen = self.gen_map[node]
157      defs_out = gen | (defs_in - kill)
158
159    else:
160      assert self.can_ignore(node), (node.ast_node, node)
161      defs_out = defs_in
162
163    self.in_[node] = defs_in
164    self.out[node] = defs_out
165
166    return prev_defs_out != defs_out
167
168
169class TreeAnnotator(transformer.Base):
170  """AST visitor that annotates each symbol name with its reaching definitions.
171
172  Simultaneously, the visitor runs the dataflow analysis on each function node,
173  accounting for the effect of closures. For example:
174
175    def foo():
176      bar = 1
177      def baz():
178        # bar = 1 reaches here
179  """
180
181  def __init__(self, source_info, graphs, definition_factory):
182    super(TreeAnnotator, self).__init__(source_info)
183    self.allow_skips = False
184    self.definition_factory = definition_factory
185    self.graphs = graphs
186    self.current_analyzer = None
187    self.current_cfg_node = None
188
189  def visit_FunctionDef(self, node):
190    parent_analyzer = self.current_analyzer
191    subgraph = self.graphs[node]
192
193    analyzer = Analyzer(subgraph, self.definition_factory)
194    analyzer.visit_forward()
195
196    # Recursively process any remaining subfunctions.
197    self.current_analyzer = analyzer
198    node.args = self.visit(node.args)
199    node.body = self.visit_block(node.body)
200    self.current_analyzer = parent_analyzer
201
202    return node
203
204  def visit_Name(self, node):
205    if self.current_analyzer is None:
206      # Names may appear outside function defs - for example in class
207      # definitions.
208      return node
209
210    analyzer = self.current_analyzer
211    cfg_node = self.current_cfg_node
212
213    assert cfg_node is not None, ('name node, %s, outside of any statement?'
214                                  % node.id)
215
216    qn = anno.getanno(node, anno.Basic.QN)
217    if isinstance(node.ctx, gast.Load):
218      anno.setanno(node, anno.Static.DEFINITIONS,
219                   tuple(analyzer.in_[cfg_node].value.get(qn, ())))
220    else:
221      anno.setanno(node, anno.Static.DEFINITIONS,
222                   tuple(analyzer.out[cfg_node].value.get(qn, ())))
223
224    return node
225
226  def _aggregate_predecessors_defined_in(self, node):
227    preds = self.current_analyzer.graph.stmt_prev[node]
228    node_defined_in = set()
229    for p in preds:
230      node_defined_in |= set(self.current_analyzer.out[p].value.keys())
231    anno.setanno(node, anno.Static.DEFINED_VARS_IN, frozenset(node_defined_in))
232
233  def visit_If(self, node):
234    self._aggregate_predecessors_defined_in(node)
235    return self.generic_visit(node)
236
237  def visit_For(self, node):
238    self._aggregate_predecessors_defined_in(node)
239
240    # Manually accounting for the shortcoming described in
241    # cfg.AstToCfg.visit_For.
242    parent = self.current_cfg_node
243    self.current_cfg_node = self.current_analyzer.graph.index[node.iter]
244    node.target = self.visit(node.target)
245    self.current_cfg_node = parent
246
247    node.iter = self.visit(node.iter)
248    node.body = self.visit_block(node.body)
249    node.orelse = self.visit_block(node.orelse)
250
251    return node
252
253  def visit_While(self, node):
254    self._aggregate_predecessors_defined_in(node)
255    return self.generic_visit(node)
256
257  def visit_Try(self, node):
258    self._aggregate_predecessors_defined_in(node)
259    return self.generic_visit(node)
260
261  def visit_ExceptHandler(self, node):
262    self._aggregate_predecessors_defined_in(node)
263    # TODO(mdan): Also track the exception type / name symbols.
264    node.body = self.visit_block(node.body)
265    return node
266
267  def visit(self, node):
268    parent = self.current_cfg_node
269
270    if (self.current_analyzer is not None and
271        node in self.current_analyzer.graph.index):
272      self.current_cfg_node = self.current_analyzer.graph.index[node]
273    node = super(TreeAnnotator, self).visit(node)
274
275    self.current_cfg_node = parent
276    return node
277
278
279def resolve(node, source_info, graphs, definition_factory=Definition):
280  """Resolves reaching definitions for each symbol.
281
282  Args:
283    node: ast.AST
284    source_info: transformer.SourceInfo
285    graphs: Dict[ast.FunctionDef, cfg.Graph]
286    definition_factory: Callable[[], Definition]
287  Returns:
288    ast.AST
289  """
290  visitor = TreeAnnotator(source_info, graphs, definition_factory)
291  node = visitor.visit(node)
292  return node
293