1# Lint as: python2, python3
2# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Upgrader for Python scripts according to an API change specification."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import ast
23import collections
24import os
25import re
26import shutil
27import sys
28import tempfile
29import traceback
30
31import pasta
32import six
33from six.moves import range
34
35# Some regular expressions we will need for parsing
36FIND_OPEN = re.compile(r"^\s*(\[).*$")
37FIND_STRING_CHARS = re.compile(r"['\"]")
38
39
40INFO = "INFO"
41WARNING = "WARNING"
42ERROR = "ERROR"
43
44
45ImportRename = collections.namedtuple(
46    "ImportRename", ["new_name", "excluded_prefixes"])
47
48
49def full_name_node(name, ctx=ast.Load()):
50  """Make an Attribute or Name node for name.
51
52  Translate a qualified name into nested Attribute nodes (and a Name node).
53
54  Args:
55    name: The name to translate to a node.
56    ctx: What context this name is used in. Defaults to Load()
57
58  Returns:
59    A Name or Attribute node.
60  """
61  names = six.ensure_str(name).split(".")
62  names.reverse()
63  node = ast.Name(id=names.pop(), ctx=ast.Load())
64  while names:
65    node = ast.Attribute(value=node, attr=names.pop(), ctx=ast.Load())
66
67  # Change outermost ctx to the one given to us (inner ones should be Load).
68  node.ctx = ctx
69  return node
70
71
72def get_arg_value(node, arg_name, arg_pos=None):
73  """Get the value of an argument from a ast.Call node.
74
75  This function goes through the positional and keyword arguments to check
76  whether a given argument was used, and if so, returns its value (the node
77  representing its value).
78
79  This cannot introspect *args or **args, but it safely handles *args in
80  Python3.5+.
81
82  Args:
83    node: The ast.Call node to extract arg values from.
84    arg_name: The name of the argument to extract.
85    arg_pos: The position of the argument (in case it's passed as a positional
86      argument).
87
88  Returns:
89    A tuple (arg_present, arg_value) containing a boolean indicating whether
90    the argument is present, and its value in case it is.
91  """
92  # Check keyword args
93  if arg_name is not None:
94    for kw in node.keywords:
95      if kw.arg == arg_name:
96        return (True, kw.value)
97
98  # Check positional args
99  if arg_pos is not None:
100    idx = 0
101    for arg in node.args:
102      if sys.version_info[:2] >= (3, 5) and isinstance(arg, ast.Starred):
103        continue  # Can't parse Starred
104      if idx == arg_pos:
105        return (True, arg)
106      idx += 1
107
108  return (False, None)
109
110
111def uses_star_args_in_call(node):
112  """Check if an ast.Call node uses arbitrary-length positional *args.
113
114  This function works with the AST call node format of Python3.5+
115  as well as the different AST format of earlier versions of Python.
116
117  Args:
118    node: The ast.Call node to check arg values for.
119
120  Returns:
121    True if the node uses starred variadic positional args or keyword args.
122    False if it does not.
123  """
124  if sys.version_info[:2] >= (3, 5):
125    # Check for an *args usage in python 3.5+
126    for arg in node.args:
127      if isinstance(arg, ast.Starred):
128        return True
129  else:
130    if node.starargs:
131      return True
132  return False
133
134
135def uses_star_kwargs_in_call(node):
136  """Check if an ast.Call node uses arbitrary-length **kwargs.
137
138  This function works with the AST call node format of Python3.5+
139  as well as the different AST format of earlier versions of Python.
140
141  Args:
142    node: The ast.Call node to check arg values for.
143
144  Returns:
145    True if the node uses starred variadic positional args or keyword args.
146    False if it does not.
147  """
148  if sys.version_info[:2] >= (3, 5):
149    # Check for a **kwarg usage in python 3.5+
150    for keyword in node.keywords:
151      if keyword.arg is None:
152        return True
153  else:
154    if node.kwargs:
155      return True
156  return False
157
158
159def uses_star_args_or_kwargs_in_call(node):
160  """Check if an ast.Call node uses arbitrary-length *args or **kwargs.
161
162  This function works with the AST call node format of Python3.5+
163  as well as the different AST format of earlier versions of Python.
164
165  Args:
166    node: The ast.Call node to check arg values for.
167
168  Returns:
169    True if the node uses starred variadic positional args or keyword args.
170    False if it does not.
171  """
172  return uses_star_args_in_call(node) or uses_star_kwargs_in_call(node)
173
174
175def excluded_from_module_rename(module, import_rename_spec):
176  """Check if this module import should not be renamed.
177
178  Args:
179    module: (string) module name.
180    import_rename_spec: ImportRename instance.
181
182  Returns:
183    True if this import should not be renamed according to the
184    import_rename_spec.
185  """
186  for excluded_prefix in import_rename_spec.excluded_prefixes:
187    if module.startswith(excluded_prefix):
188      return True
189  return False
190
191
192class APIChangeSpec(object):
193  """This class defines the transformations that need to happen.
194
195  This class must provide the following fields:
196
197  * `function_keyword_renames`: maps function names to a map of old -> new
198    argument names
199  * `symbol_renames`: maps function names to new function names
200  * `change_to_function`: a set of function names that have changed (for
201    notifications)
202  * `function_reorders`: maps functions whose argument order has changed to the
203    list of arguments in the new order
204  * `function_warnings`: maps full names of functions to warnings that will be
205    printed out if the function is used. (e.g. tf.nn.convolution())
206  * `function_transformers`: maps function names to custom handlers
207  * `module_deprecations`: maps module names to warnings that will be printed
208    if the module is still used after all other transformations have run
209  * `import_renames`: maps import name (must be a short name without '.')
210    to ImportRename instance.
211
212  For an example, see `TFAPIChangeSpec`.
213  """
214
215  def preprocess(self, root_node):  # pylint: disable=unused-argument
216    """Preprocess a parse tree. Return a preprocessed node, logs and errors."""
217    return root_node, [], []
218
219  def clear_preprocessing(self):
220    """Restore this APIChangeSpec to before it preprocessed a file.
221
222    This is needed if preprocessing a file changed any rewriting rules.
223    """
224    pass
225
226
227class NoUpdateSpec(APIChangeSpec):
228  """A specification of an API change which doesn't change anything."""
229
230  def __init__(self):
231    self.function_handle = {}
232    self.function_reorders = {}
233    self.function_keyword_renames = {}
234    self.symbol_renames = {}
235    self.function_warnings = {}
236    self.change_to_function = {}
237    self.module_deprecations = {}
238    self.function_transformers = {}
239    self.import_renames = {}
240
241
242class _PastaEditVisitor(ast.NodeVisitor):
243  """AST Visitor that processes function calls.
244
245  Updates function calls from old API version to new API version using a given
246  change spec.
247  """
248
249  def __init__(self, api_change_spec):
250    self._api_change_spec = api_change_spec
251    self._log = []   # Holds 4-tuples: severity, line, col, msg.
252    self._stack = []  # Allow easy access to parents.
253
254  # Overridden to maintain a stack of nodes to allow for parent access
255  def visit(self, node):
256    self._stack.append(node)
257    super(_PastaEditVisitor, self).visit(node)
258    self._stack.pop()
259
260  @property
261  def errors(self):
262    return [log for log in self._log if log[0] == ERROR]
263
264  @property
265  def warnings(self):
266    return [log for log in self._log if log[0] == WARNING]
267
268  @property
269  def warnings_and_errors(self):
270    return [log for log in self._log if log[0] in (WARNING, ERROR)]
271
272  @property
273  def info(self):
274    return [log for log in self._log if log[0] == INFO]
275
276  @property
277  def log(self):
278    return self._log
279
280  def add_log(self, severity, lineno, col, msg):
281    self._log.append((severity, lineno, col, msg))
282    print("%s line %d:%d: %s" % (severity, lineno, col, msg))
283
284  def add_logs(self, logs):
285    """Record a log and print it.
286
287    The log should be a tuple `(severity, lineno, col_offset, msg)`, which will
288    be printed and recorded. It is part of the log available in the `self.log`
289    property.
290
291    Args:
292      logs: The logs to add. Must be a list of tuples
293        `(severity, lineno, col_offset, msg)`.
294    """
295    self._log.extend(logs)
296    for log in logs:
297      print("%s line %d:%d: %s" % log)
298
299  def _get_applicable_entries(self, transformer_field, full_name, name):
300    """Get all list entries indexed by name that apply to full_name or name."""
301    # Transformers are indexed to full name, name, or no name
302    # as a performance optimization.
303    function_transformers = getattr(self._api_change_spec,
304                                    transformer_field, {})
305
306    glob_name = "*." + six.ensure_str(name) if name else None
307    transformers = []
308    if full_name in function_transformers:
309      transformers.append(function_transformers[full_name])
310    if glob_name in function_transformers:
311      transformers.append(function_transformers[glob_name])
312    if "*" in function_transformers:
313      transformers.append(function_transformers["*"])
314    return transformers
315
316  def _get_applicable_dict(self, transformer_field, full_name, name):
317    """Get all dict entries indexed by name that apply to full_name or name."""
318    # Transformers are indexed to full name, name, or no name
319    # as a performance optimization.
320    function_transformers = getattr(self._api_change_spec,
321                                    transformer_field, {})
322
323    glob_name = "*." + six.ensure_str(name) if name else None
324    transformers = function_transformers.get("*", {}).copy()
325    transformers.update(function_transformers.get(glob_name, {}))
326    transformers.update(function_transformers.get(full_name, {}))
327    return transformers
328
329  def _get_full_name(self, node):
330    """Traverse an Attribute node to generate a full name, e.g., "tf.foo.bar".
331
332    This is the inverse of `full_name_node`.
333
334    Args:
335      node: A Node of type Attribute.
336
337    Returns:
338      a '.'-delimited full-name or None if node was not Attribute or Name.
339      i.e. `foo()+b).bar` returns None, while `a.b.c` would return "a.b.c".
340    """
341    curr = node
342    items = []
343    while not isinstance(curr, ast.Name):
344      if not isinstance(curr, ast.Attribute):
345        return None
346      items.append(curr.attr)
347      curr = curr.value
348    items.append(curr.id)
349    return ".".join(reversed(items))
350
351  def _maybe_add_warning(self, node, full_name):
352    """Adds an error to be printed about full_name at node."""
353    function_warnings = self._api_change_spec.function_warnings
354    if full_name in function_warnings:
355      level, message = function_warnings[full_name]
356      message = six.ensure_str(message).replace("<function name>", full_name)
357      self.add_log(level, node.lineno, node.col_offset,
358                   "%s requires manual check. %s" % (full_name, message))
359      return True
360    else:
361      return False
362
363  def _maybe_add_module_deprecation_warning(self, node, full_name, whole_name):
364    """Adds a warning if full_name is a deprecated module."""
365    warnings = self._api_change_spec.module_deprecations
366    if full_name in warnings:
367      level, message = warnings[full_name]
368      message = six.ensure_str(message).replace("<function name>",
369                                                six.ensure_str(whole_name))
370      self.add_log(level, node.lineno, node.col_offset,
371                   "Using member %s in deprecated module %s. %s" % (whole_name,
372                                                                    full_name,
373                                                                    message))
374      return True
375    else:
376      return False
377
378  def _maybe_add_call_warning(self, node, full_name, name):
379    """Print a warning when specific functions are called with selected args.
380
381    The function _print_warning_for_function matches the full name of the called
382    function, e.g., tf.foo.bar(). This function matches the function name that
383    is called, as long as the function is an attribute. For example,
384    `tf.foo.bar()` and `foo.bar()` are matched, but not `bar()`.
385
386    Args:
387      node: ast.Call object
388      full_name: The precomputed full name of the callable, if one exists, None
389        otherwise.
390      name: The precomputed name of the callable, if one exists, None otherwise.
391
392    Returns:
393      Whether an error was recorded.
394    """
395    # Only look for *.-warnings here, the other will be handled by the Attribute
396    # visitor. Also, do not warn for bare functions, only if the call func is
397    # an attribute.
398    warned = False
399    if isinstance(node.func, ast.Attribute):
400      warned = self._maybe_add_warning(node, "*." + six.ensure_str(name))
401
402    # All arg warnings are handled here, since only we have the args
403    arg_warnings = self._get_applicable_dict("function_arg_warnings",
404                                             full_name, name)
405
406    variadic_args = uses_star_args_or_kwargs_in_call(node)
407
408    for (kwarg, arg), (level, warning) in sorted(arg_warnings.items()):
409      present, _ = get_arg_value(node, kwarg, arg) or variadic_args
410      if present:
411        warned = True
412        warning_message = six.ensure_str(warning).replace(
413            "<function name>", six.ensure_str(full_name or name))
414        template = "%s called with %s argument, requires manual check: %s"
415        if variadic_args:
416          template = ("%s called with *args or **kwargs that may include %s, "
417                      "requires manual check: %s")
418        self.add_log(level, node.lineno, node.col_offset,
419                     template % (full_name or name, kwarg, warning_message))
420
421    return warned
422
423  def _maybe_rename(self, parent, node, full_name):
424    """Replace node (Attribute or Name) with a node representing full_name."""
425    new_name = self._api_change_spec.symbol_renames.get(full_name, None)
426    if new_name:
427      self.add_log(INFO, node.lineno, node.col_offset,
428                   "Renamed %r to %r" % (full_name, new_name))
429      new_node = full_name_node(new_name, node.ctx)
430      ast.copy_location(new_node, node)
431      pasta.ast_utils.replace_child(parent, node, new_node)
432      return True
433    else:
434      return False
435
436  def _maybe_change_to_function_call(self, parent, node, full_name):
437    """Wraps node (typically, an Attribute or Expr) in a Call."""
438    if full_name in self._api_change_spec.change_to_function:
439      if not isinstance(parent, ast.Call):
440        # ast.Call's constructor is really picky about how many arguments it
441        # wants, and also, it changed between Py2 and Py3.
442        if six.PY2:
443          new_node = ast.Call(node, [], [], None, None)
444        else:
445          new_node = ast.Call(node, [], [])
446        pasta.ast_utils.replace_child(parent, node, new_node)
447        ast.copy_location(new_node, node)
448        self.add_log(INFO, node.lineno, node.col_offset,
449                     "Changed %r to a function call" % full_name)
450        return True
451    return False
452
453  def _maybe_add_arg_names(self, node, full_name):
454    """Make args into keyword args if function called full_name requires it."""
455    function_reorders = self._api_change_spec.function_reorders
456
457    if full_name in function_reorders:
458      if uses_star_args_in_call(node):
459        self.add_log(WARNING, node.lineno, node.col_offset,
460                     "(Manual check required) upgrading %s may require "
461                     "re-ordering the call arguments, but it was passed "
462                     "variable-length positional *args. The upgrade "
463                     "script cannot handle these automatically." % full_name)
464
465      reordered = function_reorders[full_name]
466      new_keywords = []
467      idx = 0
468      for arg in node.args:
469        if sys.version_info[:2] >= (3, 5) and isinstance(arg, ast.Starred):
470          continue  # Can't move Starred to keywords
471        keyword_arg = reordered[idx]
472        keyword = ast.keyword(arg=keyword_arg, value=arg)
473        new_keywords.append(keyword)
474        idx += 1
475
476      if new_keywords:
477        self.add_log(INFO, node.lineno, node.col_offset,
478                     "Added keywords to args of function %r" % full_name)
479        node.args = []
480        node.keywords = new_keywords + (node.keywords or [])
481        return True
482    return False
483
484  def _maybe_modify_args(self, node, full_name, name):
485    """Rename keyword args if the function called full_name requires it."""
486    renamed_keywords = self._get_applicable_dict("function_keyword_renames",
487                                                 full_name, name)
488
489    if not renamed_keywords:
490      return False
491
492    if uses_star_kwargs_in_call(node):
493      self.add_log(WARNING, node.lineno, node.col_offset,
494                   "(Manual check required) upgrading %s may require "
495                   "renaming or removing call arguments, but it was passed "
496                   "variable-length *args or **kwargs. The upgrade "
497                   "script cannot handle these automatically." %
498                   (full_name or name))
499    modified = False
500    new_keywords = []
501    for keyword in node.keywords:
502      argkey = keyword.arg
503      if argkey in renamed_keywords:
504        modified = True
505        if renamed_keywords[argkey] is None:
506          lineno = getattr(keyword, "lineno", node.lineno)
507          col_offset = getattr(keyword, "col_offset", node.col_offset)
508          self.add_log(INFO, lineno, col_offset,
509                       "Removed argument %s for function %s" % (
510                           argkey, full_name or name))
511        else:
512          keyword.arg = renamed_keywords[argkey]
513          lineno = getattr(keyword, "lineno", node.lineno)
514          col_offset = getattr(keyword, "col_offset", node.col_offset)
515          self.add_log(INFO, lineno, col_offset,
516                       "Renamed keyword argument for %s from %s to %s" % (
517                           full_name, argkey, renamed_keywords[argkey]))
518          new_keywords.append(keyword)
519      else:
520        new_keywords.append(keyword)
521
522    if modified:
523      node.keywords = new_keywords
524    return modified
525
526  def visit_Call(self, node):  # pylint: disable=invalid-name
527    """Handle visiting a call node in the AST.
528
529    Args:
530      node: Current Node
531    """
532    assert self._stack[-1] is node
533
534    # Get the name for this call, so we can index stuff with it.
535    full_name = self._get_full_name(node.func)
536    if full_name:
537      name = full_name.split(".")[-1]
538    elif isinstance(node.func, ast.Name):
539      name = node.func.id
540    elif isinstance(node.func, ast.Attribute):
541      name = node.func.attr
542    else:
543      name = None
544
545    # Call standard transformers for this node.
546    # Make sure warnings come first, since args or names triggering warnings
547    # may be removed by the other transformations.
548    self._maybe_add_call_warning(node, full_name, name)
549    # Make all args into kwargs
550    self._maybe_add_arg_names(node, full_name)
551    # Argument name changes or deletions
552    self._maybe_modify_args(node, full_name, name)
553
554    # Call transformers. These have the ability to modify the node, and if they
555    # do, will return the new node they created (or the same node if they just
556    # changed it). The are given the parent, but we will take care of
557    # integrating their changes into the parent if they return a new node.
558    #
559    # These are matched on the old name, since renaming is performed by the
560    # Attribute visitor, which happens later.
561    transformers = self._get_applicable_entries("function_transformers",
562                                                full_name, name)
563
564    parent = self._stack[-2]
565
566    if transformers:
567      if uses_star_args_or_kwargs_in_call(node):
568        self.add_log(WARNING, node.lineno, node.col_offset,
569                     "(Manual check required) upgrading %s may require "
570                     "modifying call arguments, but it was passed "
571                     "variable-length *args or **kwargs. The upgrade "
572                     "script cannot handle these automatically." %
573                     (full_name or name))
574
575    for transformer in transformers:
576      logs = []
577      new_node = transformer(parent, node, full_name, name, logs)
578      self.add_logs(logs)
579      if new_node and new_node is not node:
580        pasta.ast_utils.replace_child(parent, node, new_node)
581        node = new_node
582        self._stack[-1] = node
583
584    self.generic_visit(node)
585
586  def visit_Attribute(self, node):  # pylint: disable=invalid-name
587    """Handle bare Attributes i.e. [tf.foo, tf.bar]."""
588    assert self._stack[-1] is node
589
590    full_name = self._get_full_name(node)
591    if full_name:
592      parent = self._stack[-2]
593
594      # Make sure the warning comes first, otherwise the name may have changed
595      self._maybe_add_warning(node, full_name)
596
597      # Once we did a modification, node is invalid and not worth inspecting
598      # further. Also, we only perform modifications for simple nodes, so
599      # There'd be no point in descending further.
600      if self._maybe_rename(parent, node, full_name):
601        return
602      if self._maybe_change_to_function_call(parent, node, full_name):
603        return
604
605      # The isinstance check is enough -- a bare Attribute is never root.
606      i = 2
607      while isinstance(self._stack[-i], ast.Attribute):
608        i += 1
609      whole_name = pasta.dump(self._stack[-(i-1)])
610
611      self._maybe_add_module_deprecation_warning(node, full_name, whole_name)
612
613    self.generic_visit(node)
614
615  def visit_Import(self, node):  # pylint: disable=invalid-name
616    """Handle visiting an import node in the AST.
617
618    Args:
619      node: Current Node
620    """
621    new_aliases = []
622    import_updated = False
623    import_renames = getattr(self._api_change_spec, "import_renames", {})
624    max_submodule_depth = getattr(self._api_change_spec, "max_submodule_depth",
625                                  1)
626    inserts_after_imports = getattr(self._api_change_spec,
627                                    "inserts_after_imports", {})
628
629    # This loop processes imports in the format
630    # import foo as f, bar as b
631    for import_alias in node.names:
632      all_import_components = six.ensure_str(import_alias.name).split(".")
633      # Look for rename, starting with longest import levels.
634      found_update = False
635      for i in reversed(list(range(1, max_submodule_depth + 1))):
636        import_component = all_import_components[0]
637        for j in range(1, min(i, len(all_import_components))):
638          import_component += "." + six.ensure_str(all_import_components[j])
639        import_rename_spec = import_renames.get(import_component, None)
640
641        if not import_rename_spec or excluded_from_module_rename(
642            import_alias.name, import_rename_spec):
643          continue
644
645        new_name = (
646            import_rename_spec.new_name +
647            import_alias.name[len(import_component):])
648
649        # If current import is
650        #   import foo
651        # then new import should preserve imported name:
652        #   import new_foo as foo
653        # This happens when module has just one component.
654        new_asname = import_alias.asname
655        if not new_asname and "." not in import_alias.name:
656          new_asname = import_alias.name
657
658        new_alias = ast.alias(name=new_name, asname=new_asname)
659        new_aliases.append(new_alias)
660        import_updated = True
661        found_update = True
662
663        # Insert any followup lines that should happen after this import.
664        full_import = (import_alias.name, import_alias.asname)
665        insert_offset = 1
666        for line_to_insert in inserts_after_imports.get(full_import, []):
667          assert self._stack[-1] is node
668          parent = self._stack[-2]
669
670          new_line_node = pasta.parse(line_to_insert)
671          ast.copy_location(new_line_node, node)
672          parent.body.insert(
673              parent.body.index(node) + insert_offset, new_line_node)
674          insert_offset += 1
675
676          # Insert a newline after the import if necessary
677          old_suffix = pasta.base.formatting.get(node, "suffix")
678          if old_suffix is None:
679            old_suffix = os.linesep
680          if os.linesep not in old_suffix:
681            pasta.base.formatting.set(node, "suffix",
682                                      six.ensure_str(old_suffix) + os.linesep)
683
684          # Apply indentation to new node.
685          pasta.base.formatting.set(new_line_node, "prefix",
686                                    pasta.base.formatting.get(node, "prefix"))
687          pasta.base.formatting.set(new_line_node, "suffix", os.linesep)
688          self.add_log(
689              INFO, node.lineno, node.col_offset,
690              "Adding `%s` after import of %s" %
691              (new_line_node, import_alias.name))
692        # Find one match, break
693        if found_update:
694          break
695      # No rename is found for all levels
696      if not found_update:
697        new_aliases.append(import_alias)  # no change needed
698
699    # Replace the node if at least one import needs to be updated.
700    if import_updated:
701      assert self._stack[-1] is node
702      parent = self._stack[-2]
703
704      new_node = ast.Import(new_aliases)
705      ast.copy_location(new_node, node)
706      pasta.ast_utils.replace_child(parent, node, new_node)
707      self.add_log(
708          INFO, node.lineno, node.col_offset,
709          "Changed import from %r to %r." %
710          (pasta.dump(node), pasta.dump(new_node)))
711
712    self.generic_visit(node)
713
714  def visit_ImportFrom(self, node):  # pylint: disable=invalid-name
715    """Handle visiting an import-from node in the AST.
716
717    Args:
718      node: Current Node
719    """
720    if not node.module:
721      self.generic_visit(node)
722      return
723
724    from_import = node.module
725
726    # Look for rename based on first component of from-import.
727    # i.e. based on foo in foo.bar.
728    from_import_first_component = six.ensure_str(from_import).split(".")[0]
729    import_renames = getattr(self._api_change_spec, "import_renames", {})
730    import_rename_spec = import_renames.get(from_import_first_component, None)
731    if not import_rename_spec:
732      self.generic_visit(node)
733      return
734
735    # Split module aliases into the ones that require import update
736    # and those that don't. For e.g. if we want to rename "a" to "b"
737    # unless we import "a.c" in the following:
738    # from a import c, d
739    # we want to update import for "d" but not for "c".
740    updated_aliases = []
741    same_aliases = []
742    for import_alias in node.names:
743      full_module_name = "%s.%s" % (from_import, import_alias.name)
744      if excluded_from_module_rename(full_module_name, import_rename_spec):
745        same_aliases.append(import_alias)
746      else:
747        updated_aliases.append(import_alias)
748
749    if not updated_aliases:
750      self.generic_visit(node)
751      return
752
753    assert self._stack[-1] is node
754    parent = self._stack[-2]
755
756    # Replace first component of from-import with new name.
757    new_from_import = (
758        import_rename_spec.new_name +
759        from_import[len(from_import_first_component):])
760    updated_node = ast.ImportFrom(new_from_import, updated_aliases, node.level)
761    ast.copy_location(updated_node, node)
762    pasta.ast_utils.replace_child(parent, node, updated_node)
763
764    # If some imports had to stay the same, add another import for them.
765    additional_import_log = ""
766    if same_aliases:
767      same_node = ast.ImportFrom(from_import, same_aliases, node.level,
768                                 col_offset=node.col_offset, lineno=node.lineno)
769      ast.copy_location(same_node, node)
770      parent.body.insert(parent.body.index(updated_node), same_node)
771      # Apply indentation to new node.
772      pasta.base.formatting.set(
773          same_node, "prefix",
774          pasta.base.formatting.get(updated_node, "prefix"))
775      additional_import_log = " and %r" % pasta.dump(same_node)
776
777    self.add_log(
778        INFO, node.lineno, node.col_offset,
779        "Changed import from %r to %r%s." %
780        (pasta.dump(node),
781         pasta.dump(updated_node),
782         additional_import_log))
783
784    self.generic_visit(node)
785
786
787class AnalysisResult(object):
788  """This class represents an analysis result and how it should be logged.
789
790  This class must provide the following fields:
791
792  * `log_level`: The log level to which this detection should be logged
793  * `log_message`: The message that should be logged for this detection
794
795  For an example, see `VersionedTFImport`.
796  """
797
798
799class APIAnalysisSpec(object):
800  """This class defines how `AnalysisResult`s should be generated.
801
802  It specifies how to map imports and symbols to `AnalysisResult`s.
803
804  This class must provide the following fields:
805
806  * `symbols_to_detect`: maps function names to `AnalysisResult`s
807  * `imports_to_detect`: maps imports represented as (full module name, alias)
808    tuples to `AnalysisResult`s
809    notifications)
810
811  For an example, see `TFAPIImportAnalysisSpec`.
812  """
813
814
815class PastaAnalyzeVisitor(_PastaEditVisitor):
816  """AST Visitor that looks for specific API usage without editing anything.
817
818  This is used before any rewriting is done to detect if any symbols are used
819  that require changing imports or disabling rewriting altogether.
820  """
821
822  def __init__(self, api_analysis_spec):
823    super(PastaAnalyzeVisitor, self).__init__(NoUpdateSpec())
824    self._api_analysis_spec = api_analysis_spec
825    self._results = []   # Holds AnalysisResult objects
826
827  @property
828  def results(self):
829    return self._results
830
831  def add_result(self, analysis_result):
832    self._results.append(analysis_result)
833
834  def visit_Attribute(self, node):  # pylint: disable=invalid-name
835    """Handle bare Attributes i.e. [tf.foo, tf.bar]."""
836    full_name = self._get_full_name(node)
837    if full_name:
838      detection = self._api_analysis_spec.symbols_to_detect.get(full_name, None)
839      if detection:
840        self.add_result(detection)
841        self.add_log(
842            detection.log_level, node.lineno, node.col_offset,
843            detection.log_message)
844
845    self.generic_visit(node)
846
847  def visit_Import(self, node):  # pylint: disable=invalid-name
848    """Handle visiting an import node in the AST.
849
850    Args:
851      node: Current Node
852    """
853    for import_alias in node.names:
854      # Detect based on full import name and alias)
855      full_import = (import_alias.name, import_alias.asname)
856      detection = (self._api_analysis_spec
857                   .imports_to_detect.get(full_import, None))
858      if detection:
859        self.add_result(detection)
860        self.add_log(
861            detection.log_level, node.lineno, node.col_offset,
862            detection.log_message)
863
864    self.generic_visit(node)
865
866  def visit_ImportFrom(self, node):  # pylint: disable=invalid-name
867    """Handle visiting an import-from node in the AST.
868
869    Args:
870      node: Current Node
871    """
872    if not node.module:
873      self.generic_visit(node)
874      return
875
876    from_import = node.module
877
878    for import_alias in node.names:
879      # Detect based on full import name(to & as)
880      full_module_name = "%s.%s" % (from_import, import_alias.name)
881      full_import = (full_module_name, import_alias.asname)
882      detection = (self._api_analysis_spec
883                   .imports_to_detect.get(full_import, None))
884      if detection:
885        self.add_result(detection)
886        self.add_log(
887            detection.log_level, node.lineno, node.col_offset,
888            detection.log_message)
889
890    self.generic_visit(node)
891
892
893class ASTCodeUpgrader(object):
894  """Handles upgrading a set of Python files using a given API change spec."""
895
896  def __init__(self, api_change_spec):
897    if not isinstance(api_change_spec, APIChangeSpec):
898      raise TypeError("Must pass APIChangeSpec to ASTCodeUpgrader, got %s" %
899                      type(api_change_spec))
900    self._api_change_spec = api_change_spec
901
902  def process_file(self,
903                   in_filename,
904                   out_filename,
905                   no_change_to_outfile_on_error=False):
906    """Process the given python file for incompatible changes.
907
908    Args:
909      in_filename: filename to parse
910      out_filename: output file to write to
911      no_change_to_outfile_on_error: not modify the output file on errors
912    Returns:
913      A tuple representing number of files processed, log of actions, errors
914    """
915
916    # Write to a temporary file, just in case we are doing an implace modify.
917    # pylint: disable=g-backslash-continuation
918    with open(in_filename, "r") as in_file, \
919        tempfile.NamedTemporaryFile("w", delete=False) as temp_file:
920      ret = self.process_opened_file(in_filename, in_file, out_filename,
921                                     temp_file)
922    # pylint: enable=g-backslash-continuation
923
924    if no_change_to_outfile_on_error and ret[0] == 0:
925      os.remove(temp_file.name)
926    else:
927      shutil.move(temp_file.name, out_filename)
928    return ret
929
930  def format_log(self, log, in_filename):
931    log_string = "%d:%d: %s: %s" % (log[1], log[2], log[0], log[3])
932    if in_filename:
933      return six.ensure_str(in_filename) + ":" + log_string
934    else:
935      return log_string
936
937  def update_string_pasta(self, text, in_filename):
938    """Updates a file using pasta."""
939    try:
940      t = pasta.parse(text)
941    except (SyntaxError, ValueError, TypeError):
942      log = ["ERROR: Failed to parse.\n" + traceback.format_exc()]
943      return 0, "", log, []
944
945    t, preprocess_logs, preprocess_errors = self._api_change_spec.preprocess(t)
946
947    visitor = _PastaEditVisitor(self._api_change_spec)
948    visitor.visit(t)
949
950    self._api_change_spec.clear_preprocessing()
951
952    logs = [self.format_log(log, None) for log in (preprocess_logs +
953                                                   visitor.log)]
954    errors = [self.format_log(error, in_filename)
955              for error in (preprocess_errors +
956                            visitor.warnings_and_errors)]
957    return 1, pasta.dump(t), logs, errors
958
959  def _format_log(self, log, in_filename, out_filename):
960    text = six.ensure_str("-" * 80) + "\n"
961    text += "Processing file %r\n outputting to %r\n" % (in_filename,
962                                                         out_filename)
963    text += six.ensure_str("-" * 80) + "\n\n"
964    text += "\n".join(log) + "\n"
965    text += six.ensure_str("-" * 80) + "\n\n"
966    return text
967
968  def process_opened_file(self, in_filename, in_file, out_filename, out_file):
969    """Process the given python file for incompatible changes.
970
971    This function is split out to facilitate StringIO testing from
972    tf_upgrade_test.py.
973
974    Args:
975      in_filename: filename to parse
976      in_file: opened file (or StringIO)
977      out_filename: output file to write to
978      out_file: opened file (or StringIO)
979    Returns:
980      A tuple representing number of files processed, log of actions, errors
981    """
982    lines = in_file.readlines()
983    processed_file, new_file_content, log, process_errors = (
984        self.update_string_pasta("".join(lines), in_filename))
985
986    if out_file and processed_file:
987      out_file.write(new_file_content)
988
989    return (processed_file,
990            self._format_log(log, in_filename, out_filename),
991            process_errors)
992
993  def process_tree(self, root_directory, output_root_directory,
994                   copy_other_files):
995    """Processes upgrades on an entire tree of python files in place.
996
997    Note that only Python files. If you have custom code in other languages,
998    you will need to manually upgrade those.
999
1000    Args:
1001      root_directory: Directory to walk and process.
1002      output_root_directory: Directory to use as base.
1003      copy_other_files: Copy files that are not touched by this converter.
1004
1005    Returns:
1006      A tuple of files processed, the report string for all files, and a dict
1007        mapping filenames to errors encountered in that file.
1008    """
1009
1010    if output_root_directory == root_directory:
1011      return self.process_tree_inplace(root_directory)
1012
1013    # make sure output directory doesn't exist
1014    if output_root_directory and os.path.exists(output_root_directory):
1015      print("Output directory %r must not already exist." %
1016            (output_root_directory))
1017      sys.exit(1)
1018
1019    # make sure output directory does not overlap with root_directory
1020    norm_root = os.path.split(os.path.normpath(root_directory))
1021    norm_output = os.path.split(os.path.normpath(output_root_directory))
1022    if norm_root == norm_output:
1023      print("Output directory %r same as input directory %r" %
1024            (root_directory, output_root_directory))
1025      sys.exit(1)
1026
1027    # Collect list of files to process (we do this to correctly handle if the
1028    # user puts the output directory in some sub directory of the input dir)
1029    files_to_process = []
1030    files_to_copy = []
1031    for dir_name, _, file_list in os.walk(root_directory):
1032      py_files = [f for f in file_list if six.ensure_str(f).endswith(".py")]
1033      copy_files = [
1034          f for f in file_list if not six.ensure_str(f).endswith(".py")
1035      ]
1036      for filename in py_files:
1037        fullpath = os.path.join(dir_name, filename)
1038        fullpath_output = os.path.join(output_root_directory,
1039                                       os.path.relpath(fullpath,
1040                                                       root_directory))
1041        files_to_process.append((fullpath, fullpath_output))
1042      if copy_other_files:
1043        for filename in copy_files:
1044          fullpath = os.path.join(dir_name, filename)
1045          fullpath_output = os.path.join(output_root_directory,
1046                                         os.path.relpath(
1047                                             fullpath, root_directory))
1048          files_to_copy.append((fullpath, fullpath_output))
1049
1050    file_count = 0
1051    tree_errors = {}
1052    report = ""
1053    report += six.ensure_str(("=" * 80)) + "\n"
1054    report += "Input tree: %r\n" % root_directory
1055    report += six.ensure_str(("=" * 80)) + "\n"
1056
1057    for input_path, output_path in files_to_process:
1058      output_directory = os.path.dirname(output_path)
1059      if not os.path.isdir(output_directory):
1060        os.makedirs(output_directory)
1061
1062      if os.path.islink(input_path):
1063        link_target = os.readlink(input_path)
1064        link_target_output = os.path.join(
1065            output_root_directory, os.path.relpath(link_target, root_directory))
1066        if (link_target, link_target_output) in files_to_process:
1067          # Create a link to the new location of the target file
1068          os.symlink(link_target_output, output_path)
1069        else:
1070          report += "Copying symlink %s without modifying its target %s" % (
1071              input_path, link_target)
1072          os.symlink(link_target, output_path)
1073        continue
1074
1075      file_count += 1
1076      _, l_report, l_errors = self.process_file(input_path, output_path)
1077      tree_errors[input_path] = l_errors
1078      report += l_report
1079
1080    for input_path, output_path in files_to_copy:
1081      output_directory = os.path.dirname(output_path)
1082      if not os.path.isdir(output_directory):
1083        os.makedirs(output_directory)
1084      shutil.copy(input_path, output_path)
1085    return file_count, report, tree_errors
1086
1087  def process_tree_inplace(self, root_directory):
1088    """Process a directory of python files in place."""
1089    files_to_process = []
1090    for dir_name, _, file_list in os.walk(root_directory):
1091      py_files = [
1092          os.path.join(dir_name, f)
1093          for f in file_list
1094          if six.ensure_str(f).endswith(".py")
1095      ]
1096      files_to_process += py_files
1097
1098    file_count = 0
1099    tree_errors = {}
1100    report = ""
1101    report += six.ensure_str(("=" * 80)) + "\n"
1102    report += "Input tree: %r\n" % root_directory
1103    report += six.ensure_str(("=" * 80)) + "\n"
1104
1105    for path in files_to_process:
1106      if os.path.islink(path):
1107        report += "Skipping symlink %s.\n" % path
1108        continue
1109      file_count += 1
1110      _, l_report, l_errors = self.process_file(path, path)
1111      tree_errors[path] = l_errors
1112      report += l_report
1113
1114    return file_count, report, tree_errors
1115