1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Definition of XLA test case.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import contextlib 22import os 23import random 24import re 25 26import numpy as np 27 28from tensorflow.contrib.compiler import jit 29from tensorflow.core.framework import types_pb2 30from tensorflow.core.protobuf import config_pb2 31from tensorflow.python.client import session 32from tensorflow.python.framework import dtypes 33from tensorflow.python.framework import ops 34from tensorflow.python.framework import random_seed 35from tensorflow.python.ops import array_ops 36from tensorflow.python.ops import variables 37from tensorflow.python.platform import flags 38from tensorflow.python.platform import test 39from tensorflow.python.platform import tf_logging as logging 40 41FLAGS = flags.FLAGS 42 43flags.DEFINE_string('test_device', None, 44 'Tensorflow device on which to place operators under test') 45flags.DEFINE_string('types', None, 'Types to test. Comma-separated list.') 46flags.DEFINE_string('disabled_manifest', None, 47 'Path to a file with a list of tests that should not run.') 48flags.DEFINE_string('tf_xla_flags', None, 49 'Value to set the TF_XLA_FLAGS environment variable to') 50 51 52def parse_disabled_manifest(manifest_content): 53 comments_re = re.compile('#.*$') 54 disabled_tests = [] 55 disabled_method_types = [] 56 for l in manifest_content.splitlines(): 57 stripped = comments_re.sub('', l).strip() 58 if not stripped: 59 continue 60 entry = stripped.split(' ') 61 if len(entry) == 1: 62 disabled_tests.append(entry[0]) 63 elif len(entry) == 2: 64 disabled_method_types.append((entry[0], entry[1].strip().split(','))) 65 else: 66 raise ValueError('Bad entry in manifest file.') 67 68 disabled_regex = '|'.join(disabled_tests) 69 method_types_filter = dict() 70 for method, types in disabled_method_types: 71 method_types_filter[method] = set([ 72 dtypes.as_dtype(types_pb2.DataType.Value(name)).as_numpy_dtype 73 for name in types 74 ]) 75 return disabled_regex, method_types_filter 76 77 78class XLATestCase(test.TestCase): 79 """XLA test cases are parameterized test cases.""" 80 81 def __init__(self, method_name='runTest'): 82 super(XLATestCase, self).__init__(method_name) 83 self.device = FLAGS.test_device 84 self.has_custom_call = (self.device == 'XLA_CPU') 85 self._all_tf_types = set([ 86 dtypes.as_dtype(types_pb2.DataType.Value(name)) 87 for name in FLAGS.types.split(',') 88 ]) 89 self.int_tf_types = set([ 90 dtype for dtype in self._all_tf_types if dtype.is_integer 91 ]) 92 self._float_tf_types = set([ 93 dtype for dtype in self._all_tf_types if dtype.is_floating 94 ]) 95 self.complex_tf_types = set([ 96 dtype for dtype in self._all_tf_types if dtype.is_complex 97 ]) 98 self._numeric_tf_types = set( 99 self.int_tf_types | self._float_tf_types | self.complex_tf_types) 100 self.quantized_tf_types = set( 101 dtype for dtype in self._all_tf_types if dtype.is_quantized) 102 103 # Quantized types don't have a numpy equivalent, include them in 104 # all_tf_types but not in all_types. 105 # TODO(b/115960798): Parametrize tests on TF types instead of numpy types 106 # and remove all_types. 107 self._all_types = set(dtype.as_numpy_dtype 108 for dtype in self._all_tf_types 109 if not dtype.is_quantized) 110 self._int_types = set([dtype.as_numpy_dtype for dtype in self.int_tf_types]) 111 self.signed_int_types = set(dtype.as_numpy_dtype 112 for dtype in self.int_tf_types 113 if not dtype.is_unsigned) 114 self.unsigned_int_types = set(dtype.as_numpy_dtype 115 for dtype in self.int_tf_types 116 if dtype.is_unsigned) 117 self._float_types = set( 118 [dtype.as_numpy_dtype for dtype in self._float_tf_types]) 119 self.complex_types = set([ 120 dtype.as_numpy_dtype for dtype in self.complex_tf_types 121 ]) 122 self._numeric_types = set(self._int_types | self._float_types 123 | self.complex_types) 124 125 # Parse the manifest file, if any, into a regex identifying tests to 126 # disable 127 # TODO(xpan): Make it text proto if it doesn't scale. 128 # Each line of the manifest file specifies an entry. The entry can be 129 # 1) TestNameRegex // E.g. CumprodTest.* Or 130 # 2) TestName TypeName // E.g. AdamOptimizerTest.testSharing DT_BFLOAT16 131 # The 1) disables the entire test. While 2) only filter some numeric types 132 # so that they are not used in those tests. 133 self.disabled_regex = None 134 self._method_types_filter = {} 135 136 if FLAGS.disabled_manifest is not None: 137 with open(FLAGS.disabled_manifest, 'r') as manifest_file: 138 disabled_regex, self._method_types_filter = ( 139 parse_disabled_manifest(manifest_file.read())) 140 if disabled_regex: 141 self.disabled_regex = re.compile(disabled_regex) 142 143 if FLAGS.tf_xla_flags is not None: 144 os.environ['TF_XLA_FLAGS'] = FLAGS.tf_xla_flags 145 146 @property 147 def all_tf_types(self): 148 name = '{}.{}'.format(type(self).__name__, self._testMethodName) 149 tf_types = set([dtypes.as_dtype(t) 150 for t in self._method_types_filter.get(name, set())]) 151 return self._all_tf_types - tf_types 152 153 @property 154 def float_types(self): 155 name = '{}.{}'.format(type(self).__name__, self._testMethodName) 156 return self._float_types - self._method_types_filter.get(name, set()) 157 158 @property 159 def float_tf_types(self): 160 name = '{}.{}'.format(type(self).__name__, self._testMethodName) 161 return self._float_tf_types - self._method_types_filter.get(name, set()) 162 163 @property 164 def int_types(self): 165 name = '{}.{}'.format(type(self).__name__, self._testMethodName) 166 return self._int_types - self._method_types_filter.get(name, set()) 167 168 @property 169 def numeric_tf_types(self): 170 name = '{}.{}'.format(type(self).__name__, self._testMethodName) 171 tf_types = set([dtypes.as_dtype(t) 172 for t in self._method_types_filter.get(name, set())]) 173 return self._numeric_tf_types - tf_types 174 175 @property 176 def numeric_types(self): 177 name = '{}.{}'.format(type(self).__name__, self._testMethodName) 178 return self._numeric_types - self._method_types_filter.get(name, set()) 179 180 @property 181 def all_types(self): 182 name = '{}.{}'.format(type(self).__name__, self._testMethodName) 183 return self._all_types - self._method_types_filter.get(name, set()) 184 185 def setUp(self): 186 super(XLATestCase, self).setUp() 187 name = '{}.{}'.format(type(self).__name__, self._testMethodName) 188 if self.disabled_regex is not None and self.disabled_regex.match(name): 189 logging.info('Disabled test case: %s', name) 190 self.skipTest('{} is disabled by manifest.'.format(name)) 191 return 192 logging.info('Start test case: %s', name) 193 194 random.seed(random_seed.DEFAULT_GRAPH_SEED) 195 np.random.seed(random_seed.DEFAULT_GRAPH_SEED) 196 197 def tearDown(self): 198 super(XLATestCase, self).tearDown() 199 logging.info('End test case: %s', self._testMethodName) 200 201 @contextlib.contextmanager 202 def test_session(self): 203 """Custom implementation of test_session() for XLA tests. 204 205 We override the standard Tensorflow test_session() since it is too 206 specific to CPU and GPU tests. In particular, we want to disable soft 207 placement and explicitly assign ops to devices under test. 208 209 Yields: 210 A session to use when running a test case. 211 """ 212 graph = ops.Graph() 213 with session.Session(graph=graph) as sess, graph.as_default(): 214 yield sess 215 216 @contextlib.contextmanager 217 def test_scope(self): 218 """Test scope that runs tests on a Tensorflow/XLA device. 219 220 Uses a compilation_scope() to mark operators to compile. 221 222 Yields: 223 A scope to apply to the operators under test. 224 """ 225 with ops.device('device:{}:0'.format(self.device)): 226 yield 227 228 229def Benchmark(tf_bench, 230 builder_fn, 231 use_xla_jit, 232 device, 233 separate_compiled_gradients=False): 234 """Build a graph and run benchmarks against it, with or without XLA. 235 236 Args: 237 tf_bench: An instance of tf.test.Benchmark, used to run the benchmark. 238 builder_fn: A function that builds a graph when invoked, and returns 239 (name, fetches), where name is the name of the test, and fetches 240 is a list of tensors to fetch as output. 241 use_xla_jit: If true compile with the XLA JIT, otherwise use regular TF. 242 device: The tensorflow device to run on, e.g. "cpu", "gpu". 243 separate_compiled_gradients: If true put each gradient subgraph into a 244 separate compilation scope. This gives fine-grained control over which 245 portions of the graph will be compiled as a single unit. Compiling 246 gradients separately may yield better performance for some graphs. 247 The scope is named based on the scope of the forward computation as well 248 as the name of the gradients. As a result, the gradients will be compiled 249 in a scope that is separate from both the forward computation, and from 250 other gradients. 251 """ 252 253 with ops.Graph().as_default(): 254 name = None 255 targets = [] 256 with ops.device(device): 257 fetches = [] 258 jit_scope = jit.experimental_jit_scope 259 with jit_scope( 260 compile_ops=use_xla_jit, 261 separate_compiled_gradients=separate_compiled_gradients): 262 name, fetches = builder_fn() 263 264 # We only want to benchmark the operations themselves, and not the data 265 # transfer of the result(s). Non-compiled identity ops ensure XLA 266 # doesn't know we're dropping the results, otherwise it might compile 267 # away the entire computation. 268 for fetch in fetches: 269 targets.append(array_ops.identity(fetch).op) 270 271 config = config_pb2.ConfigProto(allow_soft_placement=True) 272 with session.Session(config=config) as sess: 273 sess.run(variables.global_variables_initializer()) 274 xla = 'xla_' if use_xla_jit else '' 275 tf_bench.run_op_benchmark( 276 sess, targets, name='%s_%s%s' % (name, xla, device)) 277