1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tf 2.0 upgrader in safety mode.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import six 21 22from tensorflow.python.framework import test_util 23from tensorflow.python.platform import test as test_lib 24from tensorflow.tools.compatibility import ast_edits 25from tensorflow.tools.compatibility import tf_upgrade_v2_safety 26 27 28class TfUpgradeV2SafetyTest(test_util.TensorFlowTestCase): 29 30 def _upgrade(self, old_file_text): 31 in_file = six.StringIO(old_file_text) 32 out_file = six.StringIO() 33 upgrader = ast_edits.ASTCodeUpgrader(tf_upgrade_v2_safety.TFAPIChangeSpec()) 34 count, report, errors = ( 35 upgrader.process_opened_file("test.py", in_file, 36 "test_out.py", out_file)) 37 return count, report, errors, out_file.getvalue() 38 39 def testContribWarning(self): 40 text = "tf.contrib.foo()" 41 _, report, _, _ = self._upgrade(text) 42 expected_info = "tf.contrib will not be distributed" 43 self.assertIn(expected_info, report) 44 45 def testTensorFlowImport(self): 46 text = "import tensorflow as tf" 47 expected_text = ("import tensorflow.compat.v1 as tf") 48 _, _, _, new_text = self._upgrade(text) 49 self.assertEqual(expected_text, new_text) 50 51 text = "import tensorflow as tf, other_import as y" 52 expected_text = ("import tensorflow.compat.v1 as tf, other_import as y") 53 _, _, _, new_text = self._upgrade(text) 54 self.assertEqual(expected_text, new_text) 55 56 text = "import tensorflow" 57 expected_text = ("import tensorflow.compat.v1 as tensorflow") 58 _, _, _, new_text = self._upgrade(text) 59 self.assertEqual(expected_text, new_text) 60 61 text = "import tensorflow.foo" 62 expected_text = "import tensorflow.compat.v1.foo" 63 _, _, _, new_text = self._upgrade(text) 64 self.assertEqual(expected_text, new_text) 65 66 text = "import tensorflow.foo as bar" 67 expected_text = "import tensorflow.compat.v1.foo as bar" 68 _, _, _, new_text = self._upgrade(text) 69 self.assertEqual(expected_text, new_text) 70 71 def testTensorFlowGoogleImport(self): 72 text = "import tensorflow.google as tf" 73 _, _, _, new_text = self._upgrade(text) 74 self.assertEqual(text, new_text) 75 76 text = "import tensorflow.google" 77 _, _, _, new_text = self._upgrade(text) 78 self.assertEqual(text, new_text) 79 80 text = "import tensorflow.google.compat.v1 as tf" 81 expected_text = "import tensorflow.google.compat.v1 as tf" 82 _, _, _, new_text = self._upgrade(text) 83 self.assertEqual(expected_text, new_text) 84 85 text = "import tensorflow.google.compat.v2 as tf" 86 expected_text = "import tensorflow.google.compat.v2 as tf" 87 _, _, _, new_text = self._upgrade(text) 88 self.assertEqual(expected_text, new_text) 89 90 def testTensorFlowImportInIndent(self): 91 text = """ 92try: 93 import tensorflow as tf # import line 94 95 tf.ones([4, 5]) 96except AttributeError: 97 pass 98""" 99 100 expected_text = """ 101try: 102 import tensorflow.compat.v1 as tf # import line 103 104 tf.ones([4, 5]) 105except AttributeError: 106 pass 107""" 108 _, _, _, new_text = self._upgrade(text) 109 self.assertEqual(expected_text, new_text) 110 111 def testTensorFlowFromImport(self): 112 text = "from tensorflow import foo" 113 expected_text = "from tensorflow.compat.v1 import foo" 114 _, _, _, new_text = self._upgrade(text) 115 self.assertEqual(expected_text, new_text) 116 117 text = "from tensorflow.foo import bar" 118 expected_text = "from tensorflow.compat.v1.foo import bar" 119 _, _, _, new_text = self._upgrade(text) 120 self.assertEqual(expected_text, new_text) 121 122 text = "from tensorflow import *" 123 expected_text = "from tensorflow.compat.v1 import *" 124 _, _, _, new_text = self._upgrade(text) 125 self.assertEqual(expected_text, new_text) 126 127 def testTensorFlowImportAlreadyHasCompat(self): 128 text = "import tensorflow.compat.v1 as tf" 129 _, _, _, new_text = self._upgrade(text) 130 self.assertEqual(text, new_text) 131 132 text = "import tensorflow.compat.v2 as tf" 133 _, _, _, new_text = self._upgrade(text) 134 self.assertEqual(text, new_text) 135 136 text = "from tensorflow.compat import v2 as tf" 137 _, _, _, new_text = self._upgrade(text) 138 self.assertEqual(text, new_text) 139 140 def testTensorFlowGoogleFromImport(self): 141 text = "from tensorflow.google.compat import v1 as tf" 142 _, _, _, new_text = self._upgrade(text) 143 self.assertEqual(text, new_text) 144 145 text = "from tensorflow.google.compat import v2 as tf" 146 _, _, _, new_text = self._upgrade(text) 147 self.assertEqual(text, new_text) 148 149 def testTensorFlowDontChangeContrib(self): 150 text = "import tensorflow.contrib as foo" 151 _, _, _, new_text = self._upgrade(text) 152 self.assertEqual(text, new_text) 153 154 text = "from tensorflow import contrib" 155 _, _, _, new_text = self._upgrade(text) 156 self.assertEqual(text, new_text) 157 158 def test_contrib_to_addons_move(self): 159 small_mapping = { 160 "tf.contrib.layers.poincare_normalize": 161 "tfa.layers.PoincareNormalize", 162 "tf.contrib.layers.maxout": 163 "tfa.layers.Maxout", 164 "tf.contrib.layers.group_norm": 165 "tfa.layers.GroupNormalization", 166 "tf.contrib.layers.instance_norm": 167 "tfa.layers.InstanceNormalization", 168 } 169 for symbol, replacement in small_mapping.items(): 170 text = "{}('stuff', *args, **kwargs)".format(symbol) 171 _, report, _, _ = self._upgrade(text) 172 self.assertIn(replacement, report) 173 174if __name__ == "__main__": 175 test_lib.main() 176 def testTensorFlowDontChangeContrib(self): 177 text = "import tensorflow.contrib as foo" 178 _, _, _, new_text = self._upgrade(text) 179 self.assertEqual(text, new_text) 180 181 text = "from tensorflow import contrib" 182 _, _, _, new_text = self._upgrade(text) 183 self.assertEqual(text, new_text) 184 185 def test_contrib_to_addons_move(self): 186 small_mapping = { 187 "tf.contrib.layers.poincare_normalize": 188 "tfa.layers.PoincareNormalize", 189 "tf.contrib.layers.maxout": 190 "tfa.layers.Maxout", 191 "tf.contrib.layers.group_norm": 192 "tfa.layers.GroupNormalization", 193 "tf.contrib.layers.instance_norm": 194 "tfa.layers.InstanceNormalization", 195 } 196 for symbol, replacement in small_mapping.items(): 197 text = "{}('stuff', *args, **kwargs)".format(symbol) 198 _, report, _, _ = self._upgrade(text) 199 self.assertIn(replacement, report) 200 201if __name__ == "__main__": 202 test_lib.main() 203