1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities to test TF-TensorRT integration.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from collections import namedtuple 22import itertools 23import os 24import warnings 25import numpy as np 26import six 27 28from tensorflow.core.protobuf import config_pb2 29from tensorflow.core.protobuf import rewriter_config_pb2 30from tensorflow.python.compiler.tensorrt import trt_convert 31from tensorflow.python.compiler.tensorrt.wrap_conversion import is_tensorrt_enabled 32from tensorflow.python.framework import dtypes 33from tensorflow.python.framework import graph_io 34from tensorflow.python.framework import importer 35from tensorflow.python.framework import ops 36from tensorflow.python.framework import test_util 37from tensorflow.python.ops import math_ops 38from tensorflow.python.platform import tf_logging as logging 39 40TfTrtIntegrationTestParams = namedtuple( 41 "TfTrtIntegrationTestParams", 42 [ 43 "gdef", 44 # A list of names of the input placeholder nodes. 45 "input_names", 46 # A list of list of output shapes of the input placeholder nodes. 47 "input_dims", 48 # A list of names of the output identity nodes. 49 "output_names", 50 # A list of list of expected output shapes of the output identity nodes. 51 "expected_output_dims" 52 ]) 53 54RunParams = namedtuple("RunParams", [ 55 "use_optimizer", "precision_mode", "dynamic_engine", "test_name", 56 "use_calibration" 57]) 58 59ConversionParams = namedtuple("ConversionParams", [ 60 "max_batch_size", "max_workspace_size_bytes", "precision_mode", 61 "minimum_segment_size", "is_dynamic_op", "maximum_cached_engines", 62 "cached_engine_batches", "rewriter_config", "use_calibration" 63]) 64 65PRECISION_MODES = ["FP32", "FP16", "INT8"] 66 67 68def IsQuantizationMode(mode): 69 return mode == "INT8" 70 71 72def IsQuantizationWithCalibration(params): 73 return IsQuantizationMode(params.precision_mode) and params.use_calibration 74 75 76class GraphState(object): 77 ORIGINAL = 0 78 CALIBRATE = 1 79 INFERENCE = 2 80 81 82def OptimizerDisabledRewriterConfig(): 83 """Returns a RewriterConfig with all default Grappler optimizers disabled.""" 84 rewriter_config = rewriter_config_pb2.RewriterConfig() 85 86 # Turn off all default Grappler optimizers. 87 off = rewriter_config_pb2.RewriterConfig.OFF 88 rewriter_config.layout_optimizer = off 89 rewriter_config.constant_folding = off 90 rewriter_config.shape_optimization = off 91 rewriter_config.remapping = off 92 rewriter_config.arithmetic_optimization = off 93 rewriter_config.dependency_optimization = off 94 rewriter_config.loop_optimization = off 95 rewriter_config.function_optimization = off 96 rewriter_config.debug_stripper = off 97 rewriter_config.disable_model_pruning = True 98 rewriter_config.scoped_allocator_optimization = off 99 rewriter_config.memory_optimization = ( 100 rewriter_config_pb2.RewriterConfig.NO_MEM_OPT) 101 rewriter_config.pin_to_host_optimization = off 102 rewriter_config.auto_parallel.enable = False 103 104 # Run only once for each enabled optimizer. 105 rewriter_config.meta_optimizer_iterations = ( 106 rewriter_config_pb2.RewriterConfig.ONE) 107 return rewriter_config 108 109 110class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): 111 """Class to test Tensorflow-TensorRT integration.""" 112 113 @property 114 def trt_incompatible_op(self): 115 return math_ops.erf 116 117 @property 118 def precision_modes(self): 119 return ["FP32", "FP16", "INT8"] 120 121 # str is bytes in py2, but unicode in py3. 122 def _ToUnicode(self, s): 123 if six.PY2: 124 if isinstance(s, unicode): 125 return s 126 return s.decode("utf-8") 127 else: 128 if isinstance(s, str): 129 return s 130 return s.decode("utf-8") 131 132 def _ToBytes(self, s): 133 if six.PY2: 134 if isinstance(s, unicode): 135 return s.encode("utf-8") 136 return s 137 else: 138 if isinstance(s, str): 139 return s.encode("utf-8") 140 return s 141 142 def _ToString(self, s): 143 if six.PY2: 144 if isinstance(s, unicode): 145 return s.encode("utf-8") 146 return s 147 else: 148 if isinstance(s, str): 149 return s 150 return s.decode("utf-8") 151 152 def __init__(self, methodName="runTest"): # pylint: disable=invalid-name 153 super(TfTrtIntegrationTestBase, self).__init__(methodName) 154 self._trt_test_params = None 155 156 def setUp(self): 157 """Setup method.""" 158 super(TfTrtIntegrationTestBase, self).setUp() 159 warnings.simplefilter("always") 160 161 def GetParams(self): 162 """Return a TfTrtIntegrationTestParams for test, implemented by subclass.""" 163 raise NotImplementedError() 164 165 def GetConversionParams(self, run_params): 166 """Return a ConversionParams for test.""" 167 batch_list = [] 168 for dims_list in self._GetParamsCached().input_dims: 169 assert dims_list 170 # Each list of shapes should have same batch size. 171 input_batches = [dims[0] for dims in dims_list] 172 assert max(input_batches) == min(input_batches) 173 batch_list.append(input_batches[0]) 174 return ConversionParams( 175 # We use the minimum of all the batch sizes, so when multiple different 176 # input shapes are provided it'll always create new engines in the 177 # cache, and we can therefore test the cache behavior. 178 max_batch_size=min(batch_list), 179 max_workspace_size_bytes=1 << 25, 180 precision_mode=run_params.precision_mode, 181 minimum_segment_size=2, 182 is_dynamic_op=run_params.dynamic_engine, 183 maximum_cached_engines=1, 184 cached_engine_batches=None, 185 rewriter_config=None, 186 use_calibration=run_params.use_calibration) 187 188 def ShouldRunTest(self, run_params): 189 """Whether to run the test.""" 190 # This setting combination requires quantization nodes to be present in 191 # order to build the engine. 192 return not (IsQuantizationMode(run_params.precision_mode) and 193 not run_params.use_calibration) 194 195 def ExpectedEnginesToBuild(self, run_params): 196 """Return the expected engines to build, implemented by subclass.""" 197 raise NotImplementedError() 198 199 def ExpectedAbsoluteTolerance(self, run_params): 200 """The absolute tolerance to compare floating point results.""" 201 return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-02 202 203 def ExpectedRelativeTolerance(self, run_params): 204 """The relative tolerance to compare floating point results.""" 205 return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-02 206 207 def _GetParamsCached(self): 208 if self._trt_test_params is None: 209 self._trt_test_params = self.GetParams() 210 return self._trt_test_params 211 212 def _GetGPUOptions(self): 213 gpu_options = config_pb2.GPUOptions() 214 gpu_options.allow_growth = True 215 return gpu_options 216 217 def _GetConfigProto(self, run_params, graph_state): 218 """Get config proto based on specific settings.""" 219 conversion_params = self.GetConversionParams(run_params) 220 if graph_state == GraphState.INFERENCE and run_params.use_optimizer: 221 rewriter_cfg = trt_convert.TrtGraphConverter.get_tensorrt_rewriter_config( 222 conversion_params.rewriter_config, 223 conversion_params.max_batch_size, 224 conversion_params.max_workspace_size_bytes, 225 conversion_params.precision_mode, 226 conversion_params.minimum_segment_size, 227 conversion_params.is_dynamic_op, 228 conversion_params.maximum_cached_engines, 229 conversion_params.cached_engine_batches, 230 conversion_params.use_calibration, 231 use_function_backup=IsQuantizationWithCalibration(conversion_params)) 232 233 graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_cfg) 234 else: 235 graph_options = config_pb2.GraphOptions() 236 if conversion_params.rewriter_config is not None: 237 graph_options.rewrite_options.CopyFrom( 238 conversion_params.rewriter_config) 239 240 config = config_pb2.ConfigProto( 241 gpu_options=self._GetGPUOptions(), graph_options=graph_options) 242 return config 243 244 def _GetFeedNames(self): 245 params = self._GetParamsCached() 246 # Construct the feeds tensor names by appending :0 to the node names. 247 return [input_name + ":0" for input_name in params.input_names] 248 249 def _GetFetchNames(self): 250 params = self._GetParamsCached() 251 # Construct the fetches tensor names by appending :0 to the node names. 252 return [output_name + ":0" for output_name in params.output_names] 253 254 def _GetFeedDict(self, inputs_data, input_shape_index): 255 assert input_shape_index < len(inputs_data) 256 feeds = self._GetFeedNames() 257 return { 258 feeds[i]: inputs_data[input_shape_index][i] for i in range(len(feeds)) 259 } 260 261 def _RunGraph(self, 262 run_params, 263 gdef, 264 inputs_data, 265 config, 266 graph_state, 267 num_runs=2): 268 """Run given graphdef multiple times.""" 269 params = self._GetParamsCached() 270 for data in inputs_data: 271 assert len(params.input_names) == len(data) 272 273 fetches = self._GetFetchNames() 274 g = ops.Graph() 275 with g.as_default(): 276 importer.import_graph_def(graph_def=gdef, name="") 277 with self.session( 278 graph=g, config=config, use_gpu=True, force_gpu=True) as sess: 279 vals = [] 280 # Run for each input(s) shape 281 for shape_index in range(len(inputs_data)): 282 val = None 283 for _ in range(num_runs): 284 new_val = sess.run(fetches, 285 self._GetFeedDict(inputs_data, shape_index)) 286 output_len = len(params.expected_output_dims[shape_index]) 287 self.assertEqual(output_len, len(new_val)) 288 for i in range(output_len): 289 self.assertEqual( 290 list(params.expected_output_dims[shape_index][i]), 291 list(new_val[i].shape)) 292 if val is not None: 293 self.assertAllClose(val, new_val, atol=1.e-06, rtol=1.e-06) 294 val = new_val 295 vals.append(val) 296 return vals 297 298 def _CreateConverter(self, gdef, session_config, conversion_params): 299 """Return a TrtGraphConverter.""" 300 params = self._GetParamsCached() 301 converter = trt_convert.TrtGraphConverter( 302 input_graph_def=gdef, 303 nodes_blacklist=params.input_names + params.output_names, 304 session_config=session_config, 305 max_batch_size=conversion_params.max_batch_size, 306 max_workspace_size_bytes=conversion_params.max_workspace_size_bytes, 307 precision_mode=conversion_params.precision_mode, 308 minimum_segment_size=conversion_params.minimum_segment_size, 309 is_dynamic_op=conversion_params.is_dynamic_op, 310 maximum_cached_engines=conversion_params.maximum_cached_engines, 311 cached_engine_batches=conversion_params.cached_engine_batches, 312 use_calibration=conversion_params.use_calibration, 313 use_function_backup=IsQuantizationWithCalibration(conversion_params)) 314 return converter 315 316 def _GetCalibratedInferGraph(self, run_params, gdef, inputs_data): 317 """Return trt converted graphdef in INT8 mode.""" 318 conversion_params = self.GetConversionParams(run_params) 319 logging.info(conversion_params) 320 assert conversion_params.precision_mode == "INT8" 321 assert conversion_params.is_dynamic_op 322 assert conversion_params.maximum_cached_engines == 1 323 assert not conversion_params.cached_engine_batches 324 assert conversion_params.use_calibration 325 assert len(inputs_data) == 1 # We only support calibrating single engine. 326 327 session_config = self._GetConfigProto(run_params, GraphState.CALIBRATE) 328 logging.info("Running calibration graph, config:\n%s", str(session_config)) 329 330 converter = self._CreateConverter(gdef, session_config, conversion_params) 331 int8_gdef = converter.convert() 332 self._VerifyGraphDef(run_params, int8_gdef, GraphState.CALIBRATE) 333 334 return converter.calibrate( 335 fetch_names=self._GetFetchNames(), 336 num_runs=5, 337 feed_dict_fn=lambda: self._GetFeedDict(inputs_data, 0)) 338 339 def _GetInferGraph(self, run_params, gdef): 340 """Return trt converted graphdef.""" 341 conversion_params = self.GetConversionParams(run_params) 342 logging.info(conversion_params) 343 344 session_config = self._GetConfigProto(run_params, GraphState.INFERENCE) 345 logging.info("Creating TRT graph for inference, config\n%s", 346 str(session_config)) 347 converter = self._CreateConverter(gdef, session_config, conversion_params) 348 return converter.convert() 349 350 def _WriteGraph(self, run_params, gdef, graph_state): 351 if graph_state == GraphState.ORIGINAL: 352 label = "Original" 353 elif graph_state == GraphState.CALIBRATE: 354 label = "CalibEngine" 355 elif graph_state == GraphState.INFERENCE: 356 label = "InferEngine" 357 graph_name = ( 358 self.__class__.__name__ + "_" + run_params.test_name + "_" + label + 359 ".pbtxt") 360 temp_dir = os.getenv("TRT_TEST_TMPDIR", self.get_temp_dir()) 361 if temp_dir: 362 logging.info("Writing graph to %s/%s", temp_dir, graph_name) 363 graph_io.write_graph(gdef, temp_dir, graph_name) 364 365 def _VerifyConnections(self, expected_engines, converted_gdef): 366 params = self._GetParamsCached() 367 old_to_new_node_map = { 368 self._ToString(node.name): self._ToString(node.name) 369 for node in params.gdef.node 370 } 371 for engine_name, node_names in expected_engines.items(): 372 for node_name in node_names: 373 old_to_new_node_map[node_name] = engine_name 374 name_to_node_map = { 375 self._ToString(node.name): node for node in params.gdef.node 376 } 377 378 def _InputName(inp): 379 inp = self._ToString(inp) 380 prefix = "" 381 if inp[0] == "^": 382 prefix = "^" 383 inp = inp[1:] 384 parts = inp.split(":") 385 if len(parts) > 1 and parts[-1].isdigit(): 386 inp = inp[:-len(parts[-1]) - 1] 387 return (prefix, inp) 388 389 expected_input_map = {} 390 for node in params.gdef.node: 391 name_str = self._ToString(node.name) 392 target_node_name = old_to_new_node_map[name_str] 393 is_engine_op = (target_node_name != name_str) 394 if target_node_name not in expected_input_map: 395 expected_input_map[target_node_name] = set() 396 input_set = expected_input_map[target_node_name] 397 for inp in node.input: 398 (prefix, inp_name) = _InputName(inp) 399 # Add the input only if it's outside the segment (note that it could be 400 # in a different engine). 401 if (not is_engine_op or 402 old_to_new_node_map[inp_name] != target_node_name): 403 if is_engine_op and name_to_node_map[inp_name].op == "Const": 404 # Const data input nodes to the segment has been copied to the 405 # segment graphdef and the engine, and the dependency has been 406 # converted to control dependendy. 407 input_set.add("^" + old_to_new_node_map[inp_name]) 408 else: 409 input_set.add(prefix + old_to_new_node_map[inp_name]) 410 411 actual_input_map = {} 412 for node in converted_gdef.node: 413 name_str = self._ToString(node.name) 414 actual_input_map[name_str] = set() 415 input_set = actual_input_map[name_str] 416 for inp in node.input: 417 (prefix, node_name) = _InputName(inp) 418 input_set.add(prefix + node_name) 419 420 self.assertEqual( 421 expected_input_map, 422 actual_input_map, 423 msg="expected:\n%s\nvs actual:\n%s" % 424 (sorted(expected_input_map.items()), sorted(actual_input_map.items()))) 425 426 def _VerifyGraphDef(self, run_params, gdef, graph_state): 427 self._WriteGraph(run_params, gdef, graph_state) 428 429 expected_engines = self.ExpectedEnginesToBuild(run_params) 430 num_engines = 0 431 functions = [f.signature.name for f in gdef.library.function] 432 for node in gdef.node: 433 if node.op == "TRTEngineOp": 434 logging.info("Found TRTEngineOp: " + node.name) 435 num_engines += 1 436 segment_funcdef_name = node.attr["segment_funcdef_name"].s 437 function_name = node.name + "_native_segment" 438 if IsQuantizationWithCalibration(run_params): 439 self.assertNotEmpty(segment_funcdef_name, node.name) 440 self.assertIn(function_name, functions) 441 else: 442 self.assertEmpty(segment_funcdef_name, node.name) 443 self.assertNotIn(function_name, functions) 444 self.assertIn(node.name, expected_engines) 445 self.assertTrue(len(node.attr["serialized_segment"].s), node.name) 446 self.assertEqual( 447 self._ToBytes(run_params.precision_mode), 448 node.attr["precision_mode"].s, node.name) 449 450 is_dynamic_engine = not node.attr["static_engine"].b 451 self.assertEqual(run_params.dynamic_engine, is_dynamic_engine, 452 node.name) 453 self.assertEqual(node.attr["use_calibration"].b, 454 run_params.use_calibration, node.name) 455 456 has_calibration_data = len(node.attr["calibration_data"].s) 457 if (IsQuantizationMode(run_params.precision_mode) and 458 run_params.use_calibration and graph_state == GraphState.INFERENCE): 459 self.assertTrue(has_calibration_data, node.name) 460 else: 461 self.assertFalse(has_calibration_data, node.name) 462 if graph_state == GraphState.ORIGINAL: 463 self.assertEqual(0, num_engines) 464 else: 465 self.assertEqual(num_engines, len(expected_engines)) 466 if isinstance(expected_engines, dict): 467 self._VerifyConnections(expected_engines, gdef) 468 # TODO(aaroey): consider verifying the corresponding TF function. 469 470 def RunTest(self, run_params): 471 if not self.ShouldRunTest(run_params): 472 return 473 assert run_params.precision_mode in PRECISION_MODES 474 np.random.seed(12345) 475 476 params = self._GetParamsCached() 477 input_gdef = params.gdef 478 input_dtypes = {} 479 for node in input_gdef.node: 480 if self._ToString(node.name) in params.input_names: 481 assert self._ToString(node.op) == "Placeholder" 482 input_dtypes[self._ToString(node.name)] = ( 483 dtypes.as_dtype(node.attr["dtype"].type).as_numpy_dtype()) 484 assert len(params.input_names) == len(input_dtypes) 485 486 inputs_data = [] 487 for inp in params.input_dims: 488 current_input_data = [] 489 for i in range(len(params.input_names)): 490 dtype = input_dtypes[params.input_names[i]] 491 # Multiply the input by some constant to avoid all zeros input for 492 # integer types. 493 scale = 10.0 if np.issubdtype(dtype, np.integer) else 1.0 494 dims = inp[i] 495 # TODO(laigd): add debug options. E.g. we can set the input data to be 496 # continuous natural numbers: 497 # seq = np.arange(np.prod(dims)) 498 # seq.resize(dims) 499 # input_data.append(scale * seq.astype(dtype)) 500 current_input_data.append( 501 (scale * np.random.random_sample(dims)).astype(dtype)) 502 inputs_data.append(current_input_data) 503 504 # Verify original graph. 505 self._VerifyGraphDef(run_params, input_gdef, GraphState.ORIGINAL) 506 507 # Run original graph without trt to get reference result. 508 config_no_trt = self._GetConfigProto(run_params, GraphState.ORIGINAL) 509 logging.info("Running original graph w/o trt, config:\n%s", 510 str(config_no_trt)) 511 ref_result = self._RunGraph(run_params, input_gdef, inputs_data, 512 config_no_trt, GraphState.ORIGINAL) 513 514 # Run calibration if necessary. 515 if (IsQuantizationMode(run_params.precision_mode) and 516 run_params.use_calibration): 517 infer_gdef = self._GetCalibratedInferGraph(run_params, input_gdef, 518 inputs_data) 519 self._VerifyGraphDef(run_params, infer_gdef, GraphState.INFERENCE) 520 elif not run_params.use_optimizer: 521 infer_gdef = self._GetInferGraph(run_params, input_gdef) 522 self._VerifyGraphDef(run_params, infer_gdef, GraphState.INFERENCE) 523 else: 524 infer_gdef = input_gdef 525 526 # Run inference. 527 infer_config = self._GetConfigProto(run_params, GraphState.INFERENCE) 528 logging.info("Running final inference graph, config:\n%s", 529 str(infer_config)) 530 result = self._RunGraph(run_params, infer_gdef, inputs_data, infer_config, 531 GraphState.INFERENCE) 532 self.assertAllClose( 533 ref_result, 534 result, 535 atol=self.ExpectedAbsoluteTolerance(run_params), 536 rtol=self.ExpectedRelativeTolerance(run_params)) 537 538 def testIdempotence(self): 539 # Test that applying tensorrt optimizer or offline conversion tools multiple 540 # times to the same graph will result in same graph. 541 # 542 # TODO(aaroey): currently the conversion is not deterministic, this is 543 # mainly because during tensorflow::ConvertGraphDefToGraph(), the graph uses 544 # EdgeSet which use a map keyed by Edge*, so the order of input/output edges 545 # of a node is nondeterministic, thus the order for segmenter to contract 546 # edges is nondeterministic. Need to evaluate whether we should fix this. 547 pass 548 549 550def _AddTests(test_class): 551 """Adds test methods to TfTrtIntegrationTestBase.""" 552 553 def _GetTest(run_params): 554 """Gets a single test method based on the parameters.""" 555 556 def _Test(self): 557 logging.info( 558 "Running test %s with parameters: use_optimizer=%s, " 559 "precision_mode=%s, dynamic_engine=%s", 560 "testTfTrt_" + run_params.test_name, run_params.use_optimizer, 561 run_params.precision_mode, run_params.dynamic_engine) 562 self.RunTest(run_params) 563 564 return _Test 565 566 use_optimizer_options = [False, True] 567 dynamic_engine_options = [False, True] 568 use_calibration_options = [False, True] 569 opts = itertools.product(use_optimizer_options, PRECISION_MODES, 570 dynamic_engine_options, use_calibration_options) 571 for (use_optimizer, precision_mode, dynamic_engine, use_calibration) in opts: 572 if IsQuantizationMode(precision_mode): 573 if use_optimizer: 574 # We ignore the use_optimizer option and always use TrtGraphConverter 575 # for INT8 mode, so no need to run it twice. 576 continue 577 if use_calibration and not dynamic_engine: 578 # Static engine with use_calibration=False will be static, so we want to 579 # test that. If use_calibration=True, only dynamic op is supported. 580 # TODO(aaroey): construction of static calibration engine is not 581 # supported yet. 582 continue 583 else: 584 if use_calibration: 585 # Don't calibrate in FP32 or FP16 mode 586 continue 587 588 conversion = "OptimizerConversion" if use_optimizer else "ToolConversion" 589 engine_type = "DynamicEngine" if dynamic_engine else "StaticEngine" 590 calibration_type = "UseCalibration" if use_calibration else "NoCalibration" 591 test_name = "%s_%s_%s_%s" % (conversion, engine_type, precision_mode, 592 calibration_type) 593 run_params = RunParams( 594 use_optimizer=use_optimizer, 595 precision_mode=precision_mode, 596 dynamic_engine=dynamic_engine, 597 test_name=test_name, 598 use_calibration=use_calibration) 599 setattr(test_class, "testTfTrt_" + test_name, _GetTest(run_params)) 600 601 602if is_tensorrt_enabled(): 603 _AddTests(TfTrtIntegrationTestBase) 604