1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# 15# ============================================================================== 16"""TensorFlow API compatibility tests. 17 18This test ensures all changes to the public API of TensorFlow are intended. 19 20If this test fails, it means a change has been made to the public API. Backwards 21incompatible changes are not allowed. You can run the test with 22"--update_goldens" flag set to "True" to update goldens when making changes to 23the public TF python API. 24""" 25 26from __future__ import absolute_import 27from __future__ import division 28from __future__ import print_function 29 30import argparse 31import os 32import re 33import sys 34 35import six 36import tensorflow as tf 37 38from google.protobuf import message 39from google.protobuf import text_format 40 41from tensorflow.python.lib.io import file_io 42from tensorflow.python.framework import test_util 43from tensorflow.python.platform import resource_loader 44from tensorflow.python.platform import test 45from tensorflow.python.platform import tf_logging as logging 46from tensorflow.tools.api.lib import api_objects_pb2 47from tensorflow.tools.api.lib import python_object_to_proto_visitor 48from tensorflow.tools.common import public_api 49from tensorflow.tools.common import traverse 50 51# FLAGS defined at the bottom: 52FLAGS = None 53# DEFINE_boolean, update_goldens, default False: 54_UPDATE_GOLDENS_HELP = """ 55 Update stored golden files if API is updated. WARNING: All API changes 56 have to be authorized by TensorFlow leads. 57""" 58 59# DEFINE_boolean, only_test_core_api, default False: 60_ONLY_TEST_CORE_API_HELP = """ 61 Some TF APIs are being moved outside of the tensorflow/ directory. There is 62 no guarantee which versions of these APIs will be present when running this 63 test. Therefore, do not error out on API changes in non-core TF code 64 if this flag is set. 65""" 66 67# DEFINE_boolean, verbose_diffs, default True: 68_VERBOSE_DIFFS_HELP = """ 69 If set to true, print line by line diffs on all libraries. If set to 70 false, only print which libraries have differences. 71""" 72 73_API_GOLDEN_FOLDER_V1 = 'tensorflow/tools/api/golden/v1' 74_API_GOLDEN_FOLDER_V2 = 'tensorflow/tools/api/golden/v2' 75_TEST_README_FILE = 'tensorflow/tools/api/tests/README.txt' 76_UPDATE_WARNING_FILE = 'tensorflow/tools/api/tests/API_UPDATE_WARNING.txt' 77 78_NON_CORE_PACKAGES = ['estimator'] 79 80 81# TODO(annarev): remove this once we test with newer version of 82# estimator that actually has compat v1 version. 83if not hasattr(tf.compat.v1, 'estimator'): 84 tf.compat.v1.estimator = tf.estimator 85 tf.compat.v2.estimator = tf.estimator 86 87 88def _KeyToFilePath(key, api_version): 89 """From a given key, construct a filepath. 90 91 Filepath will be inside golden folder for api_version. 92 """ 93 94 def _ReplaceCapsWithDash(matchobj): 95 match = matchobj.group(0) 96 return '-%s' % (match.lower()) 97 98 case_insensitive_key = re.sub('([A-Z]{1})', _ReplaceCapsWithDash, key) 99 api_folder = ( 100 _API_GOLDEN_FOLDER_V2 if api_version == 2 else _API_GOLDEN_FOLDER_V1) 101 return os.path.join(api_folder, '%s.pbtxt' % case_insensitive_key) 102 103 104def _FileNameToKey(filename): 105 """From a given filename, construct a key we use for api objects.""" 106 107 def _ReplaceDashWithCaps(matchobj): 108 match = matchobj.group(0) 109 return match[1].upper() 110 111 base_filename = os.path.basename(filename) 112 base_filename_without_ext = os.path.splitext(base_filename)[0] 113 api_object_key = re.sub('((-[a-z]){1})', _ReplaceDashWithCaps, 114 base_filename_without_ext) 115 return api_object_key 116 117 118def _VerifyNoSubclassOfMessageVisitor(path, parent, unused_children): 119 """A Visitor that crashes on subclasses of generated proto classes.""" 120 # If the traversed object is a proto Message class 121 if not (isinstance(parent, type) and issubclass(parent, message.Message)): 122 return 123 if parent is message.Message: 124 return 125 # Check that it is a direct subclass of Message. 126 if message.Message not in parent.__bases__: 127 raise NotImplementedError( 128 'Object tf.%s is a subclass of a generated proto Message. ' 129 'They are not yet supported by the API tools.' % path) 130 131 132def _FilterNonCoreGoldenFiles(golden_file_list): 133 """Filter out non-core API pbtxt files.""" 134 filtered_file_list = [] 135 filtered_package_prefixes = ['tensorflow.%s.' % p for p in _NON_CORE_PACKAGES] 136 for f in golden_file_list: 137 if any( 138 f.rsplit('/')[-1].startswith(pre) for pre in filtered_package_prefixes 139 ): 140 continue 141 filtered_file_list.append(f) 142 return filtered_file_list 143 144 145def _FilterGoldenProtoDict(golden_proto_dict, omit_golden_symbols_map): 146 """Filter out golden proto dict symbols that should be omitted.""" 147 if not omit_golden_symbols_map: 148 return golden_proto_dict 149 filtered_proto_dict = dict(golden_proto_dict) 150 for key, symbol_list in six.iteritems(omit_golden_symbols_map): 151 api_object = api_objects_pb2.TFAPIObject() 152 api_object.CopyFrom(filtered_proto_dict[key]) 153 filtered_proto_dict[key] = api_object 154 module_or_class = None 155 if api_object.HasField('tf_module'): 156 module_or_class = api_object.tf_module 157 elif api_object.HasField('tf_class'): 158 module_or_class = api_object.tf_class 159 if module_or_class is not None: 160 for members in (module_or_class.member, module_or_class.member_method): 161 filtered_members = [m for m in members if m.name not in symbol_list] 162 # Two steps because protobuf repeated fields disallow slice assignment. 163 del members[:] 164 members.extend(filtered_members) 165 return filtered_proto_dict 166 167 168class ApiCompatibilityTest(test.TestCase): 169 170 def __init__(self, *args, **kwargs): 171 super(ApiCompatibilityTest, self).__init__(*args, **kwargs) 172 173 golden_update_warning_filename = os.path.join( 174 resource_loader.get_root_dir_with_all_resources(), _UPDATE_WARNING_FILE) 175 self._update_golden_warning = file_io.read_file_to_string( 176 golden_update_warning_filename) 177 178 test_readme_filename = os.path.join( 179 resource_loader.get_root_dir_with_all_resources(), _TEST_README_FILE) 180 self._test_readme_message = file_io.read_file_to_string( 181 test_readme_filename) 182 183 def _AssertProtoDictEquals(self, 184 expected_dict, 185 actual_dict, 186 verbose=False, 187 update_goldens=False, 188 additional_missing_object_message='', 189 api_version=2): 190 """Diff given dicts of protobufs and report differences a readable way. 191 192 Args: 193 expected_dict: a dict of TFAPIObject protos constructed from golden files. 194 actual_dict: a ict of TFAPIObject protos constructed by reading from the 195 TF package linked to the test. 196 verbose: Whether to log the full diffs, or simply report which files were 197 different. 198 update_goldens: Whether to update goldens when there are diffs found. 199 additional_missing_object_message: Message to print when a symbol is 200 missing. 201 api_version: TensorFlow API version to test. 202 """ 203 diffs = [] 204 verbose_diffs = [] 205 206 expected_keys = set(expected_dict.keys()) 207 actual_keys = set(actual_dict.keys()) 208 only_in_expected = expected_keys - actual_keys 209 only_in_actual = actual_keys - expected_keys 210 all_keys = expected_keys | actual_keys 211 212 # This will be populated below. 213 updated_keys = [] 214 215 for key in all_keys: 216 diff_message = '' 217 verbose_diff_message = '' 218 # First check if the key is not found in one or the other. 219 if key in only_in_expected: 220 diff_message = 'Object %s expected but not found (removed). %s' % ( 221 key, additional_missing_object_message) 222 verbose_diff_message = diff_message 223 elif key in only_in_actual: 224 diff_message = 'New object %s found (added).' % key 225 verbose_diff_message = diff_message 226 else: 227 # Do not truncate diff 228 self.maxDiff = None # pylint: disable=invalid-name 229 # Now we can run an actual proto diff. 230 try: 231 self.assertProtoEquals(expected_dict[key], actual_dict[key]) 232 except AssertionError as e: 233 updated_keys.append(key) 234 diff_message = 'Change detected in python object: %s.' % key 235 verbose_diff_message = str(e) 236 237 # All difference cases covered above. If any difference found, add to the 238 # list. 239 if diff_message: 240 diffs.append(diff_message) 241 verbose_diffs.append(verbose_diff_message) 242 243 # If diffs are found, handle them based on flags. 244 if diffs: 245 diff_count = len(diffs) 246 logging.error(self._test_readme_message) 247 logging.error('%d differences found between API and golden.', diff_count) 248 messages = verbose_diffs if verbose else diffs 249 for i in range(diff_count): 250 print('Issue %d\t: %s' % (i + 1, messages[i]), file=sys.stderr) 251 252 if update_goldens: 253 # Write files if requested. 254 logging.warning(self._update_golden_warning) 255 256 # If the keys are only in expected, some objects are deleted. 257 # Remove files. 258 for key in only_in_expected: 259 filepath = _KeyToFilePath(key, api_version) 260 file_io.delete_file(filepath) 261 262 # If the files are only in actual (current library), these are new 263 # modules. Write them to files. Also record all updates in files. 264 for key in only_in_actual | set(updated_keys): 265 filepath = _KeyToFilePath(key, api_version) 266 file_io.write_string_to_file( 267 filepath, text_format.MessageToString(actual_dict[key])) 268 else: 269 # Fail if we cannot fix the test by updating goldens. 270 self.fail('%d differences found between API and golden.' % diff_count) 271 272 else: 273 logging.info('No differences found between API and golden.') 274 275 def testNoSubclassOfMessage(self): 276 visitor = public_api.PublicAPIVisitor(_VerifyNoSubclassOfMessageVisitor) 277 visitor.do_not_descend_map['tf'].append('contrib') 278 # Skip compat.v1 and compat.v2 since they are validated in separate tests. 279 visitor.private_map['tf.compat'] = ['v1', 'v2'] 280 traverse.traverse(tf, visitor) 281 282 def testNoSubclassOfMessageV1(self): 283 if not hasattr(tf.compat, 'v1'): 284 return 285 visitor = public_api.PublicAPIVisitor(_VerifyNoSubclassOfMessageVisitor) 286 visitor.do_not_descend_map['tf'].append('contrib') 287 if FLAGS.only_test_core_api: 288 visitor.do_not_descend_map['tf'].extend(_NON_CORE_PACKAGES) 289 traverse.traverse(tf.compat.v1, visitor) 290 291 def testNoSubclassOfMessageV2(self): 292 if not hasattr(tf.compat, 'v2'): 293 return 294 visitor = public_api.PublicAPIVisitor(_VerifyNoSubclassOfMessageVisitor) 295 visitor.do_not_descend_map['tf'].append('contrib') 296 if FLAGS.only_test_core_api: 297 visitor.do_not_descend_map['tf'].extend(_NON_CORE_PACKAGES) 298 visitor.private_map['tf.compat'] = ['v1', 'v2'] 299 traverse.traverse(tf.compat.v2, visitor) 300 301 def _checkBackwardsCompatibility(self, 302 root, 303 golden_file_pattern, 304 api_version, 305 additional_private_map=None, 306 omit_golden_symbols_map=None): 307 # Extract all API stuff. 308 visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor() 309 310 public_api_visitor = public_api.PublicAPIVisitor(visitor) 311 public_api_visitor.private_map['tf'] = ['contrib'] 312 if api_version == 2: 313 public_api_visitor.private_map['tf'].append('enable_v2_behavior') 314 315 public_api_visitor.do_not_descend_map['tf.GPUOptions'] = ['Experimental'] 316 if FLAGS.only_test_core_api: 317 public_api_visitor.do_not_descend_map['tf'].extend(_NON_CORE_PACKAGES) 318 if additional_private_map: 319 public_api_visitor.private_map.update(additional_private_map) 320 321 traverse.traverse(root, public_api_visitor) 322 proto_dict = visitor.GetProtos() 323 324 # Read all golden files. 325 golden_file_list = file_io.get_matching_files(golden_file_pattern) 326 if FLAGS.only_test_core_api: 327 golden_file_list = _FilterNonCoreGoldenFiles(golden_file_list) 328 329 def _ReadFileToProto(filename): 330 """Read a filename, create a protobuf from its contents.""" 331 ret_val = api_objects_pb2.TFAPIObject() 332 text_format.Merge(file_io.read_file_to_string(filename), ret_val) 333 return ret_val 334 335 golden_proto_dict = { 336 _FileNameToKey(filename): _ReadFileToProto(filename) 337 for filename in golden_file_list 338 } 339 golden_proto_dict = _FilterGoldenProtoDict(golden_proto_dict, 340 omit_golden_symbols_map) 341 342 # Diff them. Do not fail if called with update. 343 # If the test is run to update goldens, only report diffs but do not fail. 344 self._AssertProtoDictEquals( 345 golden_proto_dict, 346 proto_dict, 347 verbose=FLAGS.verbose_diffs, 348 update_goldens=FLAGS.update_goldens, 349 api_version=api_version) 350 351 @test_util.run_v1_only('b/120545219') 352 def testAPIBackwardsCompatibility(self): 353 api_version = 2 if '_api.v2' in tf.__name__ else 1 354 golden_file_pattern = os.path.join( 355 resource_loader.get_root_dir_with_all_resources(), 356 _KeyToFilePath('*', api_version)) 357 self._checkBackwardsCompatibility( 358 tf, 359 golden_file_pattern, 360 api_version, 361 # Skip compat.v1 and compat.v2 since they are validated 362 # in separate tests. 363 additional_private_map={'tf.compat': ['v1', 'v2']}) 364 365 # Also check that V1 API has contrib 366 self.assertTrue( 367 'tensorflow.python.util.lazy_loader.LazyLoader' 368 in str(type(tf.contrib))) 369 370 @test_util.run_v1_only('b/120545219') 371 def testAPIBackwardsCompatibilityV1(self): 372 api_version = 1 373 golden_file_pattern = os.path.join( 374 resource_loader.get_root_dir_with_all_resources(), 375 _KeyToFilePath('*', api_version)) 376 self._checkBackwardsCompatibility(tf.compat.v1, golden_file_pattern, 377 api_version) 378 379 def testAPIBackwardsCompatibilityV2(self): 380 api_version = 2 381 golden_file_pattern = os.path.join( 382 resource_loader.get_root_dir_with_all_resources(), 383 _KeyToFilePath('*', api_version)) 384 omit_golden_symbols_map = {} 385 if FLAGS.only_test_core_api: 386 # In TF 2.0 these summary symbols are imported from TensorBoard. 387 omit_golden_symbols_map['tensorflow.summary'] = [ 388 'audio', 'histogram', 'image', 'scalar', 'text'] 389 self._checkBackwardsCompatibility( 390 tf.compat.v2, 391 golden_file_pattern, 392 api_version, 393 additional_private_map={'tf.compat': ['v1', 'v2']}, 394 omit_golden_symbols_map=omit_golden_symbols_map) 395 396 397if __name__ == '__main__': 398 parser = argparse.ArgumentParser() 399 parser.add_argument( 400 '--update_goldens', type=bool, default=False, help=_UPDATE_GOLDENS_HELP) 401 # TODO(mikecase): Create Estimator's own API compatibility test or 402 # a more general API compatibility test for use for TF components. 403 parser.add_argument( 404 '--only_test_core_api', 405 type=bool, 406 default=True, # only_test_core_api default value 407 help=_ONLY_TEST_CORE_API_HELP) 408 parser.add_argument( 409 '--verbose_diffs', type=bool, default=True, help=_VERBOSE_DIFFS_HELP) 410 FLAGS, unparsed = parser.parse_known_args() 411 412 # Now update argv, so that unittest library does not get confused. 413 sys.argv = [sys.argv[0]] + unparsed 414 test.main() 415