1# Lint as: python2, python3
2# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Tests for tf 2.0 upgrader."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import inspect
23import os
24import tempfile
25
26from absl.testing import parameterized
27import six
28import tensorflow.compat.v1 as tf
29# OSS TF V2 import placeholder.
30
31from tensorflow.python.framework import test_util
32from tensorflow.python.platform import test as test_lib
33from tensorflow.python.util import tf_decorator
34from tensorflow.python.util import tf_export
35from tensorflow.python.util import tf_inspect
36from tensorflow.tools.common import public_api
37from tensorflow.tools.common import traverse
38from tensorflow.tools.compatibility import ast_edits
39from tensorflow.tools.compatibility import tf_upgrade_v2
40
41
42def get_symbol_for_name(root, name):
43  name_parts = six.ensure_str(name).split(".")
44  symbol = root
45  # Iterate starting with second item since 1st item is "tf.".
46  for part in name_parts[1:]:
47    symbol = getattr(symbol, part)
48  return symbol
49
50
51def get_args(symbol):
52  if hasattr(inspect, "signature"):
53    signature = inspect.signature(symbol)
54    # Ignore *args and **kwargs for now.
55    return [param.name for param in signature.parameters.values()
56            if param.kind == param.POSITIONAL_OR_KEYWORD]
57  return tf_inspect.getargspec(symbol)[0]
58
59
60def get_func_and_args_from_str(call_str):
61  """Parse call string to get function and argument names.
62
63  Args:
64    call_str: Call string must be in the form:
65              `tf.foo(arg1=val1, arg2=val2, ...)`.
66
67  Returns:
68    (function_name, list of arg names) tuple.
69  """
70  open_paren_index = six.ensure_str(call_str).find("(")
71  close_paren_index = call_str.rfind(")")
72
73  function_name = call_str[:six.ensure_str(call_str).find("(")]
74  args = six.ensure_str(call_str[open_paren_index +
75                                 1:close_paren_index]).split(",")
76  args = [six.ensure_str(arg).split("=")[0].strip() for arg in args]
77  args = [arg for arg in args if arg]  # filter out empty strings
78  return function_name, args
79
80
81class TestUpgrade(test_util.TensorFlowTestCase, parameterized.TestCase):
82  """Test various APIs that have been changed in 2.0.
83
84  We also test whether a converted file is executable. test_file_v1_10.py
85  aims to exhaustively test that API changes are convertible and actually
86  work when run with current TensorFlow.
87  """
88
89  @classmethod
90  def setUpClass(cls):
91    super(TestUpgrade, cls).setUpClass()
92    cls.v2_symbols = {}
93    cls.v1_symbols = {}
94    if hasattr(tf.compat, "v2"):
95
96      def symbol_collector(unused_path, unused_parent, children):
97        for child in children:
98          _, attr = tf_decorator.unwrap(child[1])
99          api_names_v2 = tf_export.get_v2_names(attr)
100          for name in api_names_v2:
101            cls.v2_symbols["tf." + six.ensure_str(name)] = attr
102
103      visitor = public_api.PublicAPIVisitor(symbol_collector)
104      visitor.private_map["tf.compat"] = ["v1", "v2"]
105      traverse.traverse(tf.compat.v2, visitor)
106
107    if hasattr(tf.compat, "v1"):
108
109      def symbol_collector_v1(unused_path, unused_parent, children):
110        for child in children:
111          _, attr = tf_decorator.unwrap(child[1])
112          api_names_v1 = tf_export.get_v1_names(attr)
113          for name in api_names_v1:
114            cls.v1_symbols["tf." + six.ensure_str(name)] = attr
115
116      visitor = public_api.PublicAPIVisitor(symbol_collector_v1)
117      visitor.private_map["tf.compat"] = ["v1", "v2"]
118      traverse.traverse(tf.compat.v1, visitor)
119
120  def _upgrade(self,
121               old_file_text,
122               import_rename=False,
123               upgrade_compat_v1_import=False):
124    in_file = six.StringIO(old_file_text)
125    out_file = six.StringIO()
126    upgrader = ast_edits.ASTCodeUpgrader(
127        tf_upgrade_v2.TFAPIChangeSpec(
128            import_rename, upgrade_compat_v1_import=upgrade_compat_v1_import))
129    count, report, errors = (
130        upgrader.process_opened_file("test.py", in_file,
131                                     "test_out.py", out_file))
132    return count, report, errors, out_file.getvalue()
133
134  def _upgrade_multiple(self, old_file_texts):
135    upgrader = ast_edits.ASTCodeUpgrader(tf_upgrade_v2.TFAPIChangeSpec())
136    results = []
137    for old_file_text in old_file_texts:
138      in_file = six.StringIO(old_file_text)
139      out_file = six.StringIO()
140      count, report, errors = (
141          upgrader.process_opened_file("test.py", in_file,
142                                       "test_out.py", out_file))
143      results.append([count, report, errors, out_file.getvalue()])
144    return results
145
146  def testParseError(self):
147    _, report, unused_errors, unused_new_text = self._upgrade(
148        "import tensorflow as tf\na + \n")
149    self.assertNotEqual(six.ensure_str(report).find("Failed to parse"), -1)
150
151  def testReport(self):
152    text = "tf.angle(a)\n"
153    _, report, unused_errors, unused_new_text = self._upgrade(text)
154    # This is not a complete test, but it is a sanity test that a report
155    # is generating information.
156    self.assertTrue(
157        six.ensure_str(report).find("Renamed function `tf.angle` to "
158                                    "`tf.math.angle`"))
159
160  def testRename(self):
161    text = "tf.conj(a)\n"
162    _, unused_report, unused_errors, new_text = self._upgrade(text)
163    self.assertEqual(new_text, "tf.math.conj(a)\n")
164    text = "tf.rsqrt(tf.log_sigmoid(3.8))\n"
165    _, unused_report, unused_errors, new_text = self._upgrade(text)
166    self.assertEqual(new_text, "tf.math.rsqrt(tf.math.log_sigmoid(3.8))\n")
167
168  def testAllAPI(self):
169    if not hasattr(tf.compat, "v2"):
170      return
171
172    # Converts all symbols in the v1 namespace to the v2 namespace, raising
173    # an error if the target of the conversion is not in the v2 namespace.
174    # Please regenerate the renames file or edit any manual renames if this
175    # test fails.
176    def conversion_visitor(unused_path, unused_parent, children):
177      for child in children:
178        _, attr = tf_decorator.unwrap(child[1])
179        api_names = tf_export.get_v1_names(attr)
180        for name in api_names:
181          _, _, _, text = self._upgrade("tf." + six.ensure_str(name))
182          if (text and
183              not text.startswith("tf.compat.v1") and
184              not text.startswith("tf.compat.v2") and
185              text not in self.v2_symbols and
186              # Builds currently install old version of estimator that doesn't
187              # have some 2.0 symbols.
188              not text.startswith("tf.estimator")):
189            self.assertFalse(
190                True, "Symbol %s generated from %s not in v2 API" % (
191                    text, name))
192
193    visitor = public_api.PublicAPIVisitor(conversion_visitor)
194    visitor.do_not_descend_map["tf"].append("contrib")
195    visitor.private_map["tf.compat"] = ["v1", "v2"]
196    traverse.traverse(tf.compat.v1, visitor)
197
198  def testAllAPIV1(self):
199    collect = True
200    v1_symbols = set([])
201
202    # Converts all symbols in the v1 namespace to the v2 namespace, raising
203    # an error if the target of the conversion is not in the v1 namespace.
204    def conversion_visitor(unused_path, unused_parent, children):
205      for child in children:
206        _, attr = tf_decorator.unwrap(child[1])
207        api_names = tf_export.get_v1_names(attr)
208        for name in api_names:
209          if collect:
210            v1_symbols.add("tf." + six.ensure_str(name))
211          else:
212            _, _, _, text = self._upgrade("tf." + six.ensure_str(name))
213            if (text and
214                not text.startswith("tf.compat.v1") and
215                not text.startswith("tf.compat.v2") and
216                not text.startswith("tf.estimator") and
217                text not in v1_symbols):
218              self.assertFalse(
219                  True, "Symbol %s generated from %s not in v1 API" % (
220                      text, name))
221
222    visitor = public_api.PublicAPIVisitor(conversion_visitor)
223    visitor.do_not_descend_map["tf"].append("contrib")
224    visitor.private_map["tf.compat"] = ["v1", "v2"]
225    traverse.traverse(tf.compat.v1, visitor)
226    collect = False
227    traverse.traverse(tf.compat.v1, visitor)
228
229  def testV1KeywordArgNames(self):
230    all_keyword_renames = (
231        tf_upgrade_v2.TFAPIChangeSpec().function_keyword_renames)
232
233    # Visitor that verifies V1 argument names.
234    def arg_test_visitor(unused_path, unused_parent, children):
235      for child in children:
236        _, attr = tf_decorator.unwrap(child[1])
237        names_v1 = tf_export.get_v1_names(attr)
238
239        for name in names_v1:
240          name = "tf.%s" % name
241          if name not in all_keyword_renames:
242            continue
243          arg_names_v1 = tf_inspect.getargspec(attr)[0]
244          keyword_renames = all_keyword_renames[name]
245          self.assertEqual(type(keyword_renames), dict)
246
247          # Assert that v1 function has valid v1 argument names.
248          for from_name, _ in keyword_renames.items():
249            self.assertIn(
250                from_name, arg_names_v1,
251                "%s not found in %s arguments: %s" %
252                (from_name, name, str(arg_names_v1)))
253
254    visitor = public_api.PublicAPIVisitor(arg_test_visitor)
255    visitor.do_not_descend_map["tf"].append("contrib")
256    visitor.private_map["tf.compat"] = ["v1", "v2"]
257    traverse.traverse(tf.compat.v1, visitor)
258
259  def testV2KeywordArgNames(self):
260    # This test converts a call of the form:
261    # tf.foo(arg1=0, arg2=1, ...)
262    # to 2.0. Then, checks that converted function has valid argument names.
263    if not hasattr(tf.compat, "v2"):
264      return
265    v2_arg_exceptions = {
266        "verify_shape_is_now_always_true",
267        # These arguments should not be used, they just specify
268        # that a function takes named arguments.
269        "keyword_required",
270        "_sentinel",
271    }
272    v1_name_exceptions = {
273        "tf.print",  # requires print_function import
274    }
275    function_warnings = (
276        tf_upgrade_v2.TFAPIChangeSpec().function_warnings)
277    function_transformers = (
278        tf_upgrade_v2.TFAPIChangeSpec().function_transformers)
279    keyword_renames = (
280        tf_upgrade_v2.TFAPIChangeSpec().function_keyword_renames)
281
282    # Visitor that converts to V2 and checks V2 argument names.
283    def conversion_visitor(unused_path, unused_parent, children):
284      for child in children:
285        _, attr = tf_decorator.unwrap(child[1])
286        if not tf_inspect.isfunction(attr):
287          continue
288        names_v1 = tf_export.get_v1_names(attr)
289        arg_names_v1 = get_args(attr)
290
291        for name in names_v1:
292          tf_name = "tf.%s" % name
293          if tf_name in function_warnings or tf_name in function_transformers:
294            continue  # These require manual change
295          if tf_name in v1_name_exceptions:
296            continue
297          # Assert that arg names after converting to v2 are present in
298          # v2 function.
299          # 1. First, create an input of the form:
300          #    tf.foo(arg1=val1, arg2=val2, ...)
301          args = ",".join(
302              ["%s=%d" % (from_name, from_index)
303               for from_index, from_name in enumerate(arg_names_v1)])
304          text_input = "%s(%s)" % (tf_name, args)
305          # 2. Convert the input to V2.
306          _, _, _, text = self._upgrade(text_input)
307          new_function_name, new_args = get_func_and_args_from_str(text)
308          if new_function_name == "tf.compat.v1.%s" % name:
309            if tf_name in keyword_renames:
310              # If we rename arguments, new function must be available in 2.0.
311              # We should not be using compat.v1 in this case.
312              self.fail(
313                  "Function '%s' is not in 2.0 when converting\n%s\nto\n%s" %
314                  (new_function_name, text_input, text))
315            continue
316          if new_function_name.startswith("tf.compat.v2"):
317            self.assertIn(new_function_name.replace("tf.compat.v2.", "tf."),
318                          self.v2_symbols)
319            continue
320          # 3. Verify V2 function and arguments.
321          args_v2 = get_args(self.v2_symbols[new_function_name])
322          args_v2.extend(v2_arg_exceptions)
323          for new_arg in new_args:
324            self.assertIn(
325                new_arg, args_v2,
326                "Invalid argument '%s' in 2.0 when converting\n%s\nto\n%s.\n"
327                "Supported arguments: %s" % (
328                    new_arg, text_input, text, str(args_v2)))
329          # 4. Verify that the argument exists in v1 as well.
330          if new_function_name in set(["tf.nn.ctc_loss",
331                                       "tf.saved_model.save"]):
332            continue
333          args_v1 = get_args(self.v1_symbols[new_function_name])
334          args_v1.extend(v2_arg_exceptions)
335          for new_arg in new_args:
336            self.assertIn(
337                new_arg, args_v1,
338                "Invalid argument '%s' in 1.0 when converting\n%s\nto\n%s.\n"
339                "Supported arguments: %s" % (
340                    new_arg, text_input, text, str(args_v1)))
341
342    visitor = public_api.PublicAPIVisitor(conversion_visitor)
343    visitor.do_not_descend_map["tf"].append("contrib")
344    visitor.private_map["tf.compat"] = ["v1", "v2"]
345    traverse.traverse(tf.compat.v1, visitor)
346
347  def testPositionsMatchArgGiven(self):
348    full_dict = tf_upgrade_v2.TFAPIChangeSpec().function_arg_warnings
349    method_names = list(full_dict.keys())
350    for method_name in method_names:
351      args = list(full_dict[method_name].keys())
352      if "contrib" in method_name:
353        # Skip descending and fetching contrib methods during test. These are
354        # not available in the repo anymore.
355        continue
356      elif six.ensure_str(method_name).startswith("*."):
357        # special case for optimizer methods
358        method = six.ensure_str(method_name).replace("*", "tf.train.Optimizer")
359      else:
360        method = method_name
361
362      method = get_symbol_for_name(tf, method)
363      arg_spec = tf_inspect.getfullargspec(method)
364      for (arg, pos) in args:
365        # to deal with the self argument on methods on objects
366        if six.ensure_str(method_name).startswith("*."):
367          pos += 1
368        self.assertEqual(arg_spec[0][pos], arg)
369
370  def testReorderFileNeedsUpdate(self):
371    reordered_function_names = (
372        tf_upgrade_v2.TFAPIChangeSpec().reordered_function_names)
373    function_reorders = (
374        tf_upgrade_v2.TFAPIChangeSpec().function_reorders)
375    manual_function_reorders = (
376        tf_upgrade_v2.TFAPIChangeSpec().manual_function_reorders)
377
378    added_names_message = """Some function names in
379self.reordered_function_names are not in reorders_v2.py.
380Please run the following commands to update reorders_v2.py:
381bazel build tensorflow/tools/compatibility/update:generate_v2_reorders_map
382bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
383"""
384    removed_names_message = """%s in self.reorders_v2 does not match
385any name in self.reordered_function_names.
386Please run the following commands to update reorders_v2.py:
387bazel build tensorflow/tools/compatibility/update:generate_v2_reorders_map
388bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
389"""
390    self.assertTrue(
391        reordered_function_names.issubset(function_reorders),
392        added_names_message)
393    # function_reorders should contain reordered_function_names
394    # and their TensorFlow V1 aliases.
395    for name in function_reorders:
396      if name in manual_function_reorders:
397        continue
398      # get other names for this function
399      attr = get_symbol_for_name(tf.compat.v1, name)
400      _, attr = tf_decorator.unwrap(attr)
401      v1_names = tf_export.get_v1_names(attr)
402      self.assertTrue(v1_names)
403      v1_names = ["tf.%s" % n for n in v1_names]
404      # check if any other name is in
405      self.assertTrue(
406          any(n in reordered_function_names for n in v1_names),
407          removed_names_message % name)
408
409  def testRenameConstant(self):
410    text = "tf.MONOLITHIC_BUILD\n"
411    _, unused_report, unused_errors, new_text = self._upgrade(text)
412    self.assertEqual(new_text, "tf.sysconfig.MONOLITHIC_BUILD\n")
413    text = "some_call(tf.MONOLITHIC_BUILD)\n"
414    _, unused_report, unused_errors, new_text = self._upgrade(text)
415    self.assertEqual(new_text, "some_call(tf.sysconfig.MONOLITHIC_BUILD)\n")
416
417  def testRenameArgs(self):
418    text = ("tf.nn.pool(input_a, window_shape_a, pooling_type_a, padding_a, "
419            "dilation_rate_a, strides_a, name_a, data_format_a)\n")
420    _, unused_report, unused_errors, new_text = self._upgrade(text)
421    self.assertEqual(new_text,
422                     ("tf.nn.pool(input=input_a, window_shape=window_shape_a,"
423                      " pooling_type=pooling_type_a, padding=padding_a, "
424                      "dilations=dilation_rate_a, strides=strides_a, "
425                      "name=name_a, data_format=data_format_a)\n"))
426
427  def testReorder(self):
428    text = "tf.boolean_mask(a, b, c, d)\n"
429    _, unused_report, unused_errors, new_text = self._upgrade(text)
430    self.assertEqual(new_text,
431                     "tf.boolean_mask(tensor=a, mask=b, name=c, axis=d)\n")
432
433  def testLearningRateDecay(self):
434    for decay in ["tf.train.exponential_decay",
435                  "tf.train.polynomial_decay", "tf.train.natural_exp_decay",
436                  "tf.train.inverse_time_decay", "tf.train.cosine_decay",
437                  "tf.train.cosine_decay_restarts",
438                  "tf.train.linear_cosine_decay",
439                  "tf.train.noisy_linear_cosine_decay",
440                  "tf.train.piecewise_constant_decay",
441                 ]:
442
443      text = "%s(a, b)\n" % decay
444      _, report, unused_errors, _ = self._upgrade(text)
445      self.assertIn("switch to the schedules in "
446                    "`tf.keras.optimizers.schedules`", report)
447
448  def verify_compat_v1_rename_correctness(self, values, ns_prefix=""):
449    if ns_prefix:
450      ns_prefix += "."
451    for v in values:
452      text = "tf." + ns_prefix + v + "(a, b)"
453      _, _, _, new_text = self._upgrade(text)
454      self.assertEqual("tf.compat.v1." + ns_prefix + v + "(a, b)", new_text)
455
456  def testInitializers(self):
457    initializers = [
458        "zeros",
459        "ones",
460        "constant",
461        "random_uniform",
462        "random_normal",
463        "truncated_normal",
464        "variance_scaling",
465        "orthogonal",
466        "glorot_uniform",
467        "glorot_normal",
468        "identity",
469        "lecun_normal",
470        "lecun_uniform",
471        "he_normal",
472        "he_uniform",
473    ]
474    self.verify_compat_v1_rename_correctness(
475        initializers, ns_prefix="initializers")
476
477    initializers = [
478        "zeros_initializer",
479        "ones_initializer",
480        "constant_initializer",
481        "random_uniform_initializer",
482        "random_normal_initializer",
483        "truncated_normal_initializer",
484        "variance_scaling_initializer",
485        "orthogonal_initializer",
486        "glorot_uniform_initializer",
487        "glorot_normal_initializer",
488    ]
489    self.verify_compat_v1_rename_correctness(initializers)
490
491    initializers = [
492        "zeros",
493        "ones",
494        "Ones",
495        "Zeros",
496        "constant",
497        "Constant",
498        "VarianceScaling",
499        "Orthogonal",
500        "orthogonal",
501        "Identity",
502        "identity",
503        "glorot_uniform",
504        "glorot_normal",
505        "lecun_normal",
506        "lecun_uniform",
507        "he_normal",
508        "he_uniform",
509        "TruncatedNormal",
510        "truncated_normal",
511        "RandomUniform",
512        "uniform",
513        "random_uniform",
514        "RandomNormal",
515        "normal",
516        "random_normal",
517    ]
518    self.verify_compat_v1_rename_correctness(
519        initializers, ns_prefix="keras.initializers")
520
521  def testContribXavierInitializer(self):
522    for contrib_alias in ["tf.contrib.", "contrib_"]:
523      text = contrib_alias + "layers.xavier_initializer()\n"
524      _, unused_report, unused_errors, new_text = self._upgrade(text)
525      self.assertEqual(
526          new_text,
527          "tf.compat.v1.keras.initializers.VarianceScaling(scale=1.0, "
528          "mode=\"fan_avg\", "
529          "distribution=\"uniform\")\n",
530      )
531
532      text = "slim.xavier_initializer(True or False)\n"
533      _, unused_report, unused_errors, new_text = self._upgrade(text)
534      self.assertEqual(
535          new_text,
536          "tf.compat.v1.keras.initializers.VarianceScaling(scale=1.0, "
537          "mode=\"fan_avg\", "
538          "distribution=(\"uniform\" if True or False else "
539          "\"truncated_normal\"))\n",
540      )
541
542      text = "slim.xavier_initializer(uniform=(True or False))\n"
543      _, unused_report, unused_errors, new_text = self._upgrade(text)
544      self.assertEqual(
545          new_text,
546          "tf.compat.v1.keras.initializers.VarianceScaling(scale=1.0, "
547          "mode=\"fan_avg\", "
548          "distribution=(\"uniform\" if True or False else "
549          "\"truncated_normal\"))\n",
550      )
551
552      text = contrib_alias + "layers.xavier_initializer_conv2d(False, 12)\n"
553      _, unused_report, unused_errors, new_text = self._upgrade(text)
554      self.assertEqual(
555          new_text,
556          "tf.compat.v1.keras.initializers.VarianceScaling(scale=1.0, "
557          "mode=\"fan_avg\", "
558          "distribution=(\"uniform\" if False else \"truncated_normal\"), "
559          "seed=12)\n",
560      )
561
562      text = (contrib_alias + "layers.xavier_initializer_conv2d("
563              "False, 12, tf.float32)\n")
564      _, unused_report, unused_errors, new_text = self._upgrade(text)
565      self.assertEqual(
566          new_text,
567          "tf.compat.v1.keras.initializers.VarianceScaling(scale=1.0, "
568          "mode=\"fan_avg\", "
569          "distribution=(\"uniform\" if False else \"truncated_normal\"), "
570          "seed=12, "
571          "dtype=tf.float32)\n",
572      )
573
574      text = (contrib_alias + "layers.xavier_initializer("
575              "False, 12, dtypes=tf.float32)\n")
576      _, unused_report, unused_errors, new_text = self._upgrade(text)
577      self.assertEqual(
578          new_text,
579          "tf.compat.v1.keras.initializers.VarianceScaling(scale=1.0, "
580          "mode=\"fan_avg\", "
581          "distribution=(\"uniform\" if False else \"truncated_normal\"), "
582          "seed=12, "
583          "dtypes=tf.float32)\n",
584      )
585
586  def testVarianceScalingInitializer(self):
587    text = ("tf.contrib.layers.variance_scaling_initializer("
588            "mode=(\"FAN\" + \"_AVG\"))\n")
589    _, unused_report, unused_errors, new_text = self._upgrade(text)
590    self.assertEqual(
591        new_text,
592        "tf.compat.v1.keras.initializers.VarianceScaling(scale=2.0, "
593        "mode=(\"FAN\" + \"_AVG\").lower())\n",
594    )
595
596    text = ("slim.variance_scaling_initializer("
597            "uniform=(True or False), mode=(\"FAN\" + \"_AVG\"))\n")
598    _, unused_report, unused_errors, new_text = self._upgrade(text)
599    self.assertEqual(
600        new_text,
601        "tf.compat.v1.keras.initializers.VarianceScaling(scale=2.0, "
602        "distribution=(\"uniform\" if True or False else \"truncated_normal\"),"
603        " mode=(\"FAN\" + \"_AVG\").lower())\n",
604    )
605
606    text = "tf.contrib.layers.variance_scaling_initializer(factor=1.0)\n"
607    _, unused_report, unused_errors, new_text = self._upgrade(text)
608    self.assertEqual(
609        new_text,
610        "tf.compat.v1.keras.initializers.VarianceScaling(scale=1.0)\n",
611    )
612
613    text = ("tf.contrib.layers.variance_scaling_initializer("
614            "12.0, \"FAN_AVG\", True, dtypes=tf.float32)\n")
615    _, unused_report, unused_errors, new_text = self._upgrade(text)
616    self.assertEqual(
617        new_text,
618        "tf.compat.v1.keras.initializers.VarianceScaling(12.0, "
619        "(\"FAN_AVG\").lower(), "
620        "(\"uniform\" if True else \"truncated_normal\"), "
621        "dtypes=tf.float32)\n",
622    )
623
624  def testMetrics(self):
625    metrics = [
626        "accuracy",
627        "auc",
628        "average_precision_at_k",
629        "false_negatives",
630        "false_negatives_at_thresholds",
631        "false_positives",
632        "false_positives_at_thresholds",
633        "mean",
634        "mean_absolute_error",
635        "mean_cosine_distance",
636        "mean_iou",
637        "mean_per_class_accuracy",
638        "mean_relative_error",
639        "mean_squared_error",
640        "mean_tensor",
641        "percentage_below",
642        "precision",
643        "precision_at_k",
644        "precision_at_thresholds",
645        "precision_at_top_k",
646        "recall",
647        "recall_at_k",
648        "recall_at_thresholds",
649        "recall_at_top_k",
650        "root_mean_squared_error",
651        "sensitivity_at_specificity",
652        "sparse_average_precision_at_k",
653        "sparse_precision_at_k",
654        "specificity_at_sensitivity",
655        "true_negatives",
656        "true_negatives_at_thresholds",
657        "true_positives",
658        "true_positives_at_thresholds",
659    ]
660    for m in metrics:
661      text = "tf.metrics." + m + "(a, b)"
662      _, report, unused_errors, new_text = self._upgrade(text)
663      self.assertEqual("tf.compat.v1.metrics." + m + "(a, b)", new_text)
664      self.assertIn(
665          "tf.metrics have been replaced with object oriented versions", report)
666
667  def testLosses(self):
668    losses = [
669        "absolute_difference",
670        "add_loss",
671        "compute_weighted_loss",
672        "cosine_distance",
673        "get_losses",
674        "get_regularization_loss",
675        "get_regularization_losses",
676        "get_total_loss",
677        "hinge_loss",
678        "huber_loss",
679        "log_loss",
680        "mean_pairwise_squared_error",
681        "mean_squared_error",
682        "sigmoid_cross_entropy",
683        "softmax_cross_entropy",
684        "sparse_softmax_cross_entropy",
685    ]
686    for l in losses:
687      text = "tf.losses." + l + "(a, b)"
688      _, report, unused_errors, new_text = self._upgrade(text)
689      self.assertEqual("tf.compat.v1.losses." + l + "(a, b)", new_text)
690      self.assertIn(
691          "tf.losses have been replaced with object oriented versions", report)
692
693  def testEstimatorLossReductionChange(self):
694    classes = [
695        "LinearClassifier", "LinearRegressor", "DNNLinearCombinedClassifier",
696        "DNNLinearCombinedRegressor", "DNNRegressor", "DNNClassifier",
697        "BaselineClassifier", "BaselineRegressor"
698    ]
699    for c in classes:
700      ns = "tf.estimator." + c
701      text = ns + "()"
702      expected_text = ns + "(loss_reduction=tf.keras.losses.Reduction.SUM)"
703      _, report, errors, new_text = self._upgrade(text)
704      self.assertEqual(expected_text, new_text)
705
706      text = ns + "(loss_reduction=TEST)"
707      expected_text = ns + "(loss_reduction=TEST)"
708      _, report, errors, new_text = self._upgrade(text)
709      self.assertEqual(text, new_text)
710    text = "tf.estimator.BaselineClassifier(m, c, w, v, o, c, lr)"
711    expected_text = (
712        "tf.compat.v1.estimator.BaselineClassifier("
713        "model_dir=m, n_classes=c, weight_column=w, label_vocabulary=v, "
714        "optimizer=o, config=c, loss_reduction=lr)")
715    _, report, errors, new_text = self._upgrade(text)
716    self.assertEqual(expected_text, new_text)
717
718    text = "tf.estimator.BaselineClassifier(model_dir=model_dir)"
719    expected_text = ("tf.estimator.BaselineClassifier(" +
720                     "model_dir=model_dir, "
721                     "loss_reduction=tf.keras.losses.Reduction.SUM)")
722    _, report, errors, new_text = self._upgrade(text)
723    self.assertEqual(expected_text, new_text)
724
725  def testBaseEstimatorPartitioner(self):
726    classes = ["LinearEstimator", "DNNLinearCombinedEstimator", "DNNEstimator"]
727    for c in classes:
728      ns = "tf.estimator." + c
729      suffix = "(input_layer_partitioner=TEST)"
730      text = ns + suffix
731      expected_text = "tf.compat.v1.estimator." + c + suffix
732      _, unused_report, unused_errors, new_text = self._upgrade(text)
733      self.assertEqual(new_text, expected_text)
734
735  def testCannedEstimatorPartitioner(self):
736    classes = [
737        "LinearClassifier", "LinearRegressor", "DNNLinearCombinedClassifier",
738        "DNNLinearCombinedRegressor", "DNNRegressor", "DNNClassifier"
739    ]
740
741    for c in classes:
742      ns = "tf.estimator." + c
743      suffix = "(input_layer_partitioner=TEST)"
744      text = ns + suffix
745      suffix = ("(input_layer_partitioner=TEST, "
746                "loss_reduction=tf.keras.losses.Reduction.SUM)")
747      expected_text = "tf.compat.v1.estimator." + c + suffix
748      _, unused_report, unused_errors, new_text = self._upgrade(text)
749      self.assertEqual(new_text, expected_text)
750
751  def testBaseEstimatorOptimizer(self):
752    classes = ["BaselineEstimator", "LinearEstimator", "DNNEstimator"]
753    for c in classes:
754      ns = "tf.estimator." + c
755      suffix = "(optimizer=TEST)"
756      text = ns + suffix
757      expected_text = "tf.compat.v1.estimator." + c + suffix
758      _, unused_report, unused_errors, new_text = self._upgrade(text)
759      self.assertEqual(new_text, expected_text)
760
761  def testDNNLinearCombinedEstimatorOptimizer(self):
762    classes = ["DNNLinearCombinedEstimator"]
763    for c in classes:
764      ns = "tf.estimator." + c
765      suffix = "(dnn_optimizer=TEST, linear_optimizer=Test)"
766      text = ns + suffix
767      expected_text = "tf.compat.v1.estimator." + c + suffix
768      _, unused_report, unused_errors, new_text = self._upgrade(text)
769      self.assertEqual(new_text, expected_text)
770
771  def testCannedEstimatorOptimizer(self):
772    classes = [
773        "BaselineClassifier", "BaselineRegressor", "LinearClassifier",
774        "LinearRegressor", "DNNRegressor", "DNNClassifier"
775    ]
776
777    for c in classes:
778      ns = "tf.estimator." + c
779      suffix = "(optimizer=TEST)"
780      text = ns + suffix
781      suffix = ("(optimizer=TEST, "
782                "loss_reduction=tf.keras.losses.Reduction.SUM)")
783      expected_text = "tf.compat.v1.estimator." + c + suffix
784      _, unused_report, unused_errors, new_text = self._upgrade(text)
785      self.assertEqual(new_text, expected_text)
786
787  def testDNNLinearCombinedOptimizer(self):
788    classes = [
789        "DNNLinearCombinedClassifier",
790        "DNNLinearCombinedRegressor",
791    ]
792    for c in classes:
793      ns = "tf.estimator." + c
794      suffix = "(dnn_optimizer=TEST, linear_optimizer=Test)"
795      text = ns + suffix
796      suffix = ("(dnn_optimizer=TEST, linear_optimizer=Test, "
797                "loss_reduction=tf.keras.losses.Reduction.SUM)")
798      expected_text = "tf.compat.v1.estimator." + c + suffix
799      _, unused_report, unused_errors, new_text = self._upgrade(text)
800      self.assertEqual(new_text, expected_text)
801
802  def testBaseEstimatorPartitionerAndOptimizer(self):
803    classes = ["LinearEstimator", "DNNEstimator"]
804    for c in classes:
805      ns = "tf.estimator." + c
806      suffix = "(input_layer_partitioner=TEST, optimizer=TEST)"
807      text = ns + suffix
808      expected_text = "tf.compat.v1.estimator." + c + suffix
809      _, unused_report, unused_errors, new_text = self._upgrade(text)
810      self.assertEqual(new_text, expected_text)
811
812  def testDNNLinearCombinedEstimatorPartitionerAndOptimizer(self):
813    classes = ["DNNLinearCombinedEstimator"]
814    for c in classes:
815      ns = "tf.estimator." + c
816      suffix = ("(input_layer_partitioner=TEST, dnn_optimizer=TEST, "
817                "linear_optimizer=TEST)")
818      text = ns + suffix
819      expected_text = "tf.compat.v1.estimator." + c + suffix
820      _, unused_report, unused_errors, new_text = self._upgrade(text)
821      self.assertEqual(new_text, expected_text)
822
823  def testCannedEstimatorPartitionerAndOptimizer(self):
824    classes = [
825        "LinearClassifier", "LinearRegressor", "DNNRegressor", "DNNClassifier"
826    ]
827
828    for c in classes:
829      ns = "tf.estimator." + c
830      suffix = "(input_layer_partitioner=TEST, optimizer=TEST)"
831      text = ns + suffix
832      suffix = ("(input_layer_partitioner=TEST, optimizer=TEST, "
833                "loss_reduction=tf.keras.losses.Reduction.SUM)")
834      expected_text = "tf.compat.v1.estimator." + c + suffix
835      _, unused_report, unused_errors, new_text = self._upgrade(text)
836      self.assertEqual(new_text, expected_text)
837
838  def testDNNLinearCombinedPartitionerAndOptimizer(self):
839    classes = [
840        "DNNLinearCombinedClassifier",
841        "DNNLinearCombinedRegressor",
842    ]
843
844    for c in classes:
845      ns = "tf.estimator." + c
846      suffix = ("(input_layer_partitioner=TEST, dnn_optimizer=TEST, "
847                "linear_optimizer=TEST)")
848      text = ns + suffix
849      suffix = ("(input_layer_partitioner=TEST, dnn_optimizer=TEST, "
850                "linear_optimizer=TEST, "
851                "loss_reduction=tf.keras.losses.Reduction.SUM)")
852      expected_text = "tf.compat.v1.estimator." + c + suffix
853      _, unused_report, unused_errors, new_text = self._upgrade(text)
854      self.assertEqual(new_text, expected_text)
855
856  def testExtractGlimpse(self):
857    text = ("tf.image.extract_glimpse(x, size, off, False, "
858            "False, False, name=\"foo\")\n")
859    _, unused_report, unused_errors, new_text = self._upgrade(text)
860    self.assertEqual(
861        new_text,
862        "tf.image.extract_glimpse(x, size, off, False, "
863        "False, 'uniform' if (False) else 'gaussian', name=\"foo\")\n",
864    )
865
866    text = ("tf.image.extract_glimpse(x, size, off, centered=False, "
867            "normalized=False, uniform_noise=True if uniform_noise else "
868            "False, name=\"foo\")\n")
869    _, unused_report, unused_errors, new_text = self._upgrade(text)
870    self.assertEqual(
871        new_text,
872        "tf.image.extract_glimpse(x, size, off, centered=False, "
873        "normalized=False, noise='uniform' if (True if uniform_noise else "
874        "False) else 'gaussian', name=\"foo\")\n",
875    )
876
877    text = ("tf.image.extract_glimpse(x,\n"
878            "                         size,\n"
879            "                         off,\n"
880            "                         centered=True,\n"
881            "                         normalized=True, # Stuff before\n"
882            "                         uniform_noise=False,\n"
883            "                         name=\"foo\")# Stuff after\n")
884    _, unused_report, unused_errors, new_text = self._upgrade(text)
885    self.assertEqual(
886        new_text, "tf.image.extract_glimpse(x,\n"
887        "                         size,\n"
888        "                         off,\n"
889        "                         centered=True,\n"
890        "                         normalized=True, # Stuff before\n"
891        "                         noise='uniform' if (False) else 'gaussian',\n"
892        "                         name=\"foo\")# Stuff after\n")
893
894    text = "tf.image.extract_glimpse(x)\n"
895    _, unused_report, errors, new_text = self._upgrade(text)
896    self.assertEqual(new_text, text)
897    self.assertEqual(errors, [])
898
899  def testDropout(self):
900    text = "tf.nn.dropout(x, keep_prob, name=\"foo\")\n"
901    _, unused_report, unused_errors, new_text = self._upgrade(text)
902    self.assertEqual(
903        new_text,
904        "tf.nn.dropout(x, rate=1 - (keep_prob), name=\"foo\")\n",
905    )
906
907    text = "tf.nn.dropout(x, keep_prob=.4, name=\"foo\")\n"
908    _, unused_report, unused_errors, new_text = self._upgrade(text)
909    self.assertEqual(
910        new_text,
911        "tf.nn.dropout(x, rate=1 - (.4), name=\"foo\")\n",
912    )
913
914    text = (
915        "tf.nn.dropout(x,  # Stuff before\n"
916        "              keep_prob=.4,  # Stuff after\n"
917        "              name=\"foo\")\n"
918    )
919    _, unused_report, unused_errors, new_text = self._upgrade(text)
920    self.assertEqual(
921        new_text,
922        "tf.nn.dropout(x,  # Stuff before\n"
923        "              rate=1 - (.4),  # Stuff after\n"
924        "              name=\"foo\")\n",
925    )
926
927    text = "tf.nn.dropout(x)\n"
928    _, unused_report, errors, new_text = self._upgrade(text)
929    self.assertEqual(new_text, text)
930    self.assertIn("tf.nn.dropout called without arguments", errors[0])
931
932  def testDropoutExpr(self):
933    text = "tf.nn.dropout(x, 1 - func(3 + 4.), name=\"foo\")\n"
934    _, unused_report, unused_errors, new_text = self._upgrade(text)
935    self.assertEqual(
936        new_text,
937        "tf.nn.dropout(x, rate=1 - (1 - func(3 + 4.)), name=\"foo\")\n",
938    )
939
940  def testContribL1(self):
941    text = "tf.contrib.layers.l1_regularizer(scale)\n"
942    _, unused_report, unused_errors, new_text = self._upgrade(text)
943    self.assertEqual(
944        new_text,
945        "tf.keras.regularizers.l1(scale)\n",
946    )
947    self.assertNotIn("Dropping scope", unused_report)
948
949    text = "tf.contrib.layers.l1_regularizer(scale, scope)\n"
950    _, unused_report, unused_errors, new_text = self._upgrade(text)
951    self.assertEqual(
952        new_text,
953        "tf.keras.regularizers.l1(scale)\n",
954    )
955    self.assertIn("Dropping scope", unused_report)
956
957    text = (
958        "slim.l1_regularizer(  # Stuff before\n"
959        "                    scale=.4,"
960        "                    scope=\"foo\")\n"
961    )
962    _, unused_report, unused_errors, new_text = self._upgrade(text)
963    self.assertEqual(
964        new_text,
965        "tf.keras.regularizers.l1(  # Stuff before\n"
966        "                    l=.4)\n",
967    )
968    self.assertIn("Dropping scope", unused_report)
969
970  def testContribL2(self):
971    text = "tf.contrib.layers.l2_regularizer(scale)\n"
972    _, unused_report, unused_errors, new_text = self._upgrade(text)
973    self.assertEqual(
974        new_text,
975        "tf.keras.regularizers.l2(0.5 * (scale))\n",
976    )
977    self.assertNotIn("Dropping scope", unused_report)
978
979    text = "tf.contrib.layers.l2_regularizer(scale, scope)\n"
980    _, unused_report, unused_errors, new_text = self._upgrade(text)
981    self.assertEqual(
982        new_text,
983        "tf.keras.regularizers.l2(0.5 * (scale))\n",
984    )
985    self.assertIn("Dropping scope", unused_report)
986
987    text = (
988        "slim.l2_regularizer(  # Stuff before\n"
989        "                    scale=.4,"
990        "                    scope=\"foo\")\n"
991    )
992    _, unused_report, unused_errors, new_text = self._upgrade(text)
993    self.assertEqual(
994        new_text,
995        "tf.keras.regularizers.l2(  # Stuff before\n"
996        "                    l=0.5 * (.4))\n",
997    )
998    self.assertIn("Dropping scope", unused_report)
999
1000  def testContribL2Expr(self):
1001    text = "tf.contrib.layers.l2_regularizer(1 - func(3 + 4.), scope=\"foo\")\n"
1002    _, unused_report, unused_errors, new_text = self._upgrade(text)
1003    self.assertEqual(
1004        new_text,
1005        "tf.keras.regularizers.l2(0.5 * (1 - func(3 + 4.)))\n",
1006    )
1007
1008  def testMathCountNonZeroChanges(self):
1009    text = (
1010        "tf.math.count_nonzero(input_tensor=input, dtype=dtype, name=name, "
1011        "reduction_indices=axis, keep_dims=keepdims)\n"
1012        )
1013    _, unused_report, unused_errors, new_text = self._upgrade(text)
1014    expected_text = (
1015        "tf.math.count_nonzero(input=input, dtype=dtype, name=name, "
1016        "axis=axis, keepdims=keepdims)\n"
1017        )
1018    self.assertEqual(new_text, expected_text)
1019
1020  def testCountNonZeroChanges(self):
1021    text = (
1022        "tf.count_nonzero(input_tensor=input, dtype=dtype, name=name, "
1023        "reduction_indices=axis, keep_dims=keepdims)\n"
1024        )
1025    _, unused_report, unused_errors, new_text = self._upgrade(text)
1026    expected_text = (
1027        "tf.math.count_nonzero(input=input, dtype=dtype, name=name, "
1028        "axis=axis, keepdims=keepdims)\n"
1029        )
1030    self.assertEqual(new_text, expected_text)
1031
1032  def testRandomMultinomialToRandomCategorical(self):
1033    text = (
1034        "tf.random.multinomial(logits, samples, seed, name, output_dtype)\n"
1035        )
1036    _, unused_report, unused_errors, new_text = self._upgrade(text)
1037    expected_text = (
1038        "tf.random.categorical(logits=logits, num_samples=samples, seed=seed, "
1039        "name=name, dtype=output_dtype)\n"
1040        )
1041    self.assertEqual(new_text, expected_text)
1042
1043    text = (
1044        "tf.multinomial(logits, samples, seed, name, output_dtype)\n"
1045        )
1046    _, unused_report, unused_errors, new_text = self._upgrade(text)
1047    expected_text = (
1048        "tf.random.categorical(logits=logits, num_samples=samples, seed=seed, "
1049        "name=name, dtype=output_dtype)\n"
1050        )
1051    self.assertEqual(new_text, expected_text)
1052
1053  def testRandomPoissonConversion(self):
1054    text1 = "tf.random_poisson(lam, shape, dtype)"
1055    text2 = "tf.random.poisson(lam, shape, dtype)"
1056    expected_text = "tf.random.poisson(lam=lam, shape=shape, dtype=dtype)"
1057    _, unused_report, unused_errors, new_text1 = self._upgrade(text1)
1058    self.assertEqual(new_text1, expected_text)
1059    _, unused_report, unused_errors, new_text2 = self._upgrade(text2)
1060    self.assertEqual(new_text2, expected_text)
1061
1062  def testConvolutionOpUpdate(self):
1063    text = (
1064        "tf.nn.convolution(input, filter, padding, strides, dilation_rate, "
1065        "name, data_format)"
1066    )
1067    _, unused_report, unused_errors, new_text = self._upgrade(text)
1068    expected_text = (
1069        "tf.nn.convolution(input=input, filters=filter, padding=padding, "
1070        "strides=strides, dilations=dilation_rate, name=name, "
1071        "data_format=data_format)"
1072    )
1073    self.assertEqual(new_text, expected_text)
1074
1075  def test_substr(self):
1076    text = "tf.substr(input, pos, len, name, unit)\n"
1077    _, unused_report, errors, new_text = self._upgrade(text)
1078    self.assertEqual("tf.strings.substr(input=input, pos=pos, len=len, "
1079                     "name=name, unit=unit)\n", new_text)
1080    self.assertEqual(errors, [])
1081
1082  def testColocateGradientsWithOps(self):
1083    text = "tf.gradients(yx=a, foo=False)\n"
1084    _, unused_report, errors, new_text = self._upgrade(text)
1085    self.assertEqual(text, new_text)
1086    self.assertEqual(errors, [])
1087
1088    text = "tf.gradients(yx=a, colocate_gradients_with_ops=False)\n"
1089    _, report, unused_errors, new_text = self._upgrade(text)
1090    self.assertEqual("tf.gradients(yx=a)\n", new_text)
1091    self.assertIn("tf.gradients no longer takes", report)
1092
1093    text = "tf.gradients(y, x, grad_ys, name, colocate, gate)\n"
1094    expected = ("tf.gradients(ys=y, xs=x, grad_ys=grad_ys, name=name, "
1095                "gate_gradients=gate)\n")
1096    _, unused_report, errors, new_text = self._upgrade(text)
1097    self.assertEqual(expected, new_text)
1098
1099  def testColocateGradientsWithOpsMinimize(self):
1100    text = "optimizer.minimize(a, foo=False)\n"
1101    _, unused_report, errors, new_text = self._upgrade(text)
1102    self.assertEqual(text, new_text)
1103    self.assertEqual(errors, [])
1104
1105    text = "optimizer.minimize(a, colocate_gradients_with_ops=False)\n"
1106    _, report, unused_errors, new_text = self._upgrade(text)
1107    self.assertEqual("optimizer.minimize(a)\n", new_text)
1108    self.assertIn("Optimizer.minimize no longer takes", report)
1109
1110  def testColocateGradientsWithOpsComputeGradients(self):
1111    text = "optimizer.compute_gradients(a, foo=False)\n"
1112    _, unused_report, errors, new_text = self._upgrade(text)
1113    self.assertEqual(text, new_text)
1114    self.assertEqual(errors, [])
1115
1116    text = "optimizer.compute_gradients(a, colocate_gradients_with_ops=False)\n"
1117    _, report, unused_errors, new_text = self._upgrade(text)
1118    self.assertEqual("optimizer.compute_gradients(a)\n", new_text)
1119    self.assertIn("Optimizer.compute_gradients no longer takes", report)
1120
1121  def testColocateGradientsWithHessians(self):
1122    text = "tf.hessians(ys=a, xs=b, colocate_gradients_with_ops=False)\n"
1123    _, report, unused_errors, new_text = self._upgrade(text)
1124    self.assertEqual("tf.hessians(ys=a, xs=b)\n", new_text)
1125    self.assertIn("tf.hessians no longer takes", report)
1126
1127  def testExportSavedModelRename(self):
1128    text = "self.est.export_savedmodel(path)"
1129    _, report, unused_errors, unused_new_text = self._upgrade(text)
1130    self.assertIn(
1131        "rename the method export_savedmodel() to export_saved_model()",
1132        report)
1133
1134  def testArgmin(self):
1135    text = "tf.argmin(input, name=n, dimension=1, output_type=type)"
1136    expected_text = "tf.argmin(input=input, name=n, axis=1, output_type=type)"
1137    _, unused_report, unused_errors, new_text = self._upgrade(text)
1138    self.assertEqual(new_text, expected_text)
1139
1140    text = "tf.argmin(input, 0)"
1141    expected_text = "tf.argmin(input=input, axis=0)"
1142    _, unused_report, unused_errors, new_text = self._upgrade(text)
1143    self.assertEqual(new_text, expected_text)
1144
1145    text = "tf.arg_min(input, 0)"
1146    expected_text = "tf.argmin(input, 0)"
1147    _, unused_report, unused_errors, new_text = self._upgrade(text)
1148    self.assertEqual(new_text, expected_text)
1149
1150  def testArgmax(self):
1151    text = "tf.argmax(input, name=n, dimension=1, output_type=type)"
1152    expected_text = "tf.argmax(input=input, name=n, axis=1, output_type=type)"
1153    _, unused_report, unused_errors, new_text = self._upgrade(text)
1154    self.assertEqual(new_text, expected_text)
1155
1156    text = "tf.argmax(input, 0)"
1157    expected_text = "tf.argmax(input=input, axis=0)"
1158    _, unused_report, unused_errors, new_text = self._upgrade(text)
1159    self.assertEqual(new_text, expected_text)
1160
1161    text = "tf.arg_max(input, 0)"
1162    expected_text = "tf.argmax(input, 0)"
1163    _, unused_report, unused_errors, new_text = self._upgrade(text)
1164    self.assertEqual(new_text, expected_text)
1165
1166  def testAutograph(self):
1167    text = "tf.autograph.to_graph(f, True, arg_values=None, arg_types=None)"
1168    expected_text = "tf.autograph.to_graph(f, True)"
1169    _, unused_report, unused_errors, new_text = self._upgrade(text)
1170    self.assertEqual(new_text, expected_text)
1171
1172    text = ("tf.autograph.to_code"
1173            "(f, False, arg_values=None, arg_types=None, indentation=' ')")
1174    expected_text = "tf.autograph.to_code(f, False)"
1175    _, unused_report, unused_errors, new_text = self._upgrade(text)
1176    self.assertEqual(new_text, expected_text)
1177
1178  def testEstimatorInputs(self):
1179    text = "tf.estimator.inputs.numpy_input_fn(0)"
1180    expected_text = "tf.compat.v1.estimator.inputs.numpy_input_fn(0)"
1181    _, unused_report, unused_errors, new_text = self._upgrade(text)
1182    self.assertEqual(new_text, expected_text)
1183
1184    text = "tf.estimator.inputs.pandas_input_fn(0)"
1185    expected_text = "tf.compat.v1.estimator.inputs.pandas_input_fn(0)"
1186    _, unused_report, unused_errors, new_text = self._upgrade(text)
1187    self.assertEqual(new_text, expected_text)
1188
1189  def testBatchToSpace(self):
1190    text = "tf.batch_to_space_nd(input, block_shape, crops, name)"
1191    expected_text = "tf.batch_to_space(input, block_shape, crops, name)"
1192    _, unused_report, unused_errors, new_text = self._upgrade(text)
1193    self.assertEqual(new_text, expected_text)
1194
1195    text = "tf.batch_to_space(input, crops, block_size, name)"
1196    expected_text = (
1197        "tf.batch_to_space(input=input, crops=crops, block_shape=block_size, "
1198        "name=name)")
1199    _, unused_report, unused_errors, new_text = self._upgrade(text)
1200    self.assertEqual(new_text, expected_text)
1201
1202    text = "tf.manip.batch_to_space_nd(input, block_shape, crops, name)"
1203    expected_text = "tf.batch_to_space(input, block_shape, crops, name)"
1204    _, unused_report, unused_errors, new_text = self._upgrade(text)
1205    self.assertEqual(new_text, expected_text)
1206
1207  def testExtractImagePatches(self):
1208    text = (
1209        "tf.extract_image_patches(images, ksizes=ksizes, strides=strides,"
1210        "rates=rates, padding=padding, name=name)")
1211    expected_text = (
1212        "tf.image.extract_patches(images, sizes=ksizes, strides=strides,"
1213        "rates=rates, padding=padding, name=name)")
1214    _, unused_report, unused_errors, new_text = self._upgrade(text)
1215    self.assertEqual(new_text, expected_text)
1216
1217  def testKerasSavedModel(self):
1218    text = (
1219        "tf.contrib.saved_model.save_keras_model(model, './saved_models')\n"
1220        "tf.contrib.saved_model.load_keras_model(saved_model_path)\n")
1221    expected_text = (
1222        "tf.compat.v1.keras.experimental.export_saved_model(model, "
1223        "'./saved_models')\ntf.compat.v1.keras.experimental."
1224        "load_from_saved_model(saved_model_path)\n"
1225    )
1226    _, report, unused_errors, new_text = self._upgrade(text)
1227    self.assertEqual(new_text, expected_text)
1228    expected_info = "Please use model.save"
1229    self.assertIn(expected_info, report)
1230
1231  def testStatelessMultinomial(self):
1232    text = (
1233        "tf.random.stateless_multinomial(logits, num_samples, seed, "
1234        "output_dtype=dtype, name=name)")
1235    expected_text = (
1236        "tf.random.stateless_categorical(logits, num_samples, seed, "
1237        "dtype=dtype, name=name)")
1238    _, unused_report, unused_errors, new_text = self._upgrade(text)
1239    self.assertEqual(new_text, expected_text)
1240
1241  def testSoftMaxCrossEntropyWithLogitsV2(self):
1242    text = (
1243        "tf.nn.softmax_cross_entropy_with_logits_v2("
1244        "labels=labels, logits=logits, dim=2)")
1245    expected_text = (
1246        "tf.nn.softmax_cross_entropy_with_logits("
1247        "labels=labels, logits=logits, axis=2)")
1248    _, unused_report, errors, new_text = self._upgrade(text)
1249    self.assertEqual(new_text, expected_text)
1250
1251    self.assertFalse(errors)
1252
1253  def testSoftMaxCrossEntropyWithLogits(self):
1254    text = ("tf.nn.softmax_cross_entropy_with_logits("
1255            "labels=labels, logits=logits, dim=2)")
1256    expected_text = (
1257        "tf.nn.softmax_cross_entropy_with_logits("
1258        "labels=tf.stop_gradient(labels), logits=logits, axis=2)")
1259    _, unused_report, unused_errors, new_text = self._upgrade(text)
1260    self.assertEqual(new_text, expected_text)
1261
1262    text = ("tf.nn.softmax_cross_entropy_with_logits("
1263            "labels=foo(bar))")
1264    expected_text = ("tf.nn.softmax_cross_entropy_with_logits("
1265                     "labels=tf.stop_gradient(foo(bar)))")
1266    _, unused_report, unused_errors, new_text = self._upgrade(text)
1267    self.assertEqual(expected_text, new_text)
1268
1269  def testSoftMaxCrossEntropyWithLogitsDoesntNest(self):
1270    text = ("tf.nn.softmax_cross_entropy_with_logits("
1271            "labels=tf.stop_gradient(labels), logits=logits, dim=2)")
1272    expected_text = (
1273        "tf.nn.softmax_cross_entropy_with_logits("
1274        "labels=tf.stop_gradient(labels), logits=logits, axis=2)")
1275    _, unused_report, unused_errors, new_text = self._upgrade(text)
1276    self.assertEqual(new_text, expected_text)
1277
1278    text = ("tf.nn.softmax_cross_entropy_with_logits("
1279            "labels=tf.stop_gradient(foo(bar)))")
1280    expected_text = ("tf.nn.softmax_cross_entropy_with_logits("
1281                     "labels=tf.stop_gradient(foo(bar)))")
1282    _, unused_report, unused_errors, new_text = self._upgrade(text)
1283    self.assertEqual(expected_text, new_text)
1284
1285    text = ("tf.nn.softmax_cross_entropy_with_logits("
1286            "labels=foo())")
1287    expected_text = ("tf.nn.softmax_cross_entropy_with_logits("
1288                     "labels=tf.stop_gradient(foo()))")
1289    _, unused_report, unused_errors, new_text = self._upgrade(text)
1290    self.assertEqual(expected_text, new_text)
1291
1292    text = ("tf.nn.softmax_cross_entropy_with_logits("
1293            "labels=foo().zz())")
1294    expected_text = ("tf.nn.softmax_cross_entropy_with_logits("
1295                     "labels=tf.stop_gradient(foo().zz()))")
1296    _, unused_report, unused_errors, new_text = self._upgrade(text)
1297    self.assertEqual(expected_text, new_text)
1298
1299  def testSparseMatmul(self):
1300    text = ("tf.sparse_matmul(a, b, c, d, e, f, g)\n")
1301    expected_text = ("tf.linalg.matmul(a=a, b=b, transpose_a=c, transpose_b=d, "
1302                     "a_is_sparse=e, b_is_sparse=f, name=g)\n")
1303    _, unused_report, unused_errors, new_text = self._upgrade(text)
1304    self.assertEqual(new_text, expected_text)
1305
1306  def testWeightedMoments(self):
1307    text = "tf.nn.weighted_moments(x, axes, freq, name, kd)"
1308    expected_text = (
1309        "tf.nn.weighted_moments(x=x, axes=axes, frequency_weights=freq, "
1310        "name=name, keepdims=kd)")
1311    _, unused_report, unused_errors, new_text = self._upgrade(text)
1312    self.assertEqual(new_text, expected_text)
1313
1314  def testSparseAdd(self):
1315    text = "tf.sparse.add(a, b, t)"
1316    expected_text = "tf.sparse.add(a=a, b=b, threshold=t)"
1317    _, unused_report, unused_errors, new_text = self._upgrade(text)
1318    self.assertEqual(new_text, expected_text)
1319
1320  def testSparseConcat(self):
1321    text = "tf.sparse.concat(ax, inp, name, exp, concat)"
1322    expected_text = (
1323        "tf.sparse.concat(axis=ax, sp_inputs=inp, name=name, "
1324        "expand_nonconcat_dims=exp, axis=concat)")
1325    _, unused_report, unused_errors, new_text = self._upgrade(text)
1326    self.assertEqual(new_text, expected_text)
1327
1328  def testSeparableConv2D(self):
1329    text = "tf.nn.separable_conv2d(inp, d, pt, strides, pad, rate, name, fmt)"
1330    expected_text = (
1331        "tf.nn.separable_conv2d(input=inp, depthwise_filter=d, "
1332        "pointwise_filter=pt, strides=strides, padding=pad, "
1333        "dilations=rate, name=name, data_format=fmt)")
1334    _, unused_report, unused_errors, new_text = self._upgrade(text)
1335    self.assertEqual(new_text, expected_text)
1336
1337  def testConv2D(self):
1338    text = (
1339        "tf.nn.conv2d(input, filter, strides, padding, use_cudnn_on_gpu, "
1340        "data_format)")
1341    expected_text = (
1342        "tf.nn.conv2d(input=input, filters=filter, strides=strides, "
1343        "padding=padding, data_format=data_format)")
1344    _, unused_report, unused_errors, new_text = self._upgrade(text)
1345    self.assertEqual(new_text, expected_text)
1346
1347    text = (
1348        "tf.nn.conv2d(input, filter=filter, strides=strides, padding=padding, "
1349        "use_cudnn_on_gpu=use_cudnn_on_gpu)")
1350    expected_text = ("tf.nn.conv2d(input=input, filters=filter, "
1351                     "strides=strides, padding=padding)")
1352    _, unused_report, unused_errors, new_text = self._upgrade(text)
1353    self.assertEqual(new_text, expected_text)
1354
1355  def testConv2DBackpropFilter(self):
1356    text = (
1357        "tf.nn.conv2d_backprop_filter(input, filter_sizes, out_backprop, "
1358        "strides, padding, use_cudnn_on_gpu, data_format)")
1359    expected_text = (
1360        "tf.compat.v1.nn.conv2d_backprop_filter(input, filter_sizes, "
1361        "out_backprop, strides, padding, use_cudnn_on_gpu, data_format)")
1362    _, unused_report, unused_errors, new_text = self._upgrade(text)
1363    self.assertEqual(new_text, expected_text)
1364
1365  def testConv2DBackpropInput(self):
1366    text = (
1367        "tf.nn.conv2d_backprop_input(input_sizes, filter, out_backprop, "
1368        "strides, padding, use_cudnn_on_gpu, data_format)")
1369    expected_text = (
1370        "tf.nn.conv2d_transpose(output_shape=input_sizes, filters=filter, "
1371        "input=out_backprop, strides=strides, padding=padding, "
1372        "data_format=data_format)")
1373    _, unused_report, unused_errors, new_text = self._upgrade(text)
1374    self.assertEqual(new_text, expected_text)
1375
1376  def testSpacetoBatch(self):
1377    text = "tf.space_to_batch_nd(input, shape, paddings, name)"
1378    expected_text = "tf.space_to_batch(input, shape, paddings, name)"
1379    _, unused_report, unused_errors, new_text = self._upgrade(text)
1380    self.assertEqual(new_text, expected_text)
1381
1382    text = "tf.nn.space_to_batch(input, paddings, block_size, name)"
1383    expected_text = (
1384        "tf.space_to_batch(input=input, paddings=paddings, "
1385        "block_shape=block_size, name=name)")
1386    _, unused_report, unused_errors, new_text = self._upgrade(text)
1387    self.assertEqual(new_text, expected_text)
1388
1389  def testInTopK(self):
1390    text = "tf.math.in_top_k(a, b, c, n)"
1391    expected_text = (
1392        "tf.math.in_top_k(predictions=a, targets=b, k=c, name=n)")
1393    _, unused_report, unused_errors, new_text = self._upgrade(text)
1394    self.assertEqual(new_text, expected_text)
1395
1396  def testDepthToSpace(self):
1397    text = "tf.nn.depth_to_space(input, block_size, name, data_format)"
1398    expected_text = (
1399        "tf.nn.depth_to_space(input=input, block_size=block_size, "
1400        "name=name, data_format=data_format)")
1401    _, unused_report, unused_errors, new_text = self._upgrade(text)
1402    self.assertEqual(new_text, expected_text)
1403
1404  def testEmbeddingLookup(self):
1405    text = ("tf.nn.embedding_lookup(params, ids, partition_strategy, name, "
1406            "validate_indices, max_norm)")
1407    expected_text = ("tf.nn.embedding_lookup(params=params, ids=ids, "
1408                     "partition_strategy=partition_strategy, name=name, "
1409                     "max_norm=max_norm)")
1410    _, unused_report, unused_errors, new_text = self._upgrade(text)
1411    self.assertEqual(new_text, expected_text)
1412
1413  def testEmbeddingLookupSparse(self):
1414    text = ("tf.nn.embedding_lookup_sparse(params, sp_ids, sp_weights, "
1415            "partition_strategy, name, combiner, max_norm)")
1416    expected_text = ("tf.nn.embedding_lookup_sparse(params=params, "
1417                     "sp_ids=sp_ids, sp_weights=sp_weights, "
1418                     "partition_strategy=partition_strategy, name=name, "
1419                     "combiner=combiner, max_norm=max_norm)")
1420    _, unused_report, unused_errors, new_text = self._upgrade(text)
1421    self.assertEqual(new_text, expected_text)
1422
1423  def testNnInTopK(self):
1424    text = "tf.nn.in_top_k(predictions, targets, k, name)"
1425    expected_text = ("tf.nn.in_top_k(predictions=predictions, "
1426                     "targets=targets, k=k, name=name)")
1427    _, unused_report, unused_errors, new_text = self._upgrade(text)
1428    self.assertEqual(new_text, expected_text)
1429
1430  def testSpaceToDepth(self):
1431    text = "tf.nn.space_to_depth(input, block_size, name, data_format)"
1432    expected_text = ("tf.nn.space_to_depth(input=input, block_size=block_size, "
1433                     "name=name, data_format=data_format)")
1434    _, unused_report, unused_errors, new_text = self._upgrade(text)
1435    self.assertEqual(new_text, expected_text)
1436
1437  def testPrint(self):
1438    # tf.print() cannot be parsed unless we import print_function
1439    text = """from __future__ import print_function
1440tf.print()
1441tf.print('abc')
1442"""
1443    _, unused_report, unused_errors, new_text = self._upgrade(text)
1444    self.assertEqual(new_text, text)  # Text should stay the same
1445
1446  def testSparseSplit(self):
1447    text = (
1448        "tf.sparse_split(sp_input=sp_input, num_split=num_split, axis=axis, "
1449        "name=name)")
1450    expected_text = (
1451        "tf.sparse.split(sp_input=sp_input, num_split=num_split, axis=axis, "
1452        "name=name)")
1453    _, unused_report, unused_errors, new_text = self._upgrade(text)
1454    self.assertEqual(new_text, expected_text)
1455
1456    text = (
1457        "tf.sparse_split(sp_input=sp_input, num_split=num_split, "
1458        "name=name, split_dim=axis)")
1459    expected_text = (
1460        "tf.sparse.split(sp_input=sp_input, num_split=num_split, "
1461        "name=name, axis=axis)")
1462    _, unused_report, unused_errors, new_text = self._upgrade(text)
1463    self.assertEqual(new_text, expected_text)
1464
1465    text = (
1466        "tf.sparse.split(sp_input=sp_input, num_split=num_split, "
1467        "name=name, split_dim=axis)")
1468    expected_text = (
1469        "tf.sparse.split(sp_input=sp_input, num_split=num_split, "
1470        "name=name, axis=axis)")
1471    _, unused_report, unused_errors, new_text = self._upgrade(text)
1472    self.assertEqual(new_text, expected_text)
1473
1474  def testIterators(self):
1475    for (text, expected) in [
1476        ("(expr + yielding(data)).make_one_shot_iterator()",
1477         "tf.compat.v1.data.make_one_shot_iterator((expr + yielding(data)))"),
1478        ("dataset.make_one_shot_iterator()",
1479         "tf.compat.v1.data.make_one_shot_iterator(dataset)"),
1480        ("dataset.make_one_shot_iterator(shared_name=foo)",
1481         "tf.compat.v1.data.make_one_shot_iterator(dataset, shared_name=foo)"),
1482        ("dataset.make_one_shot_iterator(x, y, z)",
1483         "tf.compat.v1.data.make_one_shot_iterator(dataset, x, y, z)"),
1484        ("dataset.make_initializable_iterator()",
1485         "tf.compat.v1.data.make_initializable_iterator(dataset)"),
1486        ("ds.make_initializable_iterator(shared_name=foo)",
1487         "tf.compat.v1.data.make_initializable_iterator(ds, shared_name=foo)"),
1488        ("dataset.make_initializable_iterator(x, y, z)",
1489         "tf.compat.v1.data.make_initializable_iterator(dataset, x, y, z)"),
1490        ("tf.data.make_one_shot_iterator(dataset)",
1491         "tf.compat.v1.data.make_one_shot_iterator(dataset)"),
1492        ("tf.data.make_one_shot_iterator(dataset, shared_name=foo)",
1493         "tf.compat.v1.data.make_one_shot_iterator(dataset, shared_name=foo)"),
1494        ("tf.data.make_one_shot_iterator(dataset, x, y, z)",
1495         "tf.compat.v1.data.make_one_shot_iterator(dataset, x, y, z)"),
1496        ("tf.data.make_initializable_iterator(dataset)",
1497         "tf.compat.v1.data.make_initializable_iterator(dataset)"),
1498        ("tf.data.make_initializable_iterator(ds, shared_name=foo)",
1499         "tf.compat.v1.data.make_initializable_iterator(ds, shared_name=foo)"),
1500        ("tf.data.make_initializable_iterator(dataset, x, y, z)",
1501         "tf.compat.v1.data.make_initializable_iterator(dataset, x, y, z)"),
1502        ("tf.compat.v1.data.make_one_shot_iterator(dataset)",
1503         "tf.compat.v1.data.make_one_shot_iterator(dataset)"),
1504        ("tf.compat.v1.data.make_one_shot_iterator(dataset, shared_name=foo)",
1505         "tf.compat.v1.data.make_one_shot_iterator(dataset, shared_name=foo)"),
1506        ("tf.compat.v1.data.make_one_shot_iterator(dataset, x, y, z)",
1507         "tf.compat.v1.data.make_one_shot_iterator(dataset, x, y, z)"),
1508        ("tf.compat.v1.data.make_initializable_iterator(dataset)",
1509         "tf.compat.v1.data.make_initializable_iterator(dataset)"),
1510        ("tf.compat.v1.data.make_initializable_iterator(ds, shared_name=foo)",
1511         "tf.compat.v1.data.make_initializable_iterator(ds, shared_name=foo)"),
1512        ("tf.compat.v1.data.make_initializable_iterator(dataset, x, y, z)",
1513         "tf.compat.v1.data.make_initializable_iterator(dataset, x, y, z)")]:
1514      _, unused_report, unused_errors, actual = self._upgrade(text)
1515      self.assertEqual(actual, expected)
1516
1517  def testStructure(self):
1518    for (text, expected) in [
1519        ("tf.data.experimental.DatasetStructure", "tf.data.DatasetSpec"),
1520        ("tf.data.experimental.OptionalStructure", "tf.OptionalSpec"),
1521        ("tf.data.experimental.RaggedTensorStructure", "tf.RaggedTensorSpec"),
1522        ("tf.data.experimental.SparseTensorStructure", "tf.SparseTensorSpec"),
1523        ("tf.data.experimental.Structure", "tf.TypeSpec"),
1524        ("tf.data.experimental.TensorArrayStructure", "tf.TensorArraySpec"),
1525        ("tf.data.experimental.TensorStructure", "tf.TensorSpec"),
1526    ]:
1527      _, unused_report, unused_errors, actual = self._upgrade(text)
1528      self.assertEqual(actual, expected)
1529
1530  def testMapAndBatch(self):
1531    suffix = ".data.experimental.map_and_batch_with_legacy_function(args)"
1532    text = "tf" + suffix
1533    expected = "tf.compat.v1" + suffix
1534    _, unused_report, unused_errors, actual = self._upgrade(text)
1535    self.assertEqual(actual, expected)
1536
1537  def testCast(self):
1538    for (name, dtype) in [("int32", "int32"),
1539                          ("int64", "int64"),
1540                          ("float", "float32"),
1541                          ("double", "float64"),
1542                          ("complex64", "complex64"),
1543                          ("complex128", "complex128"),
1544                          ("bfloat16", "bfloat16")]:
1545      text = "tf.to_%s(x, name='test')" % name
1546      expected_text = "tf.cast(x, name='test', dtype=tf.%s)" % dtype
1547      _, unused_report, unused_errors, new_text = self._upgrade(text)
1548      self.assertEqual(expected_text, new_text)
1549
1550  def testCastPositionalSecondArgument(self):
1551    for (name, dtype) in [("int32", "int32"),
1552                          ("int64", "int64"),
1553                          ("float", "float32"),
1554                          ("double", "float64"),
1555                          ("complex64", "complex64"),
1556                          ("complex128", "complex128"),
1557                          ("bfloat16", "bfloat16")]:
1558      text = "tf.to_%s(x, 'test')" % name
1559      expected_text = "tf.cast(x, name='test', dtype=tf.%s)" % dtype
1560      _, unused_report, unused_errors, new_text = self._upgrade(text)
1561      self.assertEqual(expected_text, new_text)
1562
1563  def testImageResize(self):
1564    for method in ["bilinear", "area", "bicubic", "nearest_neighbor"]:
1565      text = "tf.image.resize_%s(i, s)" % method
1566      expected_text = ("tf.image.resize(i, s, "
1567                       "method=tf.image.ResizeMethod.%s)" % method.upper())
1568      _, unused_report, unused_errors, new_text = self._upgrade(text)
1569      self.assertEqual(expected_text, new_text)
1570
1571  def testImageResizeExtraPositionalArgs(self):
1572    for method in ["bilinear", "area", "bicubic", "nearest_neighbor"]:
1573      text = "tf.image.resize_%s(i, s, a, p)" % method
1574      expected_text = [
1575          "tf.image.resize(i, s, ", "preserve_aspect_ratio=p, ",
1576          "method=tf.image.ResizeMethod.%s)" % method.upper()
1577      ]
1578      _, unused_report, unused_errors, new_text = self._upgrade(text)
1579      for s in expected_text:
1580        self.assertIn(s, new_text)
1581
1582  def testCond(self):
1583    text = "tf.cond(a, b, c, True)"
1584    expected_text = "tf.cond(pred=a, true_fn=b, false_fn=c)"
1585    _, unused_report, errors, new_text = self._upgrade(text)
1586    self.assertEqual(expected_text, new_text)
1587    self.assertIn("tf.cond", errors[0])
1588    self.assertIn("requires manual check", errors[0])
1589
1590  def testParens(self):
1591    text = """
1592def _log_prob(self, x):
1593  return tf.reduce_logsumexp(
1594      (self.mixture_distribution.logits + self.distribution.log_prob(
1595          x[..., tf.newaxis])),
1596          axis=-1)"""
1597    expected_text = """
1598def _log_prob(self, x):
1599  return tf.reduce_logsumexp(
1600      input_tensor=(self.mixture_distribution.logits + self.distribution.log_prob(
1601          x[..., tf.newaxis])),
1602          axis=-1)"""
1603    _, unused_report, unused_errors, new_text = self._upgrade(text)
1604    self.assertEqual(expected_text, new_text)
1605
1606  def testAssertStatements(self):
1607    for name in ["assert_greater", "assert_equal", "assert_none_equal",
1608                 "assert_less", "assert_negative", "assert_positive",
1609                 "assert_non_negative", "assert_non_positive", "assert_near",
1610                 "assert_less", "assert_less_equal", "assert_greater",
1611                 "assert_greater_equal", "assert_integer", "assert_type",
1612                 "assert_scalar"]:
1613      text = "tf.%s(a)" % name
1614      expected_text = "tf.compat.v1.%s(a)" % name
1615      _, report, unused_errors, new_text = self._upgrade(text)
1616      self.assertEqual(expected_text, new_text)
1617      self.assertIn("%s has been" % name, report)
1618
1619      text = "tf.debugging.%s(a)" % name
1620      expected_text = "tf.compat.v1.debugging.%s(a)" % name
1621      _, report, unused_errors, new_text = self._upgrade(text)
1622      self.assertEqual(expected_text, new_text)
1623      self.assertIn("%s has been" % name, report)
1624
1625  def testAssertRankStatements(self):
1626    for name in ["assert_rank", "assert_rank_at_least", "assert_rank_in"]:
1627      text = "tf.%s(a)" % name
1628      expected_text = "tf.compat.v1.%s(a)" % name
1629      _, report, unused_errors, new_text = self._upgrade(text)
1630      self.assertEqual(expected_text, new_text)
1631      self.assertIn("%s has been" % name, report)
1632
1633      text = "tf.debugging.%s(a)" % name
1634      expected_text = "tf.compat.v1.debugging.%s(a)" % name
1635      _, report, unused_errors, new_text = self._upgrade(text)
1636      self.assertEqual(expected_text, new_text)
1637      self.assertIn("%s has been" % name, report)
1638
1639  def test_assert_equal_graph_def(self):
1640    text = ("tf.test.assert_equal_graph_def(a, b, checkpoint_v2=x, "
1641            "hash_table_shared_name=y)")
1642    expected = "tf.test.assert_equal_graph_def(actual=a, expected=b)"
1643    _, _, _, new_text = self._upgrade(text)
1644    self.assertEqual(expected, new_text)
1645
1646  def test_is_tensor_upgrade(self):
1647    text = "tf.contrib.framework.is_tensor(x)"
1648    expected = "tf.is_tensor(x)"
1649    _, _, _, new_text = self._upgrade(text)
1650    self.assertEqual(expected, new_text)
1651
1652  def test_is_tensor_direct_import_upgrade(self):
1653    text = "contrib_framework.is_tensor(x)"
1654    expected = "tf.is_tensor(x)"
1655    _, _, _, new_text = self._upgrade(text)
1656    self.assertEqual(expected, new_text)
1657
1658  def test_CriticalSection_upgrade(self):
1659    text = "tf.contrib.framework.CriticalSection(shared_name='blah')"
1660    expected = "tf.CriticalSection(shared_name='blah')"
1661    _, _, _, new_text = self._upgrade(text)
1662    self.assertEqual(expected, new_text)
1663
1664  def test_sample_distorted_bounding_box(self):
1665    # pylint: disable=line-too-long
1666    text = "tf.image.sample_distorted_bounding_box(a, b, c, d, e, f, g, h, i, j)"
1667    expected = "tf.image.sample_distorted_bounding_box(image_size=a, bounding_boxes=b, seed=c, min_object_covered=e, aspect_ratio_range=f, area_range=g, max_attempts=h, use_image_if_no_bounding_boxes=i, name=j)"
1668    # pylint: enable=line-too-long
1669    _, _, _, new_text = self._upgrade(text)
1670    self.assertEqual(expected, new_text)
1671
1672  def test_contrib_initialize(self):
1673    text = "tf.contrib.summary.initialize"
1674    expected = "tf.compat.v1.summary.initialize"
1675    _, _, _, new_text = self._upgrade(text)
1676    self.assertEqual(expected, new_text)
1677
1678  def test_contrib_framework_argsort(self):
1679    text = "tf.contrib.framework.argsort"
1680    expected = "tf.argsort"
1681    # pylint: enable=line-too-long
1682    _, _, _, new_text = self._upgrade(text)
1683    self.assertEqual(expected, new_text)
1684
1685  def test_flags_bare(self):
1686    _, _, errors, _ = self._upgrade("tf.flags")
1687    self.assertIn("tf.flags and tf.app.flags have been removed", errors[0])
1688
1689  def test_flags_flags(self):
1690    _, _, errors, _ = self._upgrade("tf.flags.FLAGS")
1691    self.assertIn("tf.flags and tf.app.flags have been removed", errors[0])
1692
1693  def test_contrib_estimator_head_deprecation(self):
1694    for contrib_alias in ["tf.contrib.", "contrib_"]:
1695      api_symbols = ["binary_classification_head", "logistic_regression_head",
1696                     "multi_class_head", "multi_head", "multi_label_head",
1697                     "poisson_regression_head", "regression_head"]
1698      for symbol in api_symbols:
1699        text = contrib_alias + "estimator." + symbol
1700        _, report, _, _ = self._upgrade(text)
1701        self.assertIn("`tf.contrib.estimator.*_head` has been deprecated",
1702                      report)
1703
1704  def test_contrib_layers_layer_norm_deprecation(self):
1705    for contrib_alias in ["tf.contrib.", "contrib_"]:
1706      _, report, _, _ = self._upgrade(contrib_alias + "layers.layer_norm")
1707      self.assertIn(
1708          "`tf.contrib.layers.layer_norm` has been deprecated", report)
1709
1710  def test_contrib_rnn_deprecation(self):
1711    _, report, _, _ = self._upgrade("tf.contrib.rnn")
1712    self.assertIn("tf.contrib.rnn.* has been deprecated", report)
1713
1714  def test_contrib_cudnn_rnn_deprecation(self):
1715    _, report, _, _ = self._upgrade("tf.contrib.cudnn_rnn")
1716    self.assertIn("tf.contrib.cudnn_rnn.* has been deprecated", report)
1717
1718  def test_max_pool_2d(self):
1719    text = "tf.nn.max_pool(value=4)"
1720    expected_text = "tf.nn.max_pool2d(input=4)"
1721    _, _, _, new_text = self._upgrade(text)
1722    self.assertEqual(expected_text, new_text)
1723
1724  def test_contrib_estimator_early_stopping(self):
1725    for contrib_alias in ["tf.contrib.", "contrib_"]:
1726      api_symbols = [
1727          "make_early_stopping_hook", "stop_if_higher_hook",
1728          "stop_if_lower_hook",
1729          "stop_if_no_decrease_hook", "stop_if_no_increase_hook"
1730      ]
1731      for symbol in api_symbols:
1732        text = contrib_alias + "estimator." + symbol
1733        expected_text = "tf.estimator.experimental." + symbol
1734        _, _, _, new_text = self._upgrade(text)
1735        self.assertEqual(expected_text, new_text)
1736
1737  def test_contrib_rnn_cell(self):
1738    api_symbols = ["RNNCell", "BasicLSTMCell", "BasicRNNCell", "GRUCell",
1739                   "LSTMCell", "MultiRNNCell"]
1740    for symbol in api_symbols:
1741      text = "tf.contrib.rnn." + symbol
1742      expected_text = "tf.compat.v1.nn.rnn_cell." + symbol
1743      _, _, _, new_text = self._upgrade(text)
1744      self.assertEqual(expected_text, new_text)
1745
1746  def test_contrib_rnn_function(self):
1747    api_symbols = ["static_rnn", "static_state_saving_rnn",
1748                   "static_bidirectional_rnn"]
1749    for symbol in api_symbols:
1750      text = "tf.contrib.rnn." + symbol
1751      expected_text = "tf.compat.v1.nn." + symbol
1752      _, _, _, new_text = self._upgrade(text)
1753      self.assertEqual(expected_text, new_text)
1754
1755  def test_contrib_summary_generic(self):
1756    text = "tf.contrib.summary.generic('foo', myval, meta, 'fam', 42)"
1757    expected = ("tf.compat.v2.summary.write(tag='foo', data=myval, "
1758                "metadata=meta, step=42)")
1759    _, _, errors, new_text = self._upgrade(text)
1760    self.assertEqual(expected, new_text)
1761    # Arg errors come in alphabetical order of arguments, not appearance order.
1762    self.assertIn("'family' argument", errors[0])
1763    self.assertIn("'name' argument", errors[1])
1764    self.assertIn("tf.compat.v2.summary.*", errors[2])
1765
1766  def test_contrib_summary_audio(self):
1767    text = "tf.contrib.summary.audio('foo', myval, 44100, 3, 'fam', 42)"
1768    expected = ("tf.compat.v2.summary.audio(name='foo', data=myval, "
1769                "sample_rate=44100, max_outputs=3, step=42)")
1770    _, _, errors, new_text = self._upgrade(text)
1771    self.assertEqual(expected, new_text)
1772    self.assertIn("'family' argument", errors[0])
1773    self.assertIn("tf.compat.v2.summary.*", errors[1])
1774
1775  def test_contrib_summary_histogram(self):
1776    text = "tf.contrib.summary.histogram('foo', myval, 'fam', 42)"
1777    expected = ("tf.compat.v2.summary.histogram(name='foo', data=myval, "
1778                "step=42)")
1779    _, _, errors, new_text = self._upgrade(text)
1780    self.assertEqual(expected, new_text)
1781    self.assertIn("'family' argument", errors[0])
1782    self.assertIn("tf.compat.v2.summary.*", errors[1])
1783
1784  def test_contrib_summary_image(self):
1785    text = "tf.contrib.summary.image('foo', myval, red, 3, 'fam', 42)"
1786    expected = ("tf.compat.v2.summary.image(name='foo', data=myval, "
1787                "max_outputs=3, step=42)")
1788    _, _, errors, new_text = self._upgrade(text)
1789    self.assertEqual(expected, new_text)
1790    self.assertIn("'bad_color' argument", errors[0])
1791    self.assertIn("'family' argument", errors[1])
1792    self.assertIn("tf.compat.v2.summary.*", errors[2])
1793
1794  def test_contrib_summary_scalar(self):
1795    text = "tf.contrib.summary.scalar('foo', myval, 'fam', 42)"
1796    expected = ("tf.compat.v2.summary.scalar(name='foo', data=myval, "
1797                "step=42)")
1798    _, _, errors, new_text = self._upgrade(text)
1799    self.assertEqual(expected, new_text)
1800    self.assertIn("'family' argument", errors[0])
1801    self.assertIn("tf.compat.v2.summary.*", errors[1])
1802
1803  def test_contrib_summary_generic_nostep(self):
1804    text = "tf.contrib.summary.generic('foo', myval)"
1805    expected = ("tf.compat.v2.summary.write(tag='foo', data=myval, "
1806                "step=tf.compat.v1.train.get_or_create_global_step())")
1807    _, _, errors, new_text = self._upgrade(text)
1808    self.assertEqual(expected, new_text)
1809    self.assertIn("'name' argument", errors[0])
1810    self.assertIn("'step' argument", errors[1])
1811    self.assertIn("tf.compat.v2.summary.*", errors[2])
1812
1813  def test_contrib_summary_audio_nostep(self):
1814    text = "tf.contrib.summary.audio('foo', myval, 44100)"
1815    expected = ("tf.compat.v2.summary.audio(name='foo', data=myval, "
1816                "sample_rate=44100, "
1817                "step=tf.compat.v1.train.get_or_create_global_step())")
1818    _, _, errors, new_text = self._upgrade(text)
1819    self.assertEqual(expected, new_text)
1820    self.assertIn("'step' argument", errors[0])
1821    self.assertIn("tf.compat.v2.summary.*", errors[1])
1822
1823  def test_contrib_summary_histogram_nostep(self):
1824    text = "tf.contrib.summary.histogram('foo', myval)"
1825    expected = ("tf.compat.v2.summary.histogram(name='foo', data=myval, "
1826                "step=tf.compat.v1.train.get_or_create_global_step())")
1827    _, _, errors, new_text = self._upgrade(text)
1828    self.assertEqual(expected, new_text)
1829    self.assertIn("'step' argument", errors[0])
1830    self.assertIn("tf.compat.v2.summary.*", errors[1])
1831
1832  def test_contrib_summary_image_nostep(self):
1833    text = "tf.contrib.summary.image('foo', myval)"
1834    expected = ("tf.compat.v2.summary.image(name='foo', data=myval, "
1835                "step=tf.compat.v1.train.get_or_create_global_step())")
1836    _, _, errors, new_text = self._upgrade(text)
1837    self.assertEqual(expected, new_text)
1838    self.assertIn("'step' argument", errors[0])
1839    self.assertIn("tf.compat.v2.summary.*", errors[1])
1840
1841  def test_contrib_summary_scalar_nostep(self):
1842    text = "tf.contrib.summary.scalar('foo', myval)"
1843    expected = ("tf.compat.v2.summary.scalar(name='foo', data=myval, "
1844                "step=tf.compat.v1.train.get_or_create_global_step())")
1845    _, _, errors, new_text = self._upgrade(text)
1846    self.assertEqual(expected, new_text)
1847    self.assertIn("'step' argument", errors[0])
1848    self.assertIn("tf.compat.v2.summary.*", errors[1])
1849
1850  def test_contrib_summary_graph(self):
1851    text = "tf.contrib.summary.graph(my_graph)"
1852    _, _, errors, _ = self._upgrade(text)
1853    expected_error = "tf.compat.v2.summary.trace"
1854    self.assertIn(expected_error, errors[0])
1855
1856  def test_contrib_summary_import_event(self):
1857    text = "tf.contrib.summary.import_event(my_event)"
1858    _, _, errors, _ = self._upgrade(text)
1859    expected_error = "tf.compat.v2.summary.experimental.write_raw_pb"
1860    self.assertIn(expected_error, errors[0])
1861
1862  def test_contrib_summary_flush(self):
1863    text = "tf.contrib.summary.flush(writer=foo)"
1864    expected = "tf.compat.v2.summary.flush(writer=foo)"
1865    _, _, _, new_text = self._upgrade(text)
1866    self.assertEqual(expected, new_text)
1867
1868  def test_contrib_summary_create_file_writer(self):
1869    text = ("tf.contrib.summary.create_file_writer('my_logdir', 0, 1000, "
1870            "'.foo', 'shared-name')")
1871    expected = ("tf.compat.v2.summary.create_file_writer(logdir='my_logdir', "
1872                "max_queue=0, flush_millis=1000, filename_suffix='.foo')")
1873    _, _, errors, new_text = self._upgrade(text)
1874    self.assertEqual(expected, new_text)
1875    self.assertIn("'name' argument", errors[0])
1876    self.assertIn("no longer re-uses existing event files", errors[1])
1877
1878  def test_contrib_summary_always_record_summaries(self):
1879    text = "tf.contrib.summary.always_record_summaries()"
1880    expected = "tf.compat.v2.summary.record_if(True)"
1881    _, _, _, new_text = self._upgrade(text)
1882    self.assertEqual(expected, new_text)
1883
1884  def test_contrib_summary_never_record_summaries(self):
1885    text = "tf.contrib.summary.never_record_summaries()"
1886    expected = "tf.compat.v2.summary.record_if(False)"
1887    _, _, _, new_text = self._upgrade(text)
1888    self.assertEqual(expected, new_text)
1889
1890  def test_contrib_summary_record_summaries_every_n_global_steps(self):
1891    text = "tf.contrib.summary.record_summaries_every_n_global_steps(10)"
1892    _, _, errors, _ = self._upgrade(text)
1893    expected_error = "replaced by a call to tf.compat.v2.summary.record_if()"
1894    self.assertIn(expected_error, errors[0])
1895
1896  def test_contrib_summary_all_summary_ops(self):
1897    text = "tf.contrib.summary.all_summary_ops()"
1898    expected = "tf.compat.v1.summary.all_v2_summary_ops()"
1899    _, _, _, new_text = self._upgrade(text)
1900    self.assertEqual(expected, new_text)
1901
1902  def test_contrib_summary_full_example(self):
1903    deindent = lambda n, s: "\n".join(line[n:] for line in s.split("\n"))
1904    text = deindent(4, """
1905    import tensorflow as tf
1906    tf.enable_eager_execution()
1907    writer = tf.contrib.summary.create_file_writer(
1908        "/tmp/migration_test", flush_millis=1000)
1909    with writer.as_default(), tf.contrib.summary.always_record_summaries():
1910      tf.contrib.summary.scalar("loss", 0.42)
1911      tf.contrib.summary.histogram("weights", [1.0, 2.0], step=7)
1912      tf.contrib.summary.flush()
1913    """)
1914    expected = deindent(4, """
1915    import tensorflow as tf
1916    tf.compat.v1.enable_eager_execution()
1917    writer = tf.compat.v2.summary.create_file_writer(
1918        logdir="/tmp/migration_test", flush_millis=1000)
1919    with writer.as_default(), tf.compat.v2.summary.record_if(True):
1920      tf.compat.v2.summary.scalar(name="loss", data=0.42, step=tf.compat.v1.train.get_or_create_global_step())
1921      tf.compat.v2.summary.histogram(name="weights", data=[1.0, 2.0], step=7)
1922      tf.compat.v2.summary.flush()
1923    """)
1924    _, _, _, new_text = self._upgrade(text)
1925    self.assertEqual(expected, new_text)
1926
1927  def test_summary_api_warning(self):
1928    text = "tf.summary.scalar('foo', 42)"
1929    _, report, _, _ = self._upgrade(text)
1930    expected_info = "TF 1.x summary API cannot be automatically migrated"
1931    self.assertIn(expected_info, report)
1932
1933  def test_avg_pool_2d(self):
1934    text = "tf.nn.avg_pool(value=4)"
1935    expected_text = "tf.nn.avg_pool2d(input=4)"
1936    _, _, _, new_text = self._upgrade(text)
1937    self.assertEqual(expected_text, new_text)
1938
1939  def test_saved_model_load(self):
1940    text = "tf.saved_model.load(sess, ['foo_graph'])"
1941    expected = "tf.compat.v1.saved_model.load(sess, ['foo_graph'])"
1942    _, _, _, new_text = self._upgrade(text)
1943    self.assertEqual(expected, new_text)
1944
1945  def test_saved_model_load_v2(self):
1946    text = "tf.saved_model.load_v2('/tmp/blah')"
1947    expected = "tf.compat.v2.saved_model.load('/tmp/blah')"
1948    _, _, _, new_text = self._upgrade(text)
1949    self.assertEqual(expected, new_text)
1950
1951  def test_app_flags(self):
1952    text = "flags = tf.app.flags"
1953    expected = "flags = tf.compat.v1.app.flags"
1954    _, _, _, new_text = self._upgrade(text)
1955    self.assertEqual(expected, new_text)
1956
1957  def test_uniform_unit_scaling_initializer(self):
1958    text = "tf.uniform_unit_scaling_initializer(0.5)"
1959    expected_text = ("tf.compat.v1.keras.initializers.VarianceScaling("
1960                     "scale=0.5, distribution=\"uniform\")")
1961    _, _, _, new_text = self._upgrade(text)
1962    self.assertEqual(expected_text, new_text)
1963
1964    text = "tf.initializers.uniform_unit_scaling(0.5)"
1965    expected_text = ("tf.compat.v1.keras.initializers.VarianceScaling("
1966                     "scale=0.5, distribution=\"uniform\")")
1967    _, _, _, new_text = self._upgrade(text)
1968    self.assertEqual(expected_text, new_text)
1969
1970  def test_name_scope(self):
1971    text = "tf.name_scope(None, default_name, [some, values])"
1972    expected_text = "tf.name_scope(name=default_name)"
1973    _, _, _, new_text = self._upgrade(text)
1974    self.assertEqual(expected_text, new_text)
1975
1976    text = "tf.name_scope(default_name=default_name, values=stuff)"
1977    expected_text = "tf.name_scope(name=default_name)"
1978    _, _, _, new_text = self._upgrade(text)
1979    self.assertEqual(expected_text, new_text)
1980
1981    text = "tf.name_scope(name=n, default_name=d, values=s)"
1982    expected_text = "tf.compat.v1.name_scope(name=n, default_name=d, values=s)"
1983    _, report, _, new_text = self._upgrade(text)
1984    self.assertEqual(expected_text, new_text)
1985    self.assertIn("`name` passed to `name_scope`", report)
1986
1987    text = "tf.name_scope(name=None, values=stuff)"
1988    _, _, errors, _ = self._upgrade(text)
1989    self.assertIn("name_scope call with neither name nor default_name",
1990                  errors[0])
1991
1992  @parameterized.parameters(
1993      # Rename parameter: delimiter -> sep and add .to_sparse()
1994      ["tf.string_split('test', delimiter=' ')",
1995       "tf.strings.split(input='test', sep=' ').to_sparse()"],
1996      # Rename parameter: source -> input
1997      ["tf.strings.split(source='test1')",
1998       "tf.strings.split(input='test1').to_sparse()"],
1999      # Use compat.v1 for skip_empty parameter.
2000      ["tf.string_split('test', ' ', True)",
2001       "tf.compat.v1.string_split(source='test', sep=' ', skip_empty=True)"],
2002      ["tf.string_split('test', ' ', skip_empty=False)",
2003       "tf.strings.split(input='test', sep=' ').to_sparse()"],
2004      # Split behavior for sep=None changed.  (In particular, it now splits on
2005      # all whitespace, not just the space character)
2006      ["tf.string_split(x)",
2007       "tf.compat.v1.string_split(source=x)"],
2008      # Split behavior for sep='' changed:
2009      ["tf.string_split(x, '')",
2010       "tf.strings.bytes_split(input=x).to_sparse()"],
2011      ["tf.string_split(x, sep='')",
2012       "tf.strings.bytes_split(input=x).to_sparse()"],
2013      ["tf.string_split(x, delimiter='')",
2014       "tf.strings.bytes_split(input=x).to_sparse()"],
2015      ["tf.string_split(x, '', result_type='RaggedTensor')",
2016       "tf.strings.bytes_split(input=x)"],
2017      # If sep is a variable, we can't tell if it's empty:
2018      ["tf.string_split(x, sep)",
2019       "tf.compat.v1.string_split(source=x, sep=sep)"],
2020      # If sep is a non-empty string literal, then we don't need compat.v1.
2021      ["tf.string_split(x, 'non-empty-sep')",
2022       "tf.strings.split(input=x, sep='non-empty-sep').to_sparse()"],
2023      # Add to_sparse unless result_type is RaggedTensor:
2024      ["tf.string_split(x, ' ')",
2025       "tf.strings.split(input=x, sep=' ').to_sparse()"],
2026      ["tf.string_split(x, ' ', result_type='SparseTensor')",
2027       "tf.strings.split(input=x, sep=' ').to_sparse()"],
2028      ["tf.string_split(x, ' ', result_type='RaggedTensor')",
2029       "tf.strings.split(input=x, sep=' ')"],
2030      ["tf.string_split(x, ' ', result_type=x)",
2031       "tf.compat.v1.string_split(source=x, sep=' ', result_type=x)"],
2032  )  # pyformat: disable
2033  # TODO(b/129398290)
2034  def DISABLED_test_string_split(self, text, expected_text):
2035    """Tests for transforming from tf.string_split."""
2036    _, _, _, new_text = self._upgrade(text)
2037    self.assertEqual(expected_text, new_text)
2038
2039  @parameterized.parameters(
2040      # Add to_sparse unless result_type is RaggedTensor:
2041      ["tf.strings.split(x, sep)",
2042       "tf.strings.split(x, sep).to_sparse()"],
2043      ["tf.strings.split(x, sep, result_type='SparseTensor')",
2044       "tf.strings.split(x, sep).to_sparse()"],
2045      ["tf.strings.split(x, sep, result_type='RaggedTensor')",
2046       "tf.strings.split(x, sep)"],
2047      ["tf.strings.split(x, sep, result_type=x)",
2048       "tf.compat.v1.strings.split(x, sep, result_type=x)"],
2049  )  # pyformat: disable
2050  def test_strings_split(self, text, expected_text):
2051    """Tests for transforming from tf.strings.split."""
2052    _, _, _, new_text = self._upgrade(text)
2053    self.assertEqual(expected_text, new_text)
2054
2055  def test_sdca_to_raw_ops(self):
2056    text = "tf.train.sdca_fprint(input_tensor)"
2057    expected_text = "tf.raw_ops.SdcaFprint(input=input_tensor)"
2058    _, _, _, new_text = self._upgrade(text)
2059    self.assertEqual(expected_text, new_text)
2060
2061    text = "tf.train.sdca_fprint(input, name=n)"
2062    expected_text = "tf.raw_ops.SdcaFprint(input=input, name=n)"
2063    _, _, _, new_text = self._upgrade(text)
2064    self.assertEqual(expected_text, new_text)
2065
2066    text = "tf.train.sdca_shrink_l1(w, l, ll)"
2067    expected_text = "tf.raw_ops.SdcaShrinkL1(weights=w, l1=l, l2=ll)"
2068    _, _, _, new_text = self._upgrade(text)
2069    self.assertEqual(expected_text, new_text)
2070
2071    text = (
2072        "tf.train.sdca_optimizer(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o)")
2073    expected_text = (
2074        "tf.raw_ops.SdcaOptimizer(sparse_example_indices=a, "
2075        "sparse_feature_indices=b, sparse_feature_values=c, dense_features=d, "
2076        "example_weights=e, example_labels=f, sparse_indices=g, "
2077        "sparse_weights=h, dense_weights=i, example_state_data=j, loss_type=k, "
2078        "l1=l, l2=m, num_loss_partitions=n, num_inner_iterations=o)")
2079    _, _, _, new_text = self._upgrade(text)
2080    self.assertEqual(expected_text, new_text)
2081
2082  def test_contrib_to_addons_move(self):
2083    small_mapping = {
2084        "tf.contrib.layers.poincare_normalize":
2085            "tfa.layers.PoincareNormalize",
2086        "tf.contrib.layers.maxout":
2087            "tfa.layers.Maxout",
2088        "tf.contrib.layers.group_norm":
2089            "tfa.layers.GroupNormalization",
2090        "tf.contrib.layers.instance_norm":
2091            "tfa.layers.InstanceNormalization",
2092    }
2093    for symbol, replacement in small_mapping.items():
2094      text = "{}('stuff', *args, **kwargs)".format(symbol)
2095      _, report, _, _ = self._upgrade(text)
2096      self.assertIn(replacement, report)
2097
2098  def testXlaExperimental(self):
2099    text = "tf.xla.experimental.jit_scope(0)"
2100    expected_text = "tf.xla.experimental.jit_scope(0)"
2101    _, _, _, new_text = self._upgrade(text)
2102    self.assertEqual(new_text, expected_text)
2103
2104    text = "tf.xla.experimental.compile(0)"
2105    expected_text = "tf.xla.experimental.compile(0)"
2106    _, _, _, new_text = self._upgrade(text)
2107    self.assertEqual(new_text, expected_text)
2108
2109  def testNnErosion2d(self):
2110    text = "tf.nn.erosion2d(v, k, s, r, p)"
2111    expected_text = "tf.nn.erosion2d(v, k, s, r, p, data_format='NHWC')"
2112    _, _, _, new_text = self._upgrade(text)
2113    self.assertEqual(new_text, expected_text)
2114
2115  def testNnDilation2d(self):
2116    text = "tf.nn.dilation2d(v, k, s, r, p)"
2117    expected_text = "tf.nn.dilation2d(v, k, s, r, p, data_format='NHWC')"
2118    _, _, _, new_text = self._upgrade(text)
2119    self.assertEqual(new_text, expected_text)
2120
2121  def testPywrapTensorflowWarning(self):
2122    text = "tf.pywrap_tensorflow.foo()"
2123    expected = "tf.pywrap_tensorflow.foo()"
2124    _, _, errors, new_text = self._upgrade(text)
2125    self.assertEqual(expected, new_text)
2126    self.assertIn("`tf.pywrap_tensorflow` will not be distributed", errors[0])
2127
2128  def testKerasSaveModelFormat(self):
2129    text = "tf.keras.models.save_model(model, path)"
2130    expected_text = "tf.keras.models.save_model(model, path, save_format='h5')"
2131    _, report, _, new_text = self._upgrade(text)
2132    self.assertEqual(new_text, expected_text)
2133    self.assertNotIn(
2134        "saves to the Tensorflow SavedModel format by default", report)
2135
2136    _, report, _, _ = self._upgrade("model.save(path)")
2137    self.assertIn(
2138        "saves to the Tensorflow SavedModel format by default", report)
2139
2140  def test_distribute_strategy(self):
2141    text = "tf.contrib.distribute.CrossDeviceOps()"
2142    expected = "tf.distribute.CrossDeviceOps()"
2143    _, _, _, new_text = self._upgrade(text)
2144    self.assertEqual(expected, new_text)
2145
2146    text = "tf.contrib.distribute.MirroredStrategy"
2147    expected = "tf.contrib.distribute.MirroredStrategy"
2148    _, _, errors, new_text = self._upgrade(text)
2149    self.assertEqual(expected, new_text)
2150    self.assertIn("migrated to tf.distribute.MirroredStrategy", errors[0])
2151
2152    text = "tf.distribute.MirroredStrategy"
2153    expected = "tf.distribute.MirroredStrategy"
2154    _, report, _, new_text = self._upgrade(text)
2155    self.assertEqual(expected, new_text)
2156    self.assertIn("tf.distribute.MirroredStrategy API has changed", report)
2157    self.assertIn("make_dataset_iterator->experimental_distribute_dataset",
2158                  report)
2159
2160    text = "tf.contrib.distribute.TPUStrategy"
2161    expected = "tf.contrib.distribute.TPUStrategy"
2162    _, _, errors, new_text = self._upgrade(text)
2163    self.assertEqual(expected, new_text)
2164    self.assertIn("migrated to tf.distribute.TPUStrategy",
2165                  errors[0])
2166
2167    text = "tf.contrib.distribute.foo"
2168    expected = "tf.contrib.distribute.foo"
2169    _, report, _, new_text = self._upgrade(text)
2170    self.assertEqual(expected, new_text)
2171    self.assertIn("tf.contrib.distribute.* have been migrated", report)
2172
2173  def test_decode_raw(self):
2174    text = "tf.io.decode_raw(bytes=[1,2,3], output_dtype=tf.int32)"
2175    expected_text = (
2176        "tf.io.decode_raw(input_bytes=[1,2,3], output_dtype=tf.int32)")
2177    _, _, _, new_text = self._upgrade(text)
2178    self.assertEqual(expected_text, new_text)
2179
2180  def testRecomputeGrad(self):
2181    text = "tf.contrib.layers.recompute_grad()"
2182    expected = "tf.recompute_grad()"
2183    _, _, _, new_text = self._upgrade(text)
2184    self.assertEqual(expected, new_text)
2185
2186  def test_load_variable(self):
2187    text = "tf.contrib.framework.load_variable('a')"
2188    expected_text = (
2189        "tf.train.load_variable('a')")
2190    _, _, _, new_text = self._upgrade(text)
2191    self.assertEqual(expected_text, new_text)
2192    text = "tf.contrib.framework.load_variable(checkpoint_dir='a')"
2193    expected_text = (
2194        "tf.train.load_variable(ckpt_dir_or_file='a')")
2195    _, _, _, new_text = self._upgrade(text)
2196    self.assertEqual(expected_text, new_text)
2197
2198  def test_import_rename_analysis(self):
2199    old_symbol = "tf.conj(a)"
2200    new_symbol = "tf.math.conj(a)"
2201
2202    import_header = "import tensorflow as tf\n"
2203    text = import_header + old_symbol
2204    expected_text = "import tensorflow.compat.v2 as tf\n" + new_symbol
2205    _, unused_report, unused_errors, new_text = self._upgrade(
2206        text, import_rename=True)
2207    self.assertEqual(new_text, expected_text)
2208
2209    import_header = "import tensorflow as tf, other_import as y\n"
2210    text = import_header + old_symbol
2211    new_import_header = "import tensorflow.compat.v2 as tf, other_import as y\n"
2212    expected_text = new_import_header + new_symbol
2213    _, unused_report, unused_errors, new_text = self._upgrade(
2214        text, import_rename=True)
2215    self.assertEqual(new_text, expected_text)
2216
2217    import_header = ("import tensorflow as tf\n"
2218                     "import tensorflow.compat.v1 as tf_v1\n"
2219                     "import tensorflow.compat.v2 as tf_v2\n")
2220    text = import_header + old_symbol
2221    expected_header = ("import tensorflow.compat.v2 as tf\n"
2222                       "import tensorflow.compat.v1 as tf_v1\n"
2223                       "import tensorflow.compat.v2 as tf_v2\n")
2224    expected_text = expected_header + new_symbol
2225    _, _, _, new_text = self._upgrade(text, import_rename=True)
2226    self.assertEqual(new_text, expected_text)
2227
2228    import_header = ("import tensorflow.compat.v1 as tf\n"
2229                     "import tensorflow.compat.v1 as tf_v1\n"
2230                     "import tensorflow.compat.v2 as tf_v2\n")
2231    text = import_header + old_symbol
2232    expected_header = ("import tensorflow.compat.v2 as tf\n"
2233                       "import tensorflow.compat.v1 as tf_v1\n"
2234                       "import tensorflow.compat.v2 as tf_v2\n")
2235    expected_text = expected_header + new_symbol
2236    _, _, _, new_text = self._upgrade(
2237        text, import_rename=True, upgrade_compat_v1_import=True)
2238    self.assertEqual(new_text, expected_text)
2239
2240    import_header = ("import tensorflow.compat.v1 as tf\n"
2241                     "import tensorflow.compat.v1 as tf_v1\n"
2242                     "import tensorflow.compat.v2 as tf_v2\n")
2243    text = import_header + old_symbol
2244    expected_header = ("import tensorflow as tf\n"
2245                       "import tensorflow.compat.v1 as tf_v1\n"
2246                       "import tensorflow.compat.v2 as tf_v2\n")
2247    expected_text = expected_header + new_symbol
2248    _, _, _, new_text = self._upgrade(
2249        text, import_rename=False, upgrade_compat_v1_import=True)
2250    self.assertEqual(new_text, expected_text)
2251
2252    import_header = "from tensorflow import foo\n"
2253    text = import_header + old_symbol
2254    expected_text = "from tensorflow.compat.v2 import foo\n" + new_symbol
2255    _, unused_report, unused_errors, new_text = self._upgrade(
2256        text, import_rename=True)
2257    self.assertEqual(new_text, expected_text)
2258
2259    import_header = "from tensorflow import *\n"
2260    text = import_header + old_symbol
2261    expected_text = "from tensorflow.compat.v2 import *\n" + new_symbol
2262    _, unused_report, unused_errors, new_text = self._upgrade(
2263        text, import_rename=True)
2264    self.assertEqual(new_text, expected_text)
2265
2266    import_header = "from tensorflow.foo import bar\n"
2267    text = import_header + old_symbol
2268    expected_text = "from tensorflow.compat.v2.foo import bar\n" + new_symbol
2269    _, unused_report, unused_errors, new_text = self._upgrade(
2270        text, import_rename=True)
2271    self.assertEqual(new_text, expected_text)
2272
2273    import_header = ("from tensorflow import foo as tf\n"
2274                     "from tensorflow.compat import v1 as tf_v1\n"
2275                     "from tensorflow.compat import v2 as tf_v2\n")
2276    text = import_header + old_symbol
2277    expected_header = ("from tensorflow.compat.v2 import foo as tf\n"
2278                       "from tensorflow.compat import v1 as tf_v1\n"
2279                       "from tensorflow.compat import v2 as tf_v2\n")
2280    expected_text = expected_header + new_symbol
2281    _, _, _, new_text = self._upgrade(text, import_rename=True)
2282    self.assertEqual(new_text, expected_text)
2283
2284  def test_import_analysis(self):
2285    old_symbol = "tf.conj(a)"
2286    new_symbol = "tf.math.conj(a)"
2287
2288    # We upgrade the base un-versioned tensorflow aliased as tf
2289    import_header = "import tensorflow as tf\n"
2290    text = import_header + old_symbol
2291    expected_text = import_header + new_symbol
2292    _, unused_report, unused_errors, new_text = self._upgrade(text)
2293    self.assertEqual(new_text, expected_text)
2294
2295    import_header = ("import tensorflow as tf\n"
2296                     "import tensorflow.compat.v1 as tf_v1\n"
2297                     "import tensorflow.compat.v2 as tf_v2\n")
2298    text = import_header + old_symbol
2299    expected_text = import_header + new_symbol
2300    _, _, _, new_text = self._upgrade(text)
2301    self.assertEqual(new_text, expected_text)
2302
2303    # We don't handle unaliased tensorflow imports currently,
2304    # So the upgrade script show log errors
2305    import_header = "import tensorflow\n"
2306    text = import_header + old_symbol
2307    expected_text = import_header + old_symbol
2308    _, _, errors, new_text = self._upgrade(text)
2309    self.assertEqual(new_text, expected_text)
2310    self.assertIn("unaliased `import tensorflow`", "\n".join(errors))
2311
2312    # Upgrading explicitly-versioned tf code is unsafe, but we don't
2313    # need to throw errors when we detect explicitly-versioned tf.
2314    import_header = "import tensorflow.compat.v1 as tf\n"
2315    text = import_header + old_symbol
2316    expected_text = import_header + old_symbol
2317    _, report, errors, new_text = self._upgrade(text)
2318    self.assertEqual(new_text, expected_text)
2319    self.assertIn("`tensorflow.compat.v1` was directly imported as `tf`",
2320                  report)
2321    self.assertEmpty(errors)
2322
2323    import_header = "from tensorflow.compat import v1 as tf\n"
2324    text = import_header + old_symbol
2325    expected_text = import_header + old_symbol
2326    _, report, errors, new_text = self._upgrade(text)
2327    self.assertEqual(new_text, expected_text)
2328    self.assertIn("`tensorflow.compat.v1` was directly imported as `tf`",
2329                  report)
2330    self.assertEmpty(errors)
2331
2332    import_header = "from tensorflow.compat import v1 as tf, v2 as tf2\n"
2333    text = import_header + old_symbol
2334    expected_text = import_header + old_symbol
2335    _, report, errors, new_text = self._upgrade(text)
2336    self.assertEqual(new_text, expected_text)
2337    self.assertIn("`tensorflow.compat.v1` was directly imported as `tf`",
2338                  report)
2339    self.assertEmpty(errors)
2340
2341    import_header = "import tensorflow.compat.v2 as tf\n"
2342    text = import_header + old_symbol
2343    expected_text = import_header + old_symbol
2344    _, report, errors, new_text = self._upgrade(text)
2345    self.assertEqual(new_text, expected_text)
2346    self.assertIn("`tensorflow.compat.v2` was directly imported as `tf`",
2347                  report)
2348    self.assertEmpty(errors)
2349
2350    import_header = "from tensorflow.compat import v1 as tf1, v2 as tf\n"
2351    text = import_header + old_symbol
2352    expected_text = import_header + old_symbol
2353    _, report, errors, new_text = self._upgrade(text)
2354    self.assertEqual(new_text, expected_text)
2355    self.assertIn("`tensorflow.compat.v2` was directly imported as `tf`",
2356                  report)
2357    self.assertEmpty(errors)
2358
2359  def test_api_spec_reset_between_files(self):
2360    for old_symbol, new_symbol in [
2361        ("tf.conj(a)", "tf.math.conj(a)"),
2362        ("tf.to_int32(x)", "tf.cast(x, dtype=tf.int32)")]:
2363
2364      ## Test that the api spec is reset in between files:
2365      import_header = "import tensorflow.compat.v2 as tf\n"
2366      text_a = import_header + old_symbol
2367      expected_text_a = import_header + old_symbol
2368      text_b = old_symbol
2369      expected_text_b = new_symbol
2370      results = self._upgrade_multiple([text_a, text_b])
2371      result_a, result_b = results[0], results[1]
2372      self.assertEqual(result_a[3], expected_text_a)
2373      self.assertEqual(result_b[3], expected_text_b)
2374
2375  def test_model_to_estimator_checkpoint_warning(self):
2376    text = "tf.keras.estimator.model_to_estimator(model)"
2377    _, report, _, _ = self._upgrade(text)
2378    expected_info = "will save object-based checkpoints"
2379    self.assertIn(expected_info, report)
2380
2381  def test_keras_experimental_export_warning(self):
2382    text = "tf.keras.experimental.export_saved_model"
2383    _, report, _, _ = self._upgrade(text)
2384    expected_info = "Please use model.save"
2385    self.assertIn(expected_info, report)
2386
2387
2388class TestUpgradeFiles(test_util.TensorFlowTestCase):
2389
2390  def testInplace(self):
2391    """Check to make sure we don't have a file system race."""
2392    temp_file = tempfile.NamedTemporaryFile("w", delete=False)
2393    original = "tf.conj(a)\n"
2394    upgraded = "tf.math.conj(a)\n"
2395    temp_file.write(original)
2396    temp_file.close()
2397    upgrader = ast_edits.ASTCodeUpgrader(tf_upgrade_v2.TFAPIChangeSpec())
2398    upgrader.process_file(temp_file.name, temp_file.name)
2399    self.assertAllEqual(open(temp_file.name).read(), upgraded)
2400    os.unlink(temp_file.name)
2401
2402  def testInplaceNoOutputChangeOnErrorHandling(self):
2403    """In place file should not be modified when parsing error is handled."""
2404    temp_file = tempfile.NamedTemporaryFile("w", delete=False)
2405    original = "print 'a' \n"
2406    upgraded = "print 'a' \n"
2407    temp_file.write(original)
2408    temp_file.close()
2409    upgrader = ast_edits.ASTCodeUpgrader(tf_upgrade_v2.TFAPIChangeSpec())
2410    upgrader.process_file(
2411        temp_file.name, temp_file.name, no_change_to_outfile_on_error=True)
2412    self.assertAllEqual(open(temp_file.name).read(), upgraded)
2413    os.unlink(temp_file.name)
2414
2415  def testInplaceEmptyOutputOnError(self):
2416    """In place file becomes empty when parsing error is not handled."""
2417    temp_file = tempfile.NamedTemporaryFile("w", delete=False)
2418    original = "print 'a' \n"
2419    upgraded = ""
2420    temp_file.write(original)
2421    temp_file.close()
2422    upgrader = ast_edits.ASTCodeUpgrader(tf_upgrade_v2.TFAPIChangeSpec())
2423    upgrader.process_file(temp_file.name, temp_file.name)
2424    self.assertAllEqual(open(temp_file.name).read(), upgraded)
2425    os.unlink(temp_file.name)
2426
2427
2428if __name__ == "__main__":
2429  test_lib.main()
2430