1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Wrapper for Python TPU tests.
16
17The py_tpu_test macro will actually use this file as its main, building and
18executing the user-provided test file as a py_binary instead. This lets us do
19important work behind the scenes, without complicating the tests themselves.
20
21The main responsibilities of this file are:
22  - Define standard set of model flags if test did not. This allows us to
23    safely set flags at the Bazel invocation level using --test_arg.
24  - Pick a random directory on GCS to use for each test case, and set it as the
25    default value of --model_dir. This is similar to how Bazel provides each
26    test with a fresh local directory in $TEST_TMPDIR.
27"""
28
29from __future__ import absolute_import
30from __future__ import division
31from __future__ import print_function
32
33import ast
34import importlib
35import os
36import sys
37import uuid
38
39from tensorflow.python.platform import flags
40from tensorflow.python.util import tf_inspect
41
42FLAGS = flags.FLAGS
43flags.DEFINE_string(
44    'wrapped_tpu_test_module_relative', None,
45    'The Python-style relative path to the user-given test. If test is in same '
46    'directory as BUILD file as is common, then "test.py" would be ".test".')
47flags.DEFINE_string('test_dir_base',
48                    os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR'),
49                    'GCS path to root directory for temporary test files.')
50flags.DEFINE_string(
51    'bazel_repo_root', 'tensorflow/python',
52    'Substring of a bazel filepath beginning the python absolute import path.')
53
54# List of flags which all TPU tests should accept.
55REQUIRED_FLAGS = ['tpu', 'zone', 'project', 'model_dir']
56
57
58def maybe_define_flags():
59  """Defines any required flags that are missing."""
60  for f in REQUIRED_FLAGS:
61    try:
62      flags.DEFINE_string(f, None, 'flag defined by test lib')
63    except flags.DuplicateFlagError:
64      pass
65
66
67def set_random_test_dir():
68  """Pick a random GCS directory under --test_dir_base, set as --model_dir."""
69  path = os.path.join(FLAGS.test_dir_base, uuid.uuid4().hex)
70  FLAGS.set_default('model_dir', path)
71
72
73def calculate_parent_python_path(test_filepath):
74  """Returns the absolute import path for the containing directory.
75
76  Args:
77    test_filepath: The filepath which Bazel invoked
78      (ex: /filesystem/path/tensorflow/tensorflow/python/tpu/tpu_test)
79
80  Returns:
81    Absolute import path of parent (ex: tensorflow.python.tpu).
82
83  Raises:
84    ValueError: if bazel_repo_root does not appear within test_filepath.
85  """
86  # We find the last occurrence of bazel_repo_root, and drop everything before.
87  split_path = test_filepath.rsplit(FLAGS.bazel_repo_root, 1)
88  if len(split_path) < 2:
89    raise ValueError('Filepath "%s" does not contain repo root "%s"' %
90                     (test_filepath, FLAGS.bazel_repo_root))
91  path = FLAGS.bazel_repo_root + split_path[1]
92
93  # We drop the last portion of the path, which is the name of the test wrapper.
94  path = path.rsplit('/', 1)[0]
95
96  # We convert the directory separators into dots.
97  return path.replace('/', '.')
98
99
100def import_user_module():
101  """Imports the flag-specified user test code.
102
103  This runs all top-level statements in the user module, specifically flag
104  definitions.
105
106  Returns:
107    The user test module.
108  """
109  return importlib.import_module(FLAGS.wrapped_tpu_test_module_relative,
110                                 calculate_parent_python_path(sys.argv[0]))
111
112
113def _is_test_class(obj):
114  """Check if arbitrary object is a test class (not a test object!).
115
116  Args:
117    obj: An arbitrary object from within a module.
118
119  Returns:
120    True iff obj is a test class inheriting at some point from a module
121    named "TestCase". This is because we write tests using different underlying
122    test libraries.
123  """
124  return (tf_inspect.isclass(obj)
125          and 'TestCase' in (p.__name__ for p in tf_inspect.getmro(obj)))
126
127
128module_variables = vars()
129
130
131def move_test_classes_into_scope(wrapped_test_module):
132  """Add all test classes defined in wrapped module to our module.
133
134  The test runner works by inspecting the main module for TestCase classes, so
135  by adding a module-level reference to the TestCase we cause it to execute the
136  wrapped TestCase.
137
138  Args:
139    wrapped_test_module: The user-provided test code to run.
140  """
141  for name, obj in wrapped_test_module.__dict__.items():
142    if _is_test_class(obj):
143      module_variables['tpu_test_imported_%s' % name] = obj
144
145
146def run_user_main(wrapped_test_module):
147  """Runs the "if __name__ == '__main__'" at the bottom of a module.
148
149  TensorFlow practice is to have a main if at the bottom of the module which
150  might call an API compat function before calling test.main().
151
152  Since this is a statement, not a function, we can't cleanly reference it, but
153  we can inspect it from the user module and run it in the context of that
154  module so all imports and variables are available to it.
155
156  Args:
157    wrapped_test_module: The user-provided test code to run.
158
159  Raises:
160    NotImplementedError: If main block was not found in module. This should not
161      be caught, as it is likely an error on the user's part -- absltest is all
162      too happy to report a successful status (and zero tests executed) if a
163      user forgets to end a class with "test.main()".
164  """
165  tree = ast.parse(tf_inspect.getsource(wrapped_test_module))
166
167  # Get string representation of just the condition `__name == "__main__"`.
168  target = ast.dump(ast.parse('if __name__ == "__main__": pass').body[0].test)
169
170  # `tree.body` is a list of top-level statements in the module, like imports
171  # and class definitions. We search for our main block, starting from the end.
172  for expr in reversed(tree.body):
173    if isinstance(expr, ast.If) and ast.dump(expr.test) == target:
174      break
175  else:
176    raise NotImplementedError(
177        'Could not find `if __name__ == "main":` block in %s.' %
178        wrapped_test_module.__name__)
179
180  # expr is defined because we would have raised an error otherwise.
181  new_ast = ast.Module(body=expr.body, type_ignores=[])  # pylint:disable=undefined-loop-variable
182  exec(  # pylint:disable=exec-used
183      compile(new_ast, '<ast>', 'exec'),
184      globals(),
185      wrapped_test_module.__dict__,
186  )
187
188
189if __name__ == '__main__':
190  # Partially parse flags, since module to import is specified by flag.
191  unparsed = FLAGS(sys.argv, known_only=True)
192  user_module = import_user_module()
193  maybe_define_flags()
194  # Parse remaining flags.
195  FLAGS(unparsed)
196  set_random_test_dir()
197
198  move_test_classes_into_scope(user_module)
199  run_user_main(user_module)
200