1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""AST node annotation support.
16
17Adapted from Tangent.
18"""
19
20from __future__ import absolute_import
21from __future__ import division
22from __future__ import print_function
23
24import enum
25
26# pylint:disable=g-bad-import-order
27import gast
28# pylint:enable=g-bad-import-order
29
30
31# TODO(mdan): Shorten the names.
32# These names are heavily used, and anno.blaa
33# TODO(mdan): Replace the attr-dict mechanism with a more typed solution.
34
35
36class NoValue(enum.Enum):
37
38  def __repr__(self):
39    return self.name
40
41
42class Basic(NoValue):
43  """Container for basic annotation keys.
44
45  The enum values are used strictly for documentation purposes.
46  """
47
48  QN = 'Qualified name, as it appeared in the code. See qual_names.py.'
49  SKIP_PROCESSING = (
50      'This node should be preserved as is and not processed any further.')
51  INDENT_BLOCK_REMAINDER = (
52      'When a node is annotated with this, the remainder of the block should'
53      ' be indented below it. The annotation contains a tuple'
54      ' (new_body, name_map), where `new_body` is the new indented block and'
55      ' `name_map` allows renaming symbols.')
56  ORIGIN = ('Information about the source code that converted code originated'
57            ' from. See origin_information.py.')
58
59
60class Static(NoValue):
61  """Container for static analysis annotation keys.
62
63  The enum values are used strictly for documentation purposes.
64  """
65
66  # Symbols
67  # These flags are boolean.
68  IS_PARAM = 'Symbol is a parameter to the function being analyzed.'
69
70  # Scopes
71  # Scopes are represented by objects of type activity.Scope.
72  SCOPE = 'The scope for the annotated node. See activity.py.'
73  # TODO(mdan): Drop these in favor of accessing the child's SCOPE.
74  ARGS_SCOPE = 'The scope for the argument list of a function call.'
75  COND_SCOPE = 'The scope for the test node of a conditional statement.'
76  BODY_SCOPE = (
77      'The scope for the main body of a statement (True branch for if '
78      'statements, main body for loops).')
79  ORELSE_SCOPE = (
80      'The scope for the orelse body of a statement (False branch for if '
81      'statements, orelse body for loops).')
82
83  # Static analysis annotations.
84  DEFINITIONS = (
85      'Reaching definition information. See reaching_definitions.py.')
86  ORIG_DEFINITIONS = (
87      'The value of DEFINITIONS that applied to the original code before any'
88      ' conversion.')
89  DEFINED_VARS_IN = (
90      'Symbols defined when entering the node. See reaching_definitions.py.')
91  LIVE_VARS_OUT = ('Symbols live when exiting the node. See liveness.py.')
92  LIVE_VARS_IN = ('Symbols live when entering the node. See liveness.py.')
93
94
95FAIL = object()
96
97
98def keys(node, field_name='___pyct_anno'):
99  if not hasattr(node, field_name):
100    return frozenset()
101  return frozenset(getattr(node, field_name).keys())
102
103
104def getanno(node, key, default=FAIL, field_name='___pyct_anno'):
105  if (default is FAIL or (hasattr(node, field_name) and
106                          (key in getattr(node, field_name)))):
107    return getattr(node, field_name)[key]
108  else:
109    return default
110
111
112def hasanno(node, key, field_name='___pyct_anno'):
113  return hasattr(node, field_name) and key in getattr(node, field_name)
114
115
116def setanno(node, key, value, field_name='___pyct_anno'):
117  annotations = getattr(node, field_name, {})
118  setattr(node, field_name, annotations)
119  annotations[key] = value
120
121  # So that the annotations survive gast_to_ast() and ast_to_gast()
122  if field_name not in node._fields:
123    node._fields += (field_name,)
124
125
126def delanno(node, key, field_name='___pyct_anno'):
127  annotations = getattr(node, field_name)
128  del annotations[key]
129  if not annotations:
130    delattr(node, field_name)
131    node._fields = tuple(f for f in node._fields if f != field_name)
132
133
134def copyanno(from_node, to_node, key, field_name='___pyct_anno'):
135  if hasanno(from_node, key, field_name=field_name):
136    setanno(
137        to_node,
138        key,
139        getanno(from_node, key, field_name=field_name),
140        field_name=field_name)
141
142
143def dup(node, copy_map, field_name='___pyct_anno'):
144  """Recursively copies annotations in an AST tree.
145
146  Args:
147    node: ast.AST
148    copy_map: Dict[Hashable, Hashable], maps a source anno key to a destination
149        key. All annotations with the source key will be copied to identical
150        annotations with the destination key.
151    field_name: str
152  """
153  for n in gast.walk(node):
154    for k in copy_map:
155      if hasanno(n, k, field_name):
156        setanno(n, copy_map[k], getanno(n, k, field_name), field_name)
157