1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tf 2.0 upgrader in safety mode."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import six
21
22from tensorflow.python.framework import test_util
23from tensorflow.python.platform import test as test_lib
24from tensorflow.tools.compatibility import ast_edits
25from tensorflow.tools.compatibility import tf_upgrade_v2_safety
26
27
28class TfUpgradeV2SafetyTest(test_util.TensorFlowTestCase):
29
30  def _upgrade(self, old_file_text):
31    in_file = six.StringIO(old_file_text)
32    out_file = six.StringIO()
33    upgrader = ast_edits.ASTCodeUpgrader(tf_upgrade_v2_safety.TFAPIChangeSpec())
34    count, report, errors = (
35        upgrader.process_opened_file("test.py", in_file,
36                                     "test_out.py", out_file))
37    return count, report, errors, out_file.getvalue()
38
39  def testContribWarning(self):
40    text = "tf.contrib.foo()"
41    _, report, _, _ = self._upgrade(text)
42    expected_info = "tf.contrib will not be distributed"
43    self.assertIn(expected_info, report)
44
45  def testTensorFlowImport(self):
46    text = "import tensorflow as tf"
47    expected_text = ("import tensorflow.compat.v1 as tf")
48    _, _, _, new_text = self._upgrade(text)
49    self.assertEqual(expected_text, new_text)
50
51    text = "import tensorflow as tf, other_import as y"
52    expected_text = ("import tensorflow.compat.v1 as tf, other_import as y")
53    _, _, _, new_text = self._upgrade(text)
54    self.assertEqual(expected_text, new_text)
55
56    text = "import tensorflow"
57    expected_text = ("import tensorflow.compat.v1 as tensorflow")
58    _, _, _, new_text = self._upgrade(text)
59    self.assertEqual(expected_text, new_text)
60
61    text = "import tensorflow.foo"
62    expected_text = "import tensorflow.compat.v1.foo"
63    _, _, _, new_text = self._upgrade(text)
64    self.assertEqual(expected_text, new_text)
65
66    text = "import tensorflow.foo as bar"
67    expected_text = "import tensorflow.compat.v1.foo as bar"
68    _, _, _, new_text = self._upgrade(text)
69    self.assertEqual(expected_text, new_text)
70
71  def testTensorFlowGoogleImport(self):
72    text = "import tensorflow.google as tf"
73    _, _, _, new_text = self._upgrade(text)
74    self.assertEqual(text, new_text)
75
76    text = "import tensorflow.google"
77    _, _, _, new_text = self._upgrade(text)
78    self.assertEqual(text, new_text)
79
80    text = "import tensorflow.google.compat.v1 as tf"
81    expected_text = "import tensorflow.google.compat.v1 as tf"
82    _, _, _, new_text = self._upgrade(text)
83    self.assertEqual(expected_text, new_text)
84
85    text = "import tensorflow.google.compat.v2 as tf"
86    expected_text = "import tensorflow.google.compat.v2 as tf"
87    _, _, _, new_text = self._upgrade(text)
88    self.assertEqual(expected_text, new_text)
89
90  def testTensorFlowImportInIndent(self):
91    text = """
92try:
93  import tensorflow as tf  # import line
94
95  tf.ones([4, 5])
96except AttributeError:
97  pass
98"""
99
100    expected_text = """
101try:
102  import tensorflow.compat.v1 as tf  # import line
103
104  tf.ones([4, 5])
105except AttributeError:
106  pass
107"""
108    _, _, _, new_text = self._upgrade(text)
109    self.assertEqual(expected_text, new_text)
110
111  def testTensorFlowFromImport(self):
112    text = "from tensorflow import foo"
113    expected_text = "from tensorflow.compat.v1 import foo"
114    _, _, _, new_text = self._upgrade(text)
115    self.assertEqual(expected_text, new_text)
116
117    text = "from tensorflow.foo import bar"
118    expected_text = "from tensorflow.compat.v1.foo import bar"
119    _, _, _, new_text = self._upgrade(text)
120    self.assertEqual(expected_text, new_text)
121
122    text = "from tensorflow import *"
123    expected_text = "from tensorflow.compat.v1 import *"
124    _, _, _, new_text = self._upgrade(text)
125    self.assertEqual(expected_text, new_text)
126
127  def testTensorFlowImportAlreadyHasCompat(self):
128    text = "import tensorflow.compat.v1 as tf"
129    _, _, _, new_text = self._upgrade(text)
130    self.assertEqual(text, new_text)
131
132    text = "import tensorflow.compat.v2 as tf"
133    _, _, _, new_text = self._upgrade(text)
134    self.assertEqual(text, new_text)
135
136    text = "from tensorflow.compat import v2 as tf"
137    _, _, _, new_text = self._upgrade(text)
138    self.assertEqual(text, new_text)
139
140  def testTensorFlowGoogleFromImport(self):
141    text = "from tensorflow.google.compat import v1 as tf"
142    _, _, _, new_text = self._upgrade(text)
143    self.assertEqual(text, new_text)
144
145    text = "from tensorflow.google.compat import v2 as tf"
146    _, _, _, new_text = self._upgrade(text)
147    self.assertEqual(text, new_text)
148
149  def testTensorFlowDontChangeContrib(self):
150    text = "import tensorflow.contrib as foo"
151    _, _, _, new_text = self._upgrade(text)
152    self.assertEqual(text, new_text)
153
154    text = "from tensorflow import contrib"
155    _, _, _, new_text = self._upgrade(text)
156    self.assertEqual(text, new_text)
157
158  def test_contrib_to_addons_move(self):
159    small_mapping = {
160        "tf.contrib.layers.poincare_normalize":
161            "tfa.layers.PoincareNormalize",
162        "tf.contrib.layers.maxout":
163            "tfa.layers.Maxout",
164        "tf.contrib.layers.group_norm":
165            "tfa.layers.GroupNormalization",
166        "tf.contrib.layers.instance_norm":
167            "tfa.layers.InstanceNormalization",
168    }
169    for symbol, replacement in small_mapping.items():
170      text = "{}('stuff', *args, **kwargs)".format(symbol)
171      _, report, _, _ = self._upgrade(text)
172      self.assertIn(replacement, report)
173
174if __name__ == "__main__":
175  test_lib.main()
176  def testTensorFlowDontChangeContrib(self):
177    text = "import tensorflow.contrib as foo"
178    _, _, _, new_text = self._upgrade(text)
179    self.assertEqual(text, new_text)
180
181    text = "from tensorflow import contrib"
182    _, _, _, new_text = self._upgrade(text)
183    self.assertEqual(text, new_text)
184
185  def test_contrib_to_addons_move(self):
186    small_mapping = {
187        "tf.contrib.layers.poincare_normalize":
188            "tfa.layers.PoincareNormalize",
189        "tf.contrib.layers.maxout":
190            "tfa.layers.Maxout",
191        "tf.contrib.layers.group_norm":
192            "tfa.layers.GroupNormalization",
193        "tf.contrib.layers.instance_norm":
194            "tfa.layers.InstanceNormalization",
195    }
196    for symbol, replacement in small_mapping.items():
197      text = "{}('stuff', *args, **kwargs)".format(symbol)
198      _, report, _, _ = self._upgrade(text)
199      self.assertIn(replacement, report)
200
201if __name__ == "__main__":
202  test_lib.main()
203