1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Definition of XLA test case."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import contextlib
22import os
23import random
24import re
25
26import numpy as np
27
28from tensorflow.core.framework import types_pb2
29from tensorflow.core.protobuf import config_pb2
30from tensorflow.core.protobuf import rewriter_config_pb2
31from tensorflow.python.client import session
32from tensorflow.python.compiler.xla import jit
33from tensorflow.python.eager import context
34from tensorflow.python.framework import dtypes
35from tensorflow.python.framework import ops
36from tensorflow.python.framework import random_seed
37from tensorflow.python.framework import test_util
38from tensorflow.python.ops import array_ops
39from tensorflow.python.ops import variables
40from tensorflow.python.platform import flags
41from tensorflow.python.platform import test
42from tensorflow.python.platform import tf_logging as logging
43
44FLAGS = flags.FLAGS
45
46flags.DEFINE_string('test_device', None,
47                    'Tensorflow device on which to place operators under test')
48flags.DEFINE_string('types', None, 'Types to test. Comma-separated list.')
49flags.DEFINE_string('disabled_manifest', None,
50                    'Path to a file with a list of tests that should not run.')
51flags.DEFINE_string('tf_xla_flags', None,
52                    'Value to set the TF_XLA_FLAGS environment variable to')
53
54
55def parse_disabled_manifest(manifest_content):
56  comments_re = re.compile('#.*$')
57  disabled_tests = []
58  disabled_method_types = []
59  for l in manifest_content.splitlines():
60    stripped = comments_re.sub('', l).strip()
61    if not stripped:
62      continue
63    entry = stripped.split(' ')
64    if len(entry) == 1:
65      disabled_tests.append(entry[0])
66    elif len(entry) == 2:
67      disabled_method_types.append((entry[0], entry[1].strip().split(',')))
68    else:
69      raise ValueError('Bad entry in manifest file.')
70
71  disabled_regex = '|'.join(disabled_tests)
72  method_types_filter = {}
73  for method, types in disabled_method_types:
74    method_types_filter[method] = set([
75        dtypes.as_dtype(types_pb2.DataType.Value(name)).as_numpy_dtype
76        for name in types
77    ])
78  return disabled_regex, method_types_filter
79
80
81class XLATestCase(test.TestCase):
82  """XLA test cases are parameterized test cases."""
83
84  def __init__(self, method_name='runTest'):
85    super(XLATestCase, self).__init__(method_name)
86    if 'XLA' in FLAGS.test_device:
87      context.context().enable_xla_devices()
88
89    # Check if the mlir bridge has been explicitly enabled or disabled. If
90    # is_mlir_bridge_enabled() returns None, the user did not explictly enable
91    # or disable the bridge so do not update enable_mlir_bridge.
92    if test_util.is_mlir_bridge_enabled():
93      context.context().enable_mlir_bridge = True
94    elif test_util.is_mlir_bridge_enabled() is not None:
95      context.context().enable_mlir_bridge = False
96
97    self.device = FLAGS.test_device
98    self.has_custom_call = (self.device == 'XLA_CPU')
99    self._all_tf_types = set([
100        dtypes.as_dtype(types_pb2.DataType.Value(name))
101        for name in FLAGS.types.split(',')
102    ])
103    self.int_tf_types = set([
104        dtype for dtype in self._all_tf_types if dtype.is_integer
105    ])
106    self._float_tf_types = set([
107        dtype for dtype in self._all_tf_types if dtype.is_floating
108    ])
109    self.complex_tf_types = set([
110        dtype for dtype in self._all_tf_types if dtype.is_complex
111    ])
112    self._numeric_tf_types = set(
113        self.int_tf_types | self._float_tf_types | self.complex_tf_types)
114    self.quantized_tf_types = set(
115        dtype for dtype in self._all_tf_types if dtype.is_quantized)
116
117    # Quantized types don't have a numpy equivalent, include them in
118    # all_tf_types but not in all_types.
119    # TODO(b/115960798): Parametrize tests on TF types instead of numpy types
120    # and remove all_types.
121    self._all_types = set(dtype.as_numpy_dtype
122                          for dtype in self._all_tf_types
123                          if not dtype.is_quantized)
124    self._int_types = set([dtype.as_numpy_dtype for dtype in self.int_tf_types])
125    self.signed_int_types = set(dtype.as_numpy_dtype
126                                for dtype in self.int_tf_types
127                                if not dtype.is_unsigned)
128    self.unsigned_int_types = set(dtype.as_numpy_dtype
129                                  for dtype in self.int_tf_types
130                                  if dtype.is_unsigned)
131    self._float_types = set(
132        [dtype.as_numpy_dtype for dtype in self._float_tf_types])
133    self.complex_types = set([
134        dtype.as_numpy_dtype for dtype in self.complex_tf_types
135    ])
136    self._numeric_types = set(self._int_types | self._float_types
137                              | self.complex_types)
138
139    # Parse the manifest file, if any, into a regex identifying tests to
140    # disable
141    # TODO(xpan): Make it text proto if it doesn't scale.
142    # Each line of the manifest file specifies an entry. The entry can be
143    # 1) TestNameRegex  // E.g. CumprodTest.* Or
144    # 2) TestName TypeName  // E.g. AdamOptimizerTest.testSharing DT_BFLOAT16
145    # The 1) disables the entire test. While 2) only filter some numeric types
146    # so that they are not used in those tests.
147    self.disabled_regex = None
148    self._method_types_filter = {}
149
150    if FLAGS.disabled_manifest is not None:
151      with open(FLAGS.disabled_manifest, 'r') as manifest_file:
152        disabled_regex, self._method_types_filter = (
153            parse_disabled_manifest(manifest_file.read()))
154        if disabled_regex:
155          self.disabled_regex = re.compile(disabled_regex)
156
157    if FLAGS.tf_xla_flags is not None:
158      os.environ['TF_XLA_FLAGS'] = FLAGS.tf_xla_flags
159
160  @property
161  def all_tf_types(self):
162    name = '{}.{}'.format(type(self).__name__, self._testMethodName)
163    tf_types = set([dtypes.as_dtype(t)
164                    for t in self._method_types_filter.get(name, set())])
165    return self._all_tf_types - tf_types
166
167  @property
168  def float_types(self):
169    name = '{}.{}'.format(type(self).__name__, self._testMethodName)
170    return self._float_types - self._method_types_filter.get(name, set())
171
172  @property
173  def float_tf_types(self):
174    name = '{}.{}'.format(type(self).__name__, self._testMethodName)
175    return self._float_tf_types - self._method_types_filter.get(name, set())
176
177  @property
178  def int_types(self):
179    name = '{}.{}'.format(type(self).__name__, self._testMethodName)
180    return self._int_types - self._method_types_filter.get(name, set())
181
182  @property
183  def numeric_tf_types(self):
184    name = '{}.{}'.format(type(self).__name__, self._testMethodName)
185    tf_types = set([dtypes.as_dtype(t)
186                    for t in self._method_types_filter.get(name, set())])
187    return self._numeric_tf_types - tf_types
188
189  @property
190  def numeric_types(self):
191    name = '{}.{}'.format(type(self).__name__, self._testMethodName)
192    return self._numeric_types - self._method_types_filter.get(name, set())
193
194  @property
195  def all_types(self):
196    name = '{}.{}'.format(type(self).__name__, self._testMethodName)
197    return self._all_types - self._method_types_filter.get(name, set())
198
199  def setUp(self):
200    super(XLATestCase, self).setUp()
201    name = '{}.{}'.format(type(self).__name__, self._testMethodName)
202    if self.disabled_regex is not None and self.disabled_regex.match(name):
203      logging.info('Disabled test case: %s', name)
204      self.skipTest('{} is disabled by manifest.'.format(name))
205      return
206    logging.info('Start test case: %s', name)
207
208    random.seed(random_seed.DEFAULT_GRAPH_SEED)
209    np.random.seed(random_seed.DEFAULT_GRAPH_SEED)
210
211  def tearDown(self):
212    super(XLATestCase, self).tearDown()
213    logging.info('End test case: %s', self._testMethodName)
214
215  @contextlib.contextmanager
216  def session(self):
217    """Custom implementation of session() for XLA tests.
218
219    We override the standard Tensorflow session() since it is too
220    specific to CPU and GPU tests. In particular, we want to disable soft
221    placement and explicitly assign ops to devices under test.
222
223    Yields:
224      A session to use when running a test case.
225    """
226    graph = ops.Graph()
227    config = context.context().config
228
229    # Grappler can constant fold TensorListFromTensor ops into DT_VARIANT
230    # constants which XLA does not understand.  So disable constant folding in
231    # these tests.
232    config.graph_options.rewrite_options.constant_folding = (
233        rewriter_config_pb2.RewriterConfig.OFF)
234    with session.Session(
235        graph=graph, config=config) as sess, graph.as_default():
236      yield sess
237
238  def cached_session(self):
239    raise NotImplementedError(
240        'cached_session not supported on XLATestCase, please use session')
241
242  def test_session(self):
243    raise NotImplementedError(
244        'test_session not supported on XLATestCase, please use session')
245
246  @contextlib.contextmanager
247  def device_scope(self):
248    """Scope that runs tests on `self.device`.
249
250    Yields:
251      A scope to apply to the operators under test.
252    """
253    with ops.device('device:{}:0'.format(self.device)):
254      yield
255
256  def test_scope(self):
257    """Deprecated alias of `device_scope`.
258
259    This should be avoided as the name starts with `test`, so test runners
260    treat it as a test. This interferes with class decorators that operate on
261    each test method.
262    """
263    return self.device_scope()
264
265
266def Benchmark(tf_bench,
267              builder_fn,
268              use_xla_jit,
269              device,
270              separate_compiled_gradients=False):
271  """Build a graph and run benchmarks against it, with or without XLA.
272
273  Args:
274    tf_bench: An instance of tf.test.Benchmark, used to run the benchmark.
275    builder_fn: A function that builds a graph when invoked, and returns
276        (name, fetches), where name is the name of the test, and fetches
277        is a list of tensors to fetch as output.
278    use_xla_jit: If true compile with the XLA JIT, otherwise use regular TF.
279    device: The tensorflow device to run on, e.g. "cpu", "gpu".
280    separate_compiled_gradients: If true put each gradient subgraph into a
281      separate compilation scope. This gives fine-grained control over which
282      portions of the graph will be compiled as a single unit. Compiling
283      gradients separately may yield better performance for some graphs.
284      The scope is named based on the scope of the forward computation as well
285      as the name of the gradients. As a result, the gradients will be compiled
286      in a scope that is separate from both the forward computation, and from
287      other gradients.
288  """
289
290  with ops.Graph().as_default():
291    name = None
292    targets = []
293    with ops.device(device):
294      fetches = []
295      jit_scope = jit.experimental_jit_scope
296      with jit_scope(
297          compile_ops=use_xla_jit,
298          separate_compiled_gradients=separate_compiled_gradients):
299        name, fetches = builder_fn()
300
301      # We only want to benchmark the operations themselves, and not the data
302      # transfer of the result(s).  Non-compiled identity ops ensure XLA
303      # doesn't know we're dropping the results, otherwise it might compile
304      # away the entire computation.
305      for fetch in fetches:
306        targets.append(array_ops.identity(fetch).op)
307
308    # TODO(b/132430685):  Should we allow soft placement here?
309    config = config_pb2.ConfigProto(allow_soft_placement=True)
310    with session.Session(config=config) as sess:
311      sess.run(variables.global_variables_initializer())
312      xla = 'xla_' if use_xla_jit else ''
313      tf_bench.run_op_benchmark(
314          sess, targets, name='%s_%s%s' % (name, xla, device))
315