1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Definition of XLA test case."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import contextlib
22import os
23import random
24import re
25
26import numpy as np
27
28from tensorflow.contrib.compiler import jit
29from tensorflow.core.framework import types_pb2
30from tensorflow.core.protobuf import config_pb2
31from tensorflow.python.client import session
32from tensorflow.python.framework import dtypes
33from tensorflow.python.framework import ops
34from tensorflow.python.framework import random_seed
35from tensorflow.python.ops import array_ops
36from tensorflow.python.ops import variables
37from tensorflow.python.platform import flags
38from tensorflow.python.platform import test
39from tensorflow.python.platform import tf_logging as logging
40
41FLAGS = flags.FLAGS
42
43flags.DEFINE_string('test_device', None,
44                    'Tensorflow device on which to place operators under test')
45flags.DEFINE_string('types', None, 'Types to test. Comma-separated list.')
46flags.DEFINE_string('disabled_manifest', None,
47                    'Path to a file with a list of tests that should not run.')
48flags.DEFINE_string('tf_xla_flags', None,
49                    'Value to set the TF_XLA_FLAGS environment variable to')
50
51
52def parse_disabled_manifest(manifest_content):
53  comments_re = re.compile('#.*$')
54  disabled_tests = []
55  disabled_method_types = []
56  for l in manifest_content.splitlines():
57    stripped = comments_re.sub('', l).strip()
58    if not stripped:
59      continue
60    entry = stripped.split(' ')
61    if len(entry) == 1:
62      disabled_tests.append(entry[0])
63    elif len(entry) == 2:
64      disabled_method_types.append((entry[0], entry[1].strip().split(',')))
65    else:
66      raise ValueError('Bad entry in manifest file.')
67
68  disabled_regex = '|'.join(disabled_tests)
69  method_types_filter = dict()
70  for method, types in disabled_method_types:
71    method_types_filter[method] = set([
72        dtypes.as_dtype(types_pb2.DataType.Value(name)).as_numpy_dtype
73        for name in types
74    ])
75  return disabled_regex, method_types_filter
76
77
78class XLATestCase(test.TestCase):
79  """XLA test cases are parameterized test cases."""
80
81  def __init__(self, method_name='runTest'):
82    super(XLATestCase, self).__init__(method_name)
83    self.device = FLAGS.test_device
84    self.has_custom_call = (self.device == 'XLA_CPU')
85    self._all_tf_types = set([
86        dtypes.as_dtype(types_pb2.DataType.Value(name))
87        for name in FLAGS.types.split(',')
88    ])
89    self.int_tf_types = set([
90        dtype for dtype in self._all_tf_types if dtype.is_integer
91    ])
92    self._float_tf_types = set([
93        dtype for dtype in self._all_tf_types if dtype.is_floating
94    ])
95    self.complex_tf_types = set([
96        dtype for dtype in self._all_tf_types if dtype.is_complex
97    ])
98    self._numeric_tf_types = set(
99        self.int_tf_types | self._float_tf_types | self.complex_tf_types)
100    self.quantized_tf_types = set(
101        dtype for dtype in self._all_tf_types if dtype.is_quantized)
102
103    # Quantized types don't have a numpy equivalent, include them in
104    # all_tf_types but not in all_types.
105    # TODO(b/115960798): Parametrize tests on TF types instead of numpy types
106    # and remove all_types.
107    self._all_types = set(dtype.as_numpy_dtype
108                          for dtype in self._all_tf_types
109                          if not dtype.is_quantized)
110    self._int_types = set([dtype.as_numpy_dtype for dtype in self.int_tf_types])
111    self.signed_int_types = set(dtype.as_numpy_dtype
112                                for dtype in self.int_tf_types
113                                if not dtype.is_unsigned)
114    self.unsigned_int_types = set(dtype.as_numpy_dtype
115                                  for dtype in self.int_tf_types
116                                  if dtype.is_unsigned)
117    self._float_types = set(
118        [dtype.as_numpy_dtype for dtype in self._float_tf_types])
119    self.complex_types = set([
120        dtype.as_numpy_dtype for dtype in self.complex_tf_types
121    ])
122    self._numeric_types = set(self._int_types | self._float_types
123                              | self.complex_types)
124
125    # Parse the manifest file, if any, into a regex identifying tests to
126    # disable
127    # TODO(xpan): Make it text proto if it doesn't scale.
128    # Each line of the manifest file specifies an entry. The entry can be
129    # 1) TestNameRegex  // E.g. CumprodTest.* Or
130    # 2) TestName TypeName  // E.g. AdamOptimizerTest.testSharing DT_BFLOAT16
131    # The 1) disables the entire test. While 2) only filter some numeric types
132    # so that they are not used in those tests.
133    self.disabled_regex = None
134    self._method_types_filter = {}
135
136    if FLAGS.disabled_manifest is not None:
137      with open(FLAGS.disabled_manifest, 'r') as manifest_file:
138        disabled_regex, self._method_types_filter = (
139            parse_disabled_manifest(manifest_file.read()))
140        if disabled_regex:
141          self.disabled_regex = re.compile(disabled_regex)
142
143    if FLAGS.tf_xla_flags is not None:
144      os.environ['TF_XLA_FLAGS'] = FLAGS.tf_xla_flags
145
146  @property
147  def all_tf_types(self):
148    name = '{}.{}'.format(type(self).__name__, self._testMethodName)
149    tf_types = set([dtypes.as_dtype(t)
150                    for t in self._method_types_filter.get(name, set())])
151    return self._all_tf_types - tf_types
152
153  @property
154  def float_types(self):
155    name = '{}.{}'.format(type(self).__name__, self._testMethodName)
156    return self._float_types - self._method_types_filter.get(name, set())
157
158  @property
159  def float_tf_types(self):
160    name = '{}.{}'.format(type(self).__name__, self._testMethodName)
161    return self._float_tf_types - self._method_types_filter.get(name, set())
162
163  @property
164  def int_types(self):
165    name = '{}.{}'.format(type(self).__name__, self._testMethodName)
166    return self._int_types - self._method_types_filter.get(name, set())
167
168  @property
169  def numeric_tf_types(self):
170    name = '{}.{}'.format(type(self).__name__, self._testMethodName)
171    tf_types = set([dtypes.as_dtype(t)
172                    for t in self._method_types_filter.get(name, set())])
173    return self._numeric_tf_types - tf_types
174
175  @property
176  def numeric_types(self):
177    name = '{}.{}'.format(type(self).__name__, self._testMethodName)
178    return self._numeric_types - self._method_types_filter.get(name, set())
179
180  @property
181  def all_types(self):
182    name = '{}.{}'.format(type(self).__name__, self._testMethodName)
183    return self._all_types - self._method_types_filter.get(name, set())
184
185  def setUp(self):
186    super(XLATestCase, self).setUp()
187    name = '{}.{}'.format(type(self).__name__, self._testMethodName)
188    if self.disabled_regex is not None and self.disabled_regex.match(name):
189      logging.info('Disabled test case: %s', name)
190      self.skipTest('{} is disabled by manifest.'.format(name))
191      return
192    logging.info('Start test case: %s', name)
193
194    random.seed(random_seed.DEFAULT_GRAPH_SEED)
195    np.random.seed(random_seed.DEFAULT_GRAPH_SEED)
196
197  def tearDown(self):
198    super(XLATestCase, self).tearDown()
199    logging.info('End test case: %s', self._testMethodName)
200
201  @contextlib.contextmanager
202  def test_session(self):
203    """Custom implementation of test_session() for XLA tests.
204
205    We override the standard Tensorflow test_session() since it is too
206    specific to CPU and GPU tests. In particular, we want to disable soft
207    placement and explicitly assign ops to devices under test.
208
209    Yields:
210      A session to use when running a test case.
211    """
212    graph = ops.Graph()
213    with session.Session(graph=graph) as sess, graph.as_default():
214      yield sess
215
216  @contextlib.contextmanager
217  def test_scope(self):
218    """Test scope that runs tests on a Tensorflow/XLA device.
219
220    Uses a compilation_scope() to mark operators to compile.
221
222    Yields:
223      A scope to apply to the operators under test.
224    """
225    with ops.device('device:{}:0'.format(self.device)):
226      yield
227
228
229def Benchmark(tf_bench,
230              builder_fn,
231              use_xla_jit,
232              device,
233              separate_compiled_gradients=False):
234  """Build a graph and run benchmarks against it, with or without XLA.
235
236  Args:
237    tf_bench: An instance of tf.test.Benchmark, used to run the benchmark.
238    builder_fn: A function that builds a graph when invoked, and returns
239        (name, fetches), where name is the name of the test, and fetches
240        is a list of tensors to fetch as output.
241    use_xla_jit: If true compile with the XLA JIT, otherwise use regular TF.
242    device: The tensorflow device to run on, e.g. "cpu", "gpu".
243    separate_compiled_gradients: If true put each gradient subgraph into a
244      separate compilation scope. This gives fine-grained control over which
245      portions of the graph will be compiled as a single unit. Compiling
246      gradients separately may yield better performance for some graphs.
247      The scope is named based on the scope of the forward computation as well
248      as the name of the gradients. As a result, the gradients will be compiled
249      in a scope that is separate from both the forward computation, and from
250      other gradients.
251  """
252
253  with ops.Graph().as_default():
254    name = None
255    targets = []
256    with ops.device(device):
257      fetches = []
258      jit_scope = jit.experimental_jit_scope
259      with jit_scope(
260          compile_ops=use_xla_jit,
261          separate_compiled_gradients=separate_compiled_gradients):
262        name, fetches = builder_fn()
263
264      # We only want to benchmark the operations themselves, and not the data
265      # transfer of the result(s).  Non-compiled identity ops ensure XLA
266      # doesn't know we're dropping the results, otherwise it might compile
267      # away the entire computation.
268      for fetch in fetches:
269        targets.append(array_ops.identity(fetch).op)
270
271    config = config_pb2.ConfigProto(allow_soft_placement=True)
272    with session.Session(config=config) as sess:
273      sess.run(variables.global_variables_initializer())
274      xla = 'xla_' if use_xla_jit else ''
275      tf_bench.run_op_benchmark(
276          sess, targets, name='%s_%s%s' % (name, xla, device))
277