1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""AST node annotation support.
16
17Adapted from Tangent.
18"""
19
20from __future__ import absolute_import
21from __future__ import division
22from __future__ import print_function
23
24import enum
25
26# pylint:disable=g-bad-import-order
27
28import gast
29# pylint:enable=g-bad-import-order
30
31
32# TODO(mdan): Shorten the names.
33# These names are heavily used, and anno.blaa
34# TODO(mdan): Replace the attr-dict mechanism with a more typed solution.
35
36
37class NoValue(enum.Enum):
38  """Base class for different types of AST annotations."""
39
40  def of(self, node, default=None):
41    return getanno(node, self, default=default)
42
43  def add_to(self, node, value):
44    setanno(node, self, value)
45
46  def exists(self, node):
47    return hasanno(node, self)
48
49  def __repr__(self):
50    return str(self.name)
51
52
53class Basic(NoValue):
54  """Container for basic annotation keys.
55
56  The enum values are used strictly for documentation purposes.
57  """
58
59  QN = 'Qualified name, as it appeared in the code. See qual_names.py.'
60  SKIP_PROCESSING = (
61      'This node should be preserved as is and not processed any further.')
62  INDENT_BLOCK_REMAINDER = (
63      'When a node is annotated with this, the remainder of the block should'
64      ' be indented below it. The annotation contains a tuple'
65      ' (new_body, name_map), where `new_body` is the new indented block and'
66      ' `name_map` allows renaming symbols.')
67  ORIGIN = ('Information about the source code that converted code originated'
68            ' from. See origin_information.py.')
69  DIRECTIVES = ('User directives associated with a statement or a variable.'
70                ' Typically, they affect the immediately-enclosing statement.')
71
72  EXTRA_LOOP_TEST = (
73      'A special annotation containing additional test code to be executed in'
74      ' for loops.')
75
76
77class Static(NoValue):
78  """Container for static analysis annotation keys.
79
80  The enum values are used strictly for documentation purposes.
81  """
82
83  # Symbols
84  # These flags are boolean.
85  IS_PARAM = 'Symbol is a parameter to the function being analyzed.'
86
87  # Scopes
88  # Scopes are represented by objects of type activity.Scope.
89  SCOPE = 'The scope for the annotated node. See activity.py.'
90  # TODO(mdan): Drop these in favor of accessing the child's SCOPE.
91  ARGS_SCOPE = 'The scope for the argument list of a function call.'
92  COND_SCOPE = 'The scope for the test node of a conditional statement.'
93  BODY_SCOPE = (
94      'The scope for the main body of a statement (True branch for if '
95      'statements, main body for loops).')
96  ORELSE_SCOPE = (
97      'The scope for the orelse body of a statement (False branch for if '
98      'statements, orelse body for loops).')
99
100  # Static analysis annotations.
101  DEFINITIONS = (
102      'Reaching definition information. See reaching_definitions.py.')
103  ORIG_DEFINITIONS = (
104      'The value of DEFINITIONS that applied to the original code before any'
105      ' conversion.')
106  DEFINED_FNS_IN = (
107      'Local function definitions that may exist when exiting the node. See'
108      ' reaching_fndefs.py')
109  DEFINED_VARS_IN = (
110      'Symbols defined when entering the node. See reaching_definitions.py.')
111  LIVE_VARS_OUT = ('Symbols live when exiting the node. See liveness.py.')
112  LIVE_VARS_IN = ('Symbols live when entering the node. See liveness.py.')
113  TYPES = 'Static type information. See type_inference.py.'
114  CLOSURE_TYPES = 'Types of closure symbols at each detected call site.'
115  VALUE = 'Static value information. See type_inference.py.'
116
117
118FAIL = object()
119
120
121def keys(node, field_name='___pyct_anno'):
122  if not hasattr(node, field_name):
123    return frozenset()
124  return frozenset(getattr(node, field_name).keys())
125
126
127def getanno(node, key, default=FAIL, field_name='___pyct_anno'):
128  if (default is FAIL or (hasattr(node, field_name) and
129                          (key in getattr(node, field_name)))):
130    return getattr(node, field_name)[key]
131  return default
132
133
134def hasanno(node, key, field_name='___pyct_anno'):
135  return hasattr(node, field_name) and key in getattr(node, field_name)
136
137
138def setanno(node, key, value, field_name='___pyct_anno'):
139  annotations = getattr(node, field_name, {})
140  setattr(node, field_name, annotations)
141  annotations[key] = value
142
143  # So that the annotations survive gast_to_ast() and ast_to_gast()
144  if field_name not in node._fields:
145    node._fields += (field_name,)
146
147
148def delanno(node, key, field_name='___pyct_anno'):
149  annotations = getattr(node, field_name)
150  del annotations[key]
151  if not annotations:
152    delattr(node, field_name)
153    node._fields = tuple(f for f in node._fields if f != field_name)
154
155
156def copyanno(from_node, to_node, key, field_name='___pyct_anno'):
157  if hasanno(from_node, key, field_name=field_name):
158    setanno(
159        to_node,
160        key,
161        getanno(from_node, key, field_name=field_name),
162        field_name=field_name)
163
164
165def dup(node, copy_map, field_name='___pyct_anno'):
166  """Recursively copies annotations in an AST tree.
167
168  Args:
169    node: ast.AST
170    copy_map: Dict[Hashable, Hashable], maps a source anno key to a destination
171        key. All annotations with the source key will be copied to identical
172        annotations with the destination key.
173    field_name: str
174  """
175  for n in gast.walk(node):
176    for k in copy_map:
177      if hasanno(n, k, field_name):
178        setanno(n, copy_map[k], getanno(n, k, field_name), field_name)
179