1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Type inference.
16
17This analysis annotates all symbols nodes of an AST with type information
18extracted from static sources:
19 * type annotations
20 * global and local symbols visible to the function at analysis time
21 * literals
22
23Important: This analysis is static, and does not detect dynamic type changes.
24The analysis attempts to use the values of external symbols, if available. These
25values are also considered static for the purpose of analysis.
26
27Requires reaching function definitions analysis.
28"""
29
30from __future__ import absolute_import
31from __future__ import division
32from __future__ import print_function
33
34import itertools
35
36from typing import Any, Callable, Dict, Set
37
38import gast
39
40from tensorflow.python.autograph.pyct import anno
41from tensorflow.python.autograph.pyct import cfg
42from tensorflow.python.autograph.pyct import qual_names
43from tensorflow.python.autograph.pyct import transformer
44from tensorflow.python.autograph.pyct.static_analysis import activity
45from tensorflow.python.autograph.pyct.static_analysis import annos
46
47
48class Resolver(object):
49  """Resolver objects handle the process of looking up actual names and types.
50
51  Unless noted otherwise, all resolve_* methods:
52    * have a first namespace argument, mapping string to actual values
53    * have a second types_namespace argument, mapping string to actual inferred
54      types
55    * specify names as QN objects
56    * specify types as a Set of inferred types
57
58  Unless noted otherwise, all resolve_* methods must return either:
59    * a set of `type` objects
60    * None
61  """
62
63  def res_name(self, ns, types_ns, name):
64    """Resolves the type/value an external (e.g. closure, global) variable.
65
66    Args:
67      ns: namespace
68      types_ns: types namespace
69      name: symbol name
70    Returns:
71      Tuple (type, static_value). The first element is the type to use for
72      inferrence. The second is the static value to use. Return None to treat it
73      as unknown.
74    """
75    raise NotImplementedError('subclasses must implement')
76
77  def res_value(self, ns, value):
78    """Resolves the type a literal or static value."""
79    raise NotImplementedError('subclasses must implement')
80
81  def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local):
82    """Resolves the type of a (possibly annotated) function argument.
83
84    Args:
85      ns: namespace
86      types_ns: types namespace
87      f_name: str, the function name
88      name: str, the argument name
89      type_anno: the type annotating the argument, if any
90      f_is_local: bool, whether the function is a local function
91    Returns:
92      Set of the argument types.
93    """
94    raise NotImplementedError('subclasses must implement')
95
96  def res_call(self, ns, types_ns, node, f_type, args, keywords):
97    """Resolves the return type an external function or method call.
98
99    Args:
100      ns: namespace
101      types_ns: types namespace
102      node: str, the function name
103      f_type: types of the actual function being called, if known
104      args: types of each respective argument in node.args
105      keywords: types of each respective argument in node.keywords
106
107    Returns:
108      Tuple (return_type, side_effect_types). The first element is just the
109      return types of the function. The second element is a map from
110      argument names to sets of types, and allow modelling side effects of
111      functions (for example via global or nonlocal).
112    """
113    raise NotImplementedError('subclasses must implement')
114
115  # TODO(mdan): Clean this up.
116  def res_slice(self, ns, types_ns, node_or_slice, value, slice_):
117    """Resolves the return type of slice operation."""
118    raise NotImplementedError('subclasses must implement')
119
120  def res_compare(self, ns, types_ns, node, left, right):
121    """Resolves the return type of a unary operation."""
122    raise NotImplementedError('subclasses must implement')
123
124  def res_unop(self, ns, types_ns, node, opnd):
125    """Resolves the return type of a unary operation."""
126    raise NotImplementedError('subclasses must implement')
127
128  def res_binop(self, ns, types_ns, node, left, right):
129    """Resolves the return type of a binary operation."""
130    raise NotImplementedError('subclasses must implement')
131
132  def res_list_literal(self, ns, elt_types):
133    """Resolves the type of a list literal from its elements."""
134    raise NotImplementedError('subclasses must implement')
135
136
137class _TypeMap(object):
138  """Abstraction for the state of the CFG walk for type inference.
139
140  This is a value type. Only implements the strictly necessary operators.
141
142  Attributes:
143    types: Dict[qual_names.QN, Set[Type]], mapping symbols to the set of
144      possible types.
145  """
146
147  def __init__(self, init_from=None):
148    if init_from:
149      assert isinstance(init_from, _TypeMap)
150      self.types = {
151          s: set(other_types) for s, other_types in init_from.types.items()
152      }
153    else:
154      self.types = {}
155
156  def __eq__(self, other):
157    if frozenset(self.types.keys()) != frozenset(other.types.keys()):
158      return False
159    ret = all(self.types[s] == other.types[s] for s in self.types)
160    return ret
161
162  def __ne__(self, other):
163    return not self.__eq__(other)
164
165  def __or__(self, other):
166    assert isinstance(other, _TypeMap)
167    result = _TypeMap(self)
168    for s, other_types in other.types.items():
169      if s not in result.types:
170        self_types = set()
171        result.types[s] = self_types
172      else:
173        self_types = result.types[s]
174      self_types.update(other_types)
175    return result
176
177  def __repr__(self):
178    return 'SymbolTable {}'.format(self.types)
179
180
181NO_VALUE = object()
182
183
184class StmtInferrer(gast.NodeVisitor):
185  """Runs type inference on a single AST statement.
186
187  This visitor annotates most nodes with type information. It also sets types
188  for the symbols modified by this statement in its types_out property.
189
190  Note: this inferrer is able to capture side effects of functions, however,
191  these side effects will not be applied to the current expression. Doing so
192  would create too much of a dependence on the runtime's internal rules about
193  execution order.
194  Example:
195
196    def f():
197      nonlocal a
198      a = 1
199      return a
200
201    a = 0.0
202    b = f() + a  # a = float; side effect of f() ignored
203    print(a)  # a = int; side effect of f() accounted for
204  """
205
206  def __init__(self,
207               resolver: Resolver,
208               scope: activity.Scope,
209               namespace: Dict[qual_names.QN, Any],
210               closure_types: Dict[qual_names.QN, Set[Any]],
211               types_in: _TypeMap):
212    self.resolver = resolver
213    self.scope = scope
214    self.namespace = namespace
215    self.closure_types = closure_types
216    self.types_in = types_in
217    self.new_symbols = {}
218
219    # rvalue type. This property is set when encountering an assign operation,
220    # so that visiting nodes with Store ctx (typically found on left side of
221    # assignments) can infer the type they should receive.
222    self.rtype = None
223
224  def visit(self, node):
225    types = super().visit(node)
226    if __debug__:
227      self._check_set(types)
228    if types is not None:
229      # TODO(mdan): Normalize by removing subtypes.
230      anno.setanno(node, anno.Static.TYPES, tuple(types))
231    return types
232
233  def _check_set(self, value):
234    if value is not None and not isinstance(value, set):
235      raise ValueError('{} method expected to return set, got {}'.format(
236          self.resolver, value))
237
238  def visit_Constant(self, node):
239    types = self.resolver.res_value(self.namespace, node.value)
240    if __debug__:
241      self._check_set(types)
242    return types
243
244  def _apply_unpacking(self, node):
245    assert isinstance(node.ctx, gast.Store)
246    if self.rtype is not None:
247      original_stype = self.rtype
248      # TODO(mdan): Find a better way to express unpacking.
249      i_type = self.resolver.res_value(self.namespace, 0)
250      for i, elt in enumerate(node.elts):
251        self.rtype = self.resolver.res_slice(
252            self.namespace, self.types_in.types, i, original_stype, i_type)
253        self.visit(elt)
254      self.rtype = original_stype
255      return original_stype
256    return None
257
258  def visit_Tuple(self, node):
259    if isinstance(node.ctx, gast.Load):
260      elt_types = ()
261      for elt in node.elts:
262        types_ = self.visit(elt)
263        if types_ is None:
264          return None
265        elt_types += (types_,)
266      return set(itertools.product(*elt_types))
267    return self._apply_unpacking(node)
268
269  def visit_List(self, node):
270    if isinstance(node.ctx, gast.Load):
271      elt_types = tuple(self.visit(elt) for elt in node.elts)
272      return self.resolver.res_list_literal(self.namespace, elt_types)
273    return self._apply_unpacking(node)
274
275  def visit_Set(self, node):
276    raise NotImplementedError()
277
278  def visit_Name(self, node):
279    name = anno.getanno(node, anno.Basic.QN)
280
281    if isinstance(node.ctx, gast.Load):
282      types = self.types_in.types.get(name, None)
283      if types is None:
284        if (name not in self.scope.bound) or (name in self.scope.nonlocals):
285          # TODO(mdan): Test with global variables.
286          if name in self.closure_types:
287            types = self.closure_types[name]
288          else:
289            types, value = self.resolver.res_name(
290                self.namespace, self.types_in.types, name)
291            if value is not None:
292              anno.setanno(node, anno.Static.VALUE, value)
293
294    elif isinstance(node.ctx, gast.Param):
295      # The direct parent it the whole function scope. See activity.py.
296      f_is_local = self.scope.parent.parent is not None
297
298      type_name = anno.getanno(node.annotation, anno.Basic.QN, None)
299      types = self.resolver.res_arg(self.namespace, self.types_in.types,
300                                    self.scope.function_name, name, type_name,
301                                    f_is_local)
302      if types is not None:
303        self.new_symbols[name] = types
304
305    elif isinstance(node.ctx, gast.Store):
306      if self.rtype is not None:
307        self.new_symbols[name] = self.rtype
308      types = self.rtype
309
310    else:
311      assert False, 'unknown ctx'
312
313    if __debug__:
314      self._check_set(types)
315
316    return types
317
318  def visit_Attribute(self, node):
319    parent_types = self.visit(node.value)
320
321    # Attempt to use the static value if known.
322    parent_value = anno.Static.VALUE.of(node.value, None)
323    if parent_value is not None:
324      static_value = getattr(parent_value, node.attr, NO_VALUE)
325
326      if static_value is NO_VALUE:
327        # Unexpected failure to resolve attribute. Ask the resolver about the
328        # full name instead.
329        types, static_value = self.resolver.res_name(
330            self.namespace, self.types_in, anno.Basic.QN.of(node))
331        anno.setanno(node, anno.Static.VALUE, static_value)
332        if __debug__:
333          self._check_set(types)
334        return types
335
336    else:
337      # Fall back to the type if that is known.
338      if parent_types is None:
339        return None
340
341      inferred_values = [getattr(t, node.attr, None) for t in parent_types]
342      if not inferred_values:
343        return None
344
345      static_value = inferred_values[0]
346      if static_value is None:
347        return None
348
349      if any(v is not static_value for v in inferred_values[1:]):
350        # Static value not stable, assume it's dynamic.
351        return None
352
353    types = self.resolver.res_value(self.namespace, static_value)
354    anno.setanno(node, anno.Static.VALUE, static_value)
355
356    if __debug__:
357      self._check_set(types)
358
359    return types
360
361  def visit_FunctionDef(self, node):
362    f_name = qual_names.QN(node.name)
363
364    if node.decorator_list:
365      raise NotImplementedError('decorators: {}'.format(node.decorator_list))
366
367    ret_types = None
368    if node.returns:
369      ret_types, _ = self.resolver.res_name(
370          self.namespace, self.types_in.types, anno.Basic.QN.of(node.returns))
371      if __debug__:
372        self._check_set(ret_types)
373
374    if ret_types is None:
375      ret_types = {Any}
376
377    f_types = set()
378    for rt in ret_types:
379      f_types.add(Callable[[Any], rt])
380
381    self.new_symbols[f_name] = f_types
382    # The definition of a function is an expression, hence has no return value.
383    return None
384
385  def _resolve_typed_callable(self, f_types, arg_types, keyword_types):
386    ret_types = set()
387    for t in f_types:
388
389      if isinstance(t, Callable):
390        # Note: these are undocummented - may be version-specific!
391        # Callable[[x], y]: __args__ are (x, y)
392        args = t.__args__
393        if args:
394          ret_types.add(args[-1])
395        else:
396          ret_types.add(Any)
397      else:
398        raise NotImplementedError('callable type {}'.format(type(t)))
399
400    # Side effects can not be inferred based on type alone.
401    side_effects = None
402    return ret_types, side_effects
403
404  def visit_Call(self, node):
405    self.visit(node.func)
406
407    f_name = anno.Basic.QN.of(node.func)
408    arg_types = [self.visit(a) for a in node.args]
409    keyword_types = [self.visit(kw.value) for kw in node.keywords]
410
411    if f_name in self.scope.bound:
412      # Local function, use local type definitions, if available.
413      f_type = self.types_in.types.get(f_name, None)
414      if f_type is None:
415        # No static type info available, nothing more to do.
416        ret_type, side_effects = None, None
417      else:
418        ret_type, side_effects = self._resolve_typed_callable(
419            f_type, arg_types, keyword_types)
420
421    else:
422      # Nonlocal function, resolve externally.
423      f_type = anno.Static.TYPES.of(node.func, None)
424      ret_type, side_effects = self.resolver.res_call(self.namespace,
425                                                      self.types_in.types, node,
426                                                      f_type, arg_types,
427                                                      keyword_types)
428
429    if __debug__:
430      self._check_set(ret_type)
431      if side_effects:
432        if not isinstance(side_effects, dict):
433          raise ValueError(
434              'side effects must be dict, got {}'.format(side_effects))
435        for k, v in side_effects.items():
436          if not isinstance(k, qual_names.QN):
437            raise ValueError('side effect keys must be QNs, got {}'.format(k))
438          self._check_set(v)
439
440    if side_effects:
441      self.new_symbols.update(side_effects)
442    return ret_type
443
444  def visit_Expr(self, node):
445    return self.visit(node.value)
446
447  def visit_Assign(self, node):
448    self.rtype = self.visit(node.value)
449
450    for t in node.targets:
451      self.visit(t)
452
453    self.rtype = None
454
455  def visit_Subscript(self, node):
456    val_types = self.visit(node.value)
457    slice_types = self.visit(node.slice)
458
459    if val_types is None or slice_types is None:
460      return None
461
462    types = self.resolver.res_slice(
463        self.namespace, self.types_in.types, node, val_types, slice_types)
464
465    if __debug__:
466      self._check_set(types)
467
468    return types
469
470  def visit_Compare(self, node):
471    left_types = self.visit(node.left)
472    right_types = [self.visit(c) for c in node.comparators]
473
474    if left_types is None or any(t is None for t in right_types):
475      return None
476
477    types = self.resolver.res_compare(
478        self.namespace, self.types_in.types, node, left_types, right_types)
479
480    if __debug__:
481      self._check_set(types)
482
483    return types
484
485  def visit_BinOp(self, node):
486    left_types = self.visit(node.left)
487    right_types = self.visit(node.right)
488
489    if left_types is None or right_types is None:
490      return None
491
492    types = self.resolver.res_binop(
493        self.namespace, self.types_in.types, node, left_types, right_types)
494
495    if __debug__:
496      self._check_set(types)
497
498    return types
499
500  def visit_UnaryOp(self, node):
501    opnd_types = self.visit(node.operand)
502
503    if opnd_types is None:
504      return None
505
506    types = self.resolver.res_unop(
507        self.namespace, self.types_in.types, node, opnd_types)
508
509    if __debug__:
510      self._check_set(types)
511
512    return types
513
514
515class Analyzer(cfg.GraphVisitor):
516  """CFG visitor that propagates type information across statements."""
517
518  def __init__(self, graph, resolver, namespace, scope, closure_types):
519    """Creates a new analyzer.
520
521    Args:
522      graph: cfg.Graph
523      resolver: Resolver
524      namespace: Dict[str, Any]
525      scope: activity.Scope
526      closure_types: Dict[QN, Set]
527    """
528    super(Analyzer, self).__init__(graph)
529    self.resolver = resolver
530    self.namespace = namespace
531    self.scope = scope
532    self.closure_types = closure_types
533
534    context_types = {
535        n: t for n, t in closure_types.items() if n not in scope.bound
536    }
537    if context_types:
538      self.context_types = _TypeMap()
539      self.context_types.types = context_types
540    else:
541      self.context_types = None
542
543  def init_state(self, _):
544    return _TypeMap()
545
546  def _update_closure_types(self, ast_node, types):
547    existing_types = anno.Static.CLOSURE_TYPES.of(ast_node, None)
548
549    if existing_types is None:
550      existing_types = {}
551      anno.Static.CLOSURE_TYPES.add_to(ast_node, existing_types)
552
553    for k, v in types.types.items():
554      if k in existing_types:
555        existing_types[k].update(v)
556      else:
557        existing_types[k] = set(v)
558
559  def visit_node(self, node):
560    prev_types_out = self.out[node]
561
562    types_in = _TypeMap()
563    for n in node.prev:
564      types_in |= self.out[n]
565    if (self.context_types is not None) and (node is self.graph.entry):
566      types_in |= self.context_types
567
568    types_out = _TypeMap(types_in)
569    ast_node = node.ast_node
570
571    inferrer = StmtInferrer(self.resolver, self.scope, self.namespace,
572                            self.closure_types, types_in)
573    inferrer.visit(ast_node)
574    types_out.types.update(inferrer.new_symbols)
575
576    reaching_fndefs = anno.Static.DEFINED_FNS_IN.of(ast_node)
577    node_scope = anno.Static.SCOPE.of(ast_node, None)
578    if node_scope is not None:
579      # TODO(mdan): Check that it's actually safe to skip nodes without scope.
580      reads = {str(qn) for qn in node_scope.read}
581      for def_node in reaching_fndefs:
582        if def_node.name in reads:
583          self._update_closure_types(def_node, types_out)
584
585    self.in_[node] = types_in
586    self.out[node] = types_out
587
588    return prev_types_out != types_out
589
590
591class FunctionVisitor(transformer.Base):
592  """AST visitor that applies type inference to each function separately."""
593
594  def __init__(self, source_info, graphs, resolver):
595    super(FunctionVisitor, self).__init__(source_info)
596    self.graphs = graphs
597    self.resolver = resolver
598
599  def visit_FunctionDef(self, node):
600    subgraph = self.graphs[node]
601    scope = anno.getanno(node, annos.NodeAnno.ARGS_AND_BODY_SCOPE)
602    closure_types = anno.getanno(node, anno.Static.CLOSURE_TYPES, {})
603
604    analyzer = Analyzer(subgraph, self.resolver, self.ctx.info.namespace, scope,
605                        closure_types)
606    analyzer.visit_forward()
607
608    # Recursively process any remaining subfunctions.
609    node.body = self.visit_block(node.body)
610
611    return node
612
613
614def resolve(node, source_info, graphs, resolver):
615  """Performs type inference.
616
617  Args:
618    node: ast.AST
619    source_info: transformer.SourceInfo
620    graphs: Dict[ast.FunctionDef, cfg.Graph]
621    resolver: Resolver
622
623  Returns:
624    ast.AST
625  """
626  visitor = FunctionVisitor(source_info, graphs, resolver)
627  node = visitor.visit(node)
628  return node
629