1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Definition of XLA test case.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import contextlib 22import os 23import random 24import re 25 26import numpy as np 27 28from tensorflow.core.framework import types_pb2 29from tensorflow.core.protobuf import config_pb2 30from tensorflow.core.protobuf import rewriter_config_pb2 31from tensorflow.python.client import session 32from tensorflow.python.compiler.xla import jit 33from tensorflow.python.eager import context 34from tensorflow.python.framework import dtypes 35from tensorflow.python.framework import ops 36from tensorflow.python.framework import random_seed 37from tensorflow.python.framework import test_util 38from tensorflow.python.ops import array_ops 39from tensorflow.python.ops import variables 40from tensorflow.python.platform import flags 41from tensorflow.python.platform import test 42from tensorflow.python.platform import tf_logging as logging 43 44FLAGS = flags.FLAGS 45 46flags.DEFINE_string('test_device', None, 47 'Tensorflow device on which to place operators under test') 48flags.DEFINE_string('types', None, 'Types to test. Comma-separated list.') 49flags.DEFINE_string('disabled_manifest', None, 50 'Path to a file with a list of tests that should not run.') 51flags.DEFINE_string('tf_xla_flags', None, 52 'Value to set the TF_XLA_FLAGS environment variable to') 53 54 55def parse_disabled_manifest(manifest_content): 56 comments_re = re.compile('#.*$') 57 disabled_tests = [] 58 disabled_method_types = [] 59 for l in manifest_content.splitlines(): 60 stripped = comments_re.sub('', l).strip() 61 if not stripped: 62 continue 63 entry = stripped.split(' ') 64 if len(entry) == 1: 65 disabled_tests.append(entry[0]) 66 elif len(entry) == 2: 67 disabled_method_types.append((entry[0], entry[1].strip().split(','))) 68 else: 69 raise ValueError('Bad entry in manifest file.') 70 71 disabled_regex = '|'.join(disabled_tests) 72 method_types_filter = {} 73 for method, types in disabled_method_types: 74 method_types_filter[method] = set([ 75 dtypes.as_dtype(types_pb2.DataType.Value(name)).as_numpy_dtype 76 for name in types 77 ]) 78 return disabled_regex, method_types_filter 79 80 81class XLATestCase(test.TestCase): 82 """XLA test cases are parameterized test cases.""" 83 84 def __init__(self, method_name='runTest'): 85 super(XLATestCase, self).__init__(method_name) 86 if 'XLA' in FLAGS.test_device: 87 context.context().enable_xla_devices() 88 89 # Check if the mlir bridge has been explicitly enabled or disabled. If 90 # is_mlir_bridge_enabled() returns None, the user did not explictly enable 91 # or disable the bridge so do not update enable_mlir_bridge. 92 if test_util.is_mlir_bridge_enabled(): 93 context.context().enable_mlir_bridge = True 94 elif test_util.is_mlir_bridge_enabled() is not None: 95 context.context().enable_mlir_bridge = False 96 97 self.device = FLAGS.test_device 98 self.has_custom_call = (self.device == 'XLA_CPU') 99 self._all_tf_types = set([ 100 dtypes.as_dtype(types_pb2.DataType.Value(name)) 101 for name in FLAGS.types.split(',') 102 ]) 103 self.int_tf_types = set([ 104 dtype for dtype in self._all_tf_types if dtype.is_integer 105 ]) 106 self._float_tf_types = set([ 107 dtype for dtype in self._all_tf_types if dtype.is_floating 108 ]) 109 self.complex_tf_types = set([ 110 dtype for dtype in self._all_tf_types if dtype.is_complex 111 ]) 112 self._numeric_tf_types = set( 113 self.int_tf_types | self._float_tf_types | self.complex_tf_types) 114 self.quantized_tf_types = set( 115 dtype for dtype in self._all_tf_types if dtype.is_quantized) 116 117 # Quantized types don't have a numpy equivalent, include them in 118 # all_tf_types but not in all_types. 119 # TODO(b/115960798): Parametrize tests on TF types instead of numpy types 120 # and remove all_types. 121 self._all_types = set(dtype.as_numpy_dtype 122 for dtype in self._all_tf_types 123 if not dtype.is_quantized) 124 self._int_types = set([dtype.as_numpy_dtype for dtype in self.int_tf_types]) 125 self.signed_int_types = set(dtype.as_numpy_dtype 126 for dtype in self.int_tf_types 127 if not dtype.is_unsigned) 128 self.unsigned_int_types = set(dtype.as_numpy_dtype 129 for dtype in self.int_tf_types 130 if dtype.is_unsigned) 131 self._float_types = set( 132 [dtype.as_numpy_dtype for dtype in self._float_tf_types]) 133 self.complex_types = set([ 134 dtype.as_numpy_dtype for dtype in self.complex_tf_types 135 ]) 136 self._numeric_types = set(self._int_types | self._float_types 137 | self.complex_types) 138 139 # Parse the manifest file, if any, into a regex identifying tests to 140 # disable 141 # TODO(xpan): Make it text proto if it doesn't scale. 142 # Each line of the manifest file specifies an entry. The entry can be 143 # 1) TestNameRegex // E.g. CumprodTest.* Or 144 # 2) TestName TypeName // E.g. AdamOptimizerTest.testSharing DT_BFLOAT16 145 # The 1) disables the entire test. While 2) only filter some numeric types 146 # so that they are not used in those tests. 147 self.disabled_regex = None 148 self._method_types_filter = {} 149 150 if FLAGS.disabled_manifest is not None: 151 with open(FLAGS.disabled_manifest, 'r') as manifest_file: 152 disabled_regex, self._method_types_filter = ( 153 parse_disabled_manifest(manifest_file.read())) 154 if disabled_regex: 155 self.disabled_regex = re.compile(disabled_regex) 156 157 if FLAGS.tf_xla_flags is not None: 158 os.environ['TF_XLA_FLAGS'] = FLAGS.tf_xla_flags 159 160 @property 161 def all_tf_types(self): 162 name = '{}.{}'.format(type(self).__name__, self._testMethodName) 163 tf_types = set([dtypes.as_dtype(t) 164 for t in self._method_types_filter.get(name, set())]) 165 return self._all_tf_types - tf_types 166 167 @property 168 def float_types(self): 169 name = '{}.{}'.format(type(self).__name__, self._testMethodName) 170 return self._float_types - self._method_types_filter.get(name, set()) 171 172 @property 173 def float_tf_types(self): 174 name = '{}.{}'.format(type(self).__name__, self._testMethodName) 175 return self._float_tf_types - self._method_types_filter.get(name, set()) 176 177 @property 178 def int_types(self): 179 name = '{}.{}'.format(type(self).__name__, self._testMethodName) 180 return self._int_types - self._method_types_filter.get(name, set()) 181 182 @property 183 def numeric_tf_types(self): 184 name = '{}.{}'.format(type(self).__name__, self._testMethodName) 185 tf_types = set([dtypes.as_dtype(t) 186 for t in self._method_types_filter.get(name, set())]) 187 return self._numeric_tf_types - tf_types 188 189 @property 190 def numeric_types(self): 191 name = '{}.{}'.format(type(self).__name__, self._testMethodName) 192 return self._numeric_types - self._method_types_filter.get(name, set()) 193 194 @property 195 def all_types(self): 196 name = '{}.{}'.format(type(self).__name__, self._testMethodName) 197 return self._all_types - self._method_types_filter.get(name, set()) 198 199 def setUp(self): 200 super(XLATestCase, self).setUp() 201 name = '{}.{}'.format(type(self).__name__, self._testMethodName) 202 if self.disabled_regex is not None and self.disabled_regex.match(name): 203 logging.info('Disabled test case: %s', name) 204 self.skipTest('{} is disabled by manifest.'.format(name)) 205 return 206 logging.info('Start test case: %s', name) 207 208 random.seed(random_seed.DEFAULT_GRAPH_SEED) 209 np.random.seed(random_seed.DEFAULT_GRAPH_SEED) 210 211 def tearDown(self): 212 super(XLATestCase, self).tearDown() 213 logging.info('End test case: %s', self._testMethodName) 214 215 @contextlib.contextmanager 216 def session(self): 217 """Custom implementation of session() for XLA tests. 218 219 We override the standard Tensorflow session() since it is too 220 specific to CPU and GPU tests. In particular, we want to disable soft 221 placement and explicitly assign ops to devices under test. 222 223 Yields: 224 A session to use when running a test case. 225 """ 226 graph = ops.Graph() 227 config = context.context().config 228 229 # Grappler can constant fold TensorListFromTensor ops into DT_VARIANT 230 # constants which XLA does not understand. So disable constant folding in 231 # these tests. 232 config.graph_options.rewrite_options.constant_folding = ( 233 rewriter_config_pb2.RewriterConfig.OFF) 234 with session.Session( 235 graph=graph, config=config) as sess, graph.as_default(): 236 yield sess 237 238 def cached_session(self): 239 raise NotImplementedError( 240 'cached_session not supported on XLATestCase, please use session') 241 242 def test_session(self): 243 raise NotImplementedError( 244 'test_session not supported on XLATestCase, please use session') 245 246 @contextlib.contextmanager 247 def device_scope(self): 248 """Scope that runs tests on `self.device`. 249 250 Yields: 251 A scope to apply to the operators under test. 252 """ 253 with ops.device('device:{}:0'.format(self.device)): 254 yield 255 256 def test_scope(self): 257 """Deprecated alias of `device_scope`. 258 259 This should be avoided as the name starts with `test`, so test runners 260 treat it as a test. This interferes with class decorators that operate on 261 each test method. 262 """ 263 return self.device_scope() 264 265 266def Benchmark(tf_bench, 267 builder_fn, 268 use_xla_jit, 269 device, 270 separate_compiled_gradients=False): 271 """Build a graph and run benchmarks against it, with or without XLA. 272 273 Args: 274 tf_bench: An instance of tf.test.Benchmark, used to run the benchmark. 275 builder_fn: A function that builds a graph when invoked, and returns 276 (name, fetches), where name is the name of the test, and fetches 277 is a list of tensors to fetch as output. 278 use_xla_jit: If true compile with the XLA JIT, otherwise use regular TF. 279 device: The tensorflow device to run on, e.g. "cpu", "gpu". 280 separate_compiled_gradients: If true put each gradient subgraph into a 281 separate compilation scope. This gives fine-grained control over which 282 portions of the graph will be compiled as a single unit. Compiling 283 gradients separately may yield better performance for some graphs. 284 The scope is named based on the scope of the forward computation as well 285 as the name of the gradients. As a result, the gradients will be compiled 286 in a scope that is separate from both the forward computation, and from 287 other gradients. 288 """ 289 290 with ops.Graph().as_default(): 291 name = None 292 targets = [] 293 with ops.device(device): 294 fetches = [] 295 jit_scope = jit.experimental_jit_scope 296 with jit_scope( 297 compile_ops=use_xla_jit, 298 separate_compiled_gradients=separate_compiled_gradients): 299 name, fetches = builder_fn() 300 301 # We only want to benchmark the operations themselves, and not the data 302 # transfer of the result(s). Non-compiled identity ops ensure XLA 303 # doesn't know we're dropping the results, otherwise it might compile 304 # away the entire computation. 305 for fetch in fetches: 306 targets.append(array_ops.identity(fetch).op) 307 308 # TODO(b/132430685): Should we allow soft placement here? 309 config = config_pb2.ConfigProto(allow_soft_placement=True) 310 with session.Session(config=config) as sess: 311 sess.run(variables.global_variables_initializer()) 312 xla = 'xla_' if use_xla_jit else '' 313 tf_bench.run_op_benchmark( 314 sess, targets, name='%s_%s%s' % (name, xla, device)) 315