1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14#
15# ==============================================================================
16"""TensorFlow API compatibility tests.
17
18This test ensures all changes to the public API of TensorFlow are intended.
19
20If this test fails, it means a change has been made to the public API. Backwards
21incompatible changes are not allowed. You can run the test with
22"--update_goldens" flag set to "True" to update goldens when making changes to
23the public TF python API.
24"""
25
26from __future__ import absolute_import
27from __future__ import division
28from __future__ import print_function
29
30import argparse
31import os
32import re
33import sys
34
35import six
36import tensorflow as tf
37
38from google.protobuf import message
39from google.protobuf import text_format
40
41from tensorflow.python.lib.io import file_io
42from tensorflow.python.framework import test_util
43from tensorflow.python.platform import resource_loader
44from tensorflow.python.platform import test
45from tensorflow.python.platform import tf_logging as logging
46from tensorflow.tools.api.lib import api_objects_pb2
47from tensorflow.tools.api.lib import python_object_to_proto_visitor
48from tensorflow.tools.common import public_api
49from tensorflow.tools.common import traverse
50
51# FLAGS defined at the bottom:
52FLAGS = None
53# DEFINE_boolean, update_goldens, default False:
54_UPDATE_GOLDENS_HELP = """
55     Update stored golden files if API is updated. WARNING: All API changes
56     have to be authorized by TensorFlow leads.
57"""
58
59# DEFINE_boolean, only_test_core_api, default False:
60_ONLY_TEST_CORE_API_HELP = """
61    Some TF APIs are being moved outside of the tensorflow/ directory. There is
62    no guarantee which versions of these APIs will be present when running this
63    test. Therefore, do not error out on API changes in non-core TF code
64    if this flag is set.
65"""
66
67# DEFINE_boolean, verbose_diffs, default True:
68_VERBOSE_DIFFS_HELP = """
69     If set to true, print line by line diffs on all libraries. If set to
70     false, only print which libraries have differences.
71"""
72
73_API_GOLDEN_FOLDER_V1 = 'tensorflow/tools/api/golden/v1'
74_API_GOLDEN_FOLDER_V2 = 'tensorflow/tools/api/golden/v2'
75_TEST_README_FILE = 'tensorflow/tools/api/tests/README.txt'
76_UPDATE_WARNING_FILE = 'tensorflow/tools/api/tests/API_UPDATE_WARNING.txt'
77
78_NON_CORE_PACKAGES = ['estimator']
79
80
81# TODO(annarev): remove this once we test with newer version of
82# estimator that actually has compat v1 version.
83if not hasattr(tf.compat.v1, 'estimator'):
84  tf.compat.v1.estimator = tf.estimator
85  tf.compat.v2.estimator = tf.estimator
86
87
88def _KeyToFilePath(key, api_version):
89  """From a given key, construct a filepath.
90
91  Filepath will be inside golden folder for api_version.
92  """
93
94  def _ReplaceCapsWithDash(matchobj):
95    match = matchobj.group(0)
96    return '-%s' % (match.lower())
97
98  case_insensitive_key = re.sub('([A-Z]{1})', _ReplaceCapsWithDash, key)
99  api_folder = (
100      _API_GOLDEN_FOLDER_V2 if api_version == 2 else _API_GOLDEN_FOLDER_V1)
101  return os.path.join(api_folder, '%s.pbtxt' % case_insensitive_key)
102
103
104def _FileNameToKey(filename):
105  """From a given filename, construct a key we use for api objects."""
106
107  def _ReplaceDashWithCaps(matchobj):
108    match = matchobj.group(0)
109    return match[1].upper()
110
111  base_filename = os.path.basename(filename)
112  base_filename_without_ext = os.path.splitext(base_filename)[0]
113  api_object_key = re.sub('((-[a-z]){1})', _ReplaceDashWithCaps,
114                          base_filename_without_ext)
115  return api_object_key
116
117
118def _VerifyNoSubclassOfMessageVisitor(path, parent, unused_children):
119  """A Visitor that crashes on subclasses of generated proto classes."""
120  # If the traversed object is a proto Message class
121  if not (isinstance(parent, type) and issubclass(parent, message.Message)):
122    return
123  if parent is message.Message:
124    return
125  # Check that it is a direct subclass of Message.
126  if message.Message not in parent.__bases__:
127    raise NotImplementedError(
128        'Object tf.%s is a subclass of a generated proto Message. '
129        'They are not yet supported by the API tools.' % path)
130
131
132def _FilterNonCoreGoldenFiles(golden_file_list):
133  """Filter out non-core API pbtxt files."""
134  filtered_file_list = []
135  filtered_package_prefixes = ['tensorflow.%s.' % p for p in _NON_CORE_PACKAGES]
136  for f in golden_file_list:
137    if any(
138        f.rsplit('/')[-1].startswith(pre) for pre in filtered_package_prefixes
139    ):
140      continue
141    filtered_file_list.append(f)
142  return filtered_file_list
143
144
145def _FilterGoldenProtoDict(golden_proto_dict, omit_golden_symbols_map):
146  """Filter out golden proto dict symbols that should be omitted."""
147  if not omit_golden_symbols_map:
148    return golden_proto_dict
149  filtered_proto_dict = dict(golden_proto_dict)
150  for key, symbol_list in six.iteritems(omit_golden_symbols_map):
151    api_object = api_objects_pb2.TFAPIObject()
152    api_object.CopyFrom(filtered_proto_dict[key])
153    filtered_proto_dict[key] = api_object
154    module_or_class = None
155    if api_object.HasField('tf_module'):
156      module_or_class = api_object.tf_module
157    elif api_object.HasField('tf_class'):
158      module_or_class = api_object.tf_class
159    if module_or_class is not None:
160      for members in (module_or_class.member, module_or_class.member_method):
161        filtered_members = [m for m in members if m.name not in symbol_list]
162        # Two steps because protobuf repeated fields disallow slice assignment.
163        del members[:]
164        members.extend(filtered_members)
165  return filtered_proto_dict
166
167
168class ApiCompatibilityTest(test.TestCase):
169
170  def __init__(self, *args, **kwargs):
171    super(ApiCompatibilityTest, self).__init__(*args, **kwargs)
172
173    golden_update_warning_filename = os.path.join(
174        resource_loader.get_root_dir_with_all_resources(), _UPDATE_WARNING_FILE)
175    self._update_golden_warning = file_io.read_file_to_string(
176        golden_update_warning_filename)
177
178    test_readme_filename = os.path.join(
179        resource_loader.get_root_dir_with_all_resources(), _TEST_README_FILE)
180    self._test_readme_message = file_io.read_file_to_string(
181        test_readme_filename)
182
183  def _AssertProtoDictEquals(self,
184                             expected_dict,
185                             actual_dict,
186                             verbose=False,
187                             update_goldens=False,
188                             additional_missing_object_message='',
189                             api_version=2):
190    """Diff given dicts of protobufs and report differences a readable way.
191
192    Args:
193      expected_dict: a dict of TFAPIObject protos constructed from golden files.
194      actual_dict: a ict of TFAPIObject protos constructed by reading from the
195        TF package linked to the test.
196      verbose: Whether to log the full diffs, or simply report which files were
197        different.
198      update_goldens: Whether to update goldens when there are diffs found.
199      additional_missing_object_message: Message to print when a symbol is
200        missing.
201      api_version: TensorFlow API version to test.
202    """
203    diffs = []
204    verbose_diffs = []
205
206    expected_keys = set(expected_dict.keys())
207    actual_keys = set(actual_dict.keys())
208    only_in_expected = expected_keys - actual_keys
209    only_in_actual = actual_keys - expected_keys
210    all_keys = expected_keys | actual_keys
211
212    # This will be populated below.
213    updated_keys = []
214
215    for key in all_keys:
216      diff_message = ''
217      verbose_diff_message = ''
218      # First check if the key is not found in one or the other.
219      if key in only_in_expected:
220        diff_message = 'Object %s expected but not found (removed). %s' % (
221            key, additional_missing_object_message)
222        verbose_diff_message = diff_message
223      elif key in only_in_actual:
224        diff_message = 'New object %s found (added).' % key
225        verbose_diff_message = diff_message
226      else:
227        # Do not truncate diff
228        self.maxDiff = None  # pylint: disable=invalid-name
229        # Now we can run an actual proto diff.
230        try:
231          self.assertProtoEquals(expected_dict[key], actual_dict[key])
232        except AssertionError as e:
233          updated_keys.append(key)
234          diff_message = 'Change detected in python object: %s.' % key
235          verbose_diff_message = str(e)
236
237      # All difference cases covered above. If any difference found, add to the
238      # list.
239      if diff_message:
240        diffs.append(diff_message)
241        verbose_diffs.append(verbose_diff_message)
242
243    # If diffs are found, handle them based on flags.
244    if diffs:
245      diff_count = len(diffs)
246      logging.error(self._test_readme_message)
247      logging.error('%d differences found between API and golden.', diff_count)
248      messages = verbose_diffs if verbose else diffs
249      for i in range(diff_count):
250        print('Issue %d\t: %s' % (i + 1, messages[i]), file=sys.stderr)
251
252      if update_goldens:
253        # Write files if requested.
254        logging.warning(self._update_golden_warning)
255
256        # If the keys are only in expected, some objects are deleted.
257        # Remove files.
258        for key in only_in_expected:
259          filepath = _KeyToFilePath(key, api_version)
260          file_io.delete_file(filepath)
261
262        # If the files are only in actual (current library), these are new
263        # modules. Write them to files. Also record all updates in files.
264        for key in only_in_actual | set(updated_keys):
265          filepath = _KeyToFilePath(key, api_version)
266          file_io.write_string_to_file(
267              filepath, text_format.MessageToString(actual_dict[key]))
268      else:
269        # Fail if we cannot fix the test by updating goldens.
270        self.fail('%d differences found between API and golden.' % diff_count)
271
272    else:
273      logging.info('No differences found between API and golden.')
274
275  def testNoSubclassOfMessage(self):
276    visitor = public_api.PublicAPIVisitor(_VerifyNoSubclassOfMessageVisitor)
277    visitor.do_not_descend_map['tf'].append('contrib')
278    # Skip compat.v1 and compat.v2 since they are validated in separate tests.
279    visitor.private_map['tf.compat'] = ['v1', 'v2']
280    traverse.traverse(tf, visitor)
281
282  def testNoSubclassOfMessageV1(self):
283    if not hasattr(tf.compat, 'v1'):
284      return
285    visitor = public_api.PublicAPIVisitor(_VerifyNoSubclassOfMessageVisitor)
286    visitor.do_not_descend_map['tf'].append('contrib')
287    if FLAGS.only_test_core_api:
288      visitor.do_not_descend_map['tf'].extend(_NON_CORE_PACKAGES)
289    traverse.traverse(tf.compat.v1, visitor)
290
291  def testNoSubclassOfMessageV2(self):
292    if not hasattr(tf.compat, 'v2'):
293      return
294    visitor = public_api.PublicAPIVisitor(_VerifyNoSubclassOfMessageVisitor)
295    visitor.do_not_descend_map['tf'].append('contrib')
296    if FLAGS.only_test_core_api:
297      visitor.do_not_descend_map['tf'].extend(_NON_CORE_PACKAGES)
298    visitor.private_map['tf.compat'] = ['v1', 'v2']
299    traverse.traverse(tf.compat.v2, visitor)
300
301  def _checkBackwardsCompatibility(self,
302                                   root,
303                                   golden_file_pattern,
304                                   api_version,
305                                   additional_private_map=None,
306                                   omit_golden_symbols_map=None):
307    # Extract all API stuff.
308    visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor()
309
310    public_api_visitor = public_api.PublicAPIVisitor(visitor)
311    public_api_visitor.private_map['tf'] = ['contrib']
312    if api_version == 2:
313      public_api_visitor.private_map['tf'].append('enable_v2_behavior')
314
315    public_api_visitor.do_not_descend_map['tf.GPUOptions'] = ['Experimental']
316    if FLAGS.only_test_core_api:
317      public_api_visitor.do_not_descend_map['tf'].extend(_NON_CORE_PACKAGES)
318    if additional_private_map:
319      public_api_visitor.private_map.update(additional_private_map)
320
321    traverse.traverse(root, public_api_visitor)
322    proto_dict = visitor.GetProtos()
323
324    # Read all golden files.
325    golden_file_list = file_io.get_matching_files(golden_file_pattern)
326    if FLAGS.only_test_core_api:
327      golden_file_list = _FilterNonCoreGoldenFiles(golden_file_list)
328
329    def _ReadFileToProto(filename):
330      """Read a filename, create a protobuf from its contents."""
331      ret_val = api_objects_pb2.TFAPIObject()
332      text_format.Merge(file_io.read_file_to_string(filename), ret_val)
333      return ret_val
334
335    golden_proto_dict = {
336        _FileNameToKey(filename): _ReadFileToProto(filename)
337        for filename in golden_file_list
338    }
339    golden_proto_dict = _FilterGoldenProtoDict(golden_proto_dict,
340                                               omit_golden_symbols_map)
341
342    # Diff them. Do not fail if called with update.
343    # If the test is run to update goldens, only report diffs but do not fail.
344    self._AssertProtoDictEquals(
345        golden_proto_dict,
346        proto_dict,
347        verbose=FLAGS.verbose_diffs,
348        update_goldens=FLAGS.update_goldens,
349        api_version=api_version)
350
351  @test_util.run_v1_only('b/120545219')
352  def testAPIBackwardsCompatibility(self):
353    api_version = 2 if '_api.v2' in tf.__name__ else 1
354    golden_file_pattern = os.path.join(
355        resource_loader.get_root_dir_with_all_resources(),
356        _KeyToFilePath('*', api_version))
357    self._checkBackwardsCompatibility(
358        tf,
359        golden_file_pattern,
360        api_version,
361        # Skip compat.v1 and compat.v2 since they are validated
362        # in separate tests.
363        additional_private_map={'tf.compat': ['v1', 'v2']})
364
365    # Also check that V1 API has contrib
366    self.assertTrue(
367        'tensorflow.python.util.lazy_loader.LazyLoader'
368        in str(type(tf.contrib)))
369
370  @test_util.run_v1_only('b/120545219')
371  def testAPIBackwardsCompatibilityV1(self):
372    api_version = 1
373    golden_file_pattern = os.path.join(
374        resource_loader.get_root_dir_with_all_resources(),
375        _KeyToFilePath('*', api_version))
376    self._checkBackwardsCompatibility(tf.compat.v1, golden_file_pattern,
377                                      api_version)
378
379  def testAPIBackwardsCompatibilityV2(self):
380    api_version = 2
381    golden_file_pattern = os.path.join(
382        resource_loader.get_root_dir_with_all_resources(),
383        _KeyToFilePath('*', api_version))
384    omit_golden_symbols_map = {}
385    if FLAGS.only_test_core_api:
386      # In TF 2.0 these summary symbols are imported from TensorBoard.
387      omit_golden_symbols_map['tensorflow.summary'] = [
388          'audio', 'histogram', 'image', 'scalar', 'text']
389    self._checkBackwardsCompatibility(
390        tf.compat.v2,
391        golden_file_pattern,
392        api_version,
393        additional_private_map={'tf.compat': ['v1', 'v2']},
394        omit_golden_symbols_map=omit_golden_symbols_map)
395
396
397if __name__ == '__main__':
398  parser = argparse.ArgumentParser()
399  parser.add_argument(
400      '--update_goldens', type=bool, default=False, help=_UPDATE_GOLDENS_HELP)
401  # TODO(mikecase): Create Estimator's own API compatibility test or
402  # a more general API compatibility test for use for TF components.
403  parser.add_argument(
404      '--only_test_core_api',
405      type=bool,
406      default=True,  # only_test_core_api default value
407      help=_ONLY_TEST_CORE_API_HELP)
408  parser.add_argument(
409      '--verbose_diffs', type=bool, default=True, help=_VERBOSE_DIFFS_HELP)
410  FLAGS, unparsed = parser.parse_known_args()
411
412  # Now update argv, so that unittest library does not get confused.
413  sys.argv = [sys.argv[0]] + unparsed
414  test.main()
415