1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Converter construction support.
16
17This module contains a base class for all converters, as well as supporting
18structures. These structures are referred to as contexts.
19
20The class hierarchy is as follows:
21
22    <your converter>
23      [extends] converter.Base
24        [extends] transformer.Base
25            [extends] gast.nodeTransformer
26          [uses] transfomer.SourceInfo
27        [uses] converter.EntityContext
28          [uses] converter.ProgramContext
29          [uses] transfomer.SourceInfo
30
31converter.Base is a specialization of transformer.Base for AutoGraph. It's a
32very lightweight subclass that adds a `ctx` attribute holding the corresponding
33EntityContext object (see below). Note that converters are not reusable, and
34`visit` will raise an error if called more than once.
35
36converter.EntityContext contains mutable state associated with an entity that
37the converter processes.
38
39converter.ProgramContext contains mutable state across related entities. For
40example, when converting several functions that call one another, the
41ProgramContext should be shared across these entities.
42
43Below is the overall flow at conversion:
44
45    program_ctx = ProgramContext(<entities to convert>, <global settings>, ...)
46    while <program_ctx has more entities to convert>:
47      entity, source_info = <get next entity from program_ctx>
48      entity_ctx = EntityContext(program_ctx, source_info)
49      for <each ConverterClass>:
50        converter = ConverterClass(entity_ctx)
51
52        # May update entity_ctx and program_ctx
53        entity = converter.visit(entity)
54
55      <add entity's dependencies to program_ctx>
56
57Note that pyct contains a small number of transformers used for static analysis.
58These implement transformer.Base, rather than converter.Base, to avoid a
59dependency on AutoGraph.
60"""
61
62from __future__ import absolute_import
63from __future__ import division
64from __future__ import print_function
65
66import enum
67
68from tensorflow.python.autograph.core import config
69from tensorflow.python.autograph.pyct import anno
70from tensorflow.python.autograph.pyct import ast_util
71from tensorflow.python.autograph.pyct import cfg
72from tensorflow.python.autograph.pyct import compiler
73from tensorflow.python.autograph.pyct import parser
74from tensorflow.python.autograph.pyct import qual_names
75from tensorflow.python.autograph.pyct import templates
76from tensorflow.python.autograph.pyct import transformer
77from tensorflow.python.autograph.pyct.static_analysis import activity
78from tensorflow.python.autograph.pyct.static_analysis import live_values
79from tensorflow.python.autograph.pyct.static_analysis import liveness
80from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
81from tensorflow.python.autograph.pyct.static_analysis import type_info
82from tensorflow.python.util.tf_export import tf_export
83
84# TODO(mdan): These contexts can be refactored into first class objects.
85# For example, we could define Program and Entity abstractions that hold on
86# to the actual entity and have conversion methods.
87
88# TODO(mdan): Add a test specific to this converter.
89
90
91@tf_export('autograph.experimental.Feature')
92class Feature(enum.Enum):
93  """Represents conversion options that can be toggled on or off.
94
95  Attributes:
96    ALL: Enable all features.
97    AUTO_CONTROL_DEPS: Insert of control dependencies in the generated code.
98    ASSERT_STATEMENTS: Convert Tensor-dependent assert statements to tf.Assert.
99    BUILTIN_FUNCTIONS: Convert builtin functions applied to Tensors to
100      their TF counterparts.
101    ERROR_REWRITING: Rewrite errors that occur in the generated code to
102      indicate the source code to which the failing code corresponds.
103    LISTS: Convert list idioms, like initializers, slices, append, etc.
104    LOGICAL_EXPRESSIONS: Convert data-dependent logical expressions applied to
105      Tensors to their TF counterparts.
106    NAME_SCOPES: Insert name scopes that name ops according to context, like the
107      function they were defined in.
108  """
109
110  ALL = 'ALL'
111
112  AUTO_CONTROL_DEPS = 'AUTO_CONTROL_DEPS'
113  ASSERT_STATEMENTS = 'ASSERT_STATEMENTS'
114  BUILTIN_FUNCTIONS = 'BUILTIN_FUNCTIONS'
115  ERROR_REWRITING = 'ERROR_REWRITING'
116  LISTS = 'LISTS'
117  LOGICAL_EXPRESSIONS = 'LOGICAL_EXPRESSIONS'
118  NAME_SCOPES = 'NAME_SCOPES'
119
120  @classmethod
121  def all(cls):
122    """Returns a tuple that enables all options."""
123    return tuple(cls.__members__.values())
124
125  @classmethod
126  def all_but(cls, exclude):
127    """Returns a tuple that enables all but the excluded options."""
128    if not isinstance(exclude, (list, tuple, set)):
129      exclude = (exclude,)
130    return tuple(set(cls.all()) - set(exclude) - {cls.ALL})
131
132
133class ConversionOptions(object):
134  """Immutable container for global conversion flags.
135
136  Attributes:
137    recursive: bool, whether to recursively convert any user functions or
138      classes that the converted function may use.
139    force_conversion: bool, whether to force convertinng the target entity. When
140      force_conversion is turned off, the converter may decide to return the
141      function as-is.
142    optional_features: Union[Feature, Set[Feature]], controls the use of
143      optional features in the conversion process. See Feature for available
144      options.
145  """
146
147  def __init__(self,
148               recursive=False,
149               force_conversion=False,
150               internal_convert_user_code=True,
151               optional_features=Feature.ALL):
152    self.recursive = recursive
153    self.force_conversion = force_conversion
154    # TODO(mdan): Rename to conversion_recursion_depth?
155    self.internal_convert_user_code = internal_convert_user_code
156
157    if optional_features is None:
158      optional_features = ()
159    elif isinstance(optional_features, Feature):
160      optional_features = (optional_features,)
161    optional_features = frozenset(optional_features)
162    self.optional_features = optional_features
163
164  def uses(self, feature):
165    return (Feature.ALL in self.optional_features or
166            feature in self.optional_features)
167
168  def to_ast(self, internal_convert_user_code=None):
169    """Returns a representation of this object as an AST node.
170
171    The AST node encodes a constructor that would create an object with the
172    same contents.
173
174    Args:
175      internal_convert_user_code: Optional[bool], allows ovrriding the
176        corresponding value.
177
178    Returns:
179      ast.Node
180    """
181    template = """
182      ag__.ConversionOptions(
183          recursive=recursive_val,
184          force_conversion=force_conversion_val,
185          optional_features=optional_features_val,
186          internal_convert_user_code=internal_convert_user_code_val)
187    """
188
189    def list_of_features(values):
190      return parser.parse_expression('({})'.format(', '.join(
191          'ag__.{}'.format(str(v)) for v in values)))
192
193    if internal_convert_user_code is None:
194      internal_convert_user_code = self.internal_convert_user_code
195
196    expr_ast = templates.replace(
197        template,
198        recursive_val=parser.parse_expression(str(self.recursive)),
199        force_conversion_val=parser.parse_expression(
200            str(self.force_conversion)),
201        internal_convert_user_code_val=parser.parse_expression(
202            str(internal_convert_user_code)),
203        optional_features_val=list_of_features(self.optional_features))
204    return expr_ast[0].value
205
206
207class ProgramContext(object):
208  """ProgramContext keeps track of converting function hierarchies.
209
210  This object is mutable, and is updated during conversion. Not thread safe.
211
212  Attributes:
213    options: ConversionOptions
214    autograph_module: Module, a reference to the autograph module. This needs to
215      be specified by the caller to avoid circular dependencies.
216    required_imports: str, containing an import statement on each line. These
217      are all the imports necessary for the compiled code to run, in addition to
218      the closures of each entity, which are attached dynamically.
219  """
220
221  def __init__(
222      self,
223      options,
224      autograph_module,
225  ):
226    self.options = options
227    self.autograph_module = autograph_module
228
229  @property
230  def required_imports(self):
231    """Returns a block containing all imports required by the converted code."""
232    # TODO(mdan): Check that these don't clobber one another.
233    return '\n'.join(config.COMPILED_IMPORT_STATEMENTS)
234
235
236class EntityContext(transformer.Context):
237  """Tracks the conversion of a single entity.
238
239  This object is mutable, and is updated during conversion. Not thread safe.
240
241  Attributes:
242    namer: Namer
243    info: transformer.EntityInfo
244    program: ProgramContext
245  """
246
247  def __init__(self, namer, entity_info, program_ctx):
248    super(EntityContext, self).__init__(entity_info)
249    self.namer = namer
250    self.program = program_ctx
251
252
253class Base(transformer.Base):
254  """All converters should inherit from this class.
255
256  Attributes:
257    ctx: EntityContext
258  """
259
260  def __init__(self, ctx):
261    super(Base, self).__init__(ctx)
262
263    self._used = False
264    self._ast_depth = 0
265
266  def get_definition_directive(self, node, directive, arg, default):
267    """Returns the unique directive argument for a symbol.
268
269    See lang/directives.py for details on directives.
270
271    Example:
272       # Given a directive in the code:
273       ag.foo_directive(bar, baz=1)
274
275       # One can write for an AST node Name(id='bar'):
276       get_definition_directive(node, ag.foo_directive, 'baz')
277
278    Args:
279      node: ast.AST, the node representing the symbol for which the directive
280        argument is needed.
281      directive: Callable[..., Any], the directive to search.
282      arg: str, the directive argument to return.
283      default: Any
284
285    Raises:
286      ValueError: if conflicting annotations have been found
287    """
288    defs = anno.getanno(node, anno.Static.ORIG_DEFINITIONS, ())
289    if not defs:
290      return default
291
292    arg_values_found = []
293    for def_ in defs:
294      if (directive in def_.directives and arg in def_.directives[directive]):
295        arg_values_found.append(def_.directives[directive][arg])
296
297    if not arg_values_found:
298      return default
299
300    if len(arg_values_found) == 1:
301      return arg_values_found[0]
302
303    # If multiple annotations reach the symbol, they must all match. If they do,
304    # return any of them.
305    first_value = arg_values_found[0]
306    for other_value in arg_values_found[1:]:
307      if not ast_util.matches(first_value, other_value):
308        qn = anno.getanno(node, anno.Basic.QN)
309        raise ValueError('%s has ambiguous annotations for %s(%s): %s, %s' %
310                         (qn, directive.__name__, arg,
311                          compiler.ast_to_source(other_value).strip(),
312                          compiler.ast_to_source(first_value).strip()))
313    return first_value
314
315  def visit(self, node):
316    if not self._ast_depth:
317      if self._used:
318        raise ValueError('converter objects cannot be reused')
319      self._used = True
320
321    self._ast_depth += 1
322    try:
323      return super(Base, self).visit(node)
324    finally:
325      self._ast_depth -= 1
326
327
328class AnnotatedDef(reaching_definitions.Definition):
329
330  def __init__(self):
331    super(AnnotatedDef, self).__init__()
332    self.directives = {}
333
334
335class AgAnno(enum.Enum):
336  """Annotation labels specific to AutoGraph. See anno.py."""
337
338  DIRECTIVES = 'User directives associated with the annotated statement.'
339
340  def __repr__(self):
341    return self.name
342
343
344def standard_analysis(node, context, is_initial=False):
345  """Performs a complete static analysis of the given code.
346
347  Args:
348    node: ast.AST
349    context: converter.EntityContext
350    is_initial: bool, whether this is the initial analysis done on the input
351      source code
352
353  Returns:
354    ast.AST, same as node, with the static analysis annotations added
355  """
356  # TODO(mdan): Clear static analysis here.
357  # TODO(mdan): Consider not running all analyses every time.
358  # TODO(mdan): Don't return a node because it's modified by reference.
359  graphs = cfg.build(node)
360  node = qual_names.resolve(node)
361  node = activity.resolve(node, context, None)
362  node = reaching_definitions.resolve(node, context, graphs, AnnotatedDef)
363  node = liveness.resolve(node, context, graphs)
364  node = live_values.resolve(node, context, config.PYTHON_LITERALS)
365  node = type_info.resolve(node, context)
366  # This second call allows resolving first-order class attributes.
367  node = live_values.resolve(node, context, config.PYTHON_LITERALS)
368  if is_initial:
369    anno.dup(
370        node,
371        {
372            anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS,
373        },
374    )
375  return node
376
377
378def apply_(node, context, converter_module):
379  """Applies a converter to an AST.
380
381  Args:
382    node: ast.AST
383    context: converter.EntityContext
384    converter_module: converter.Base
385
386  Returns:
387    ast.AST, the result of applying converter to node
388  """
389  node = standard_analysis(node, context)
390  node = converter_module.transform(node, context)
391  return node
392