1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities to test TF-TensorRT integration."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from collections import namedtuple
22import itertools
23import os
24import warnings
25import numpy as np
26import six
27
28from tensorflow.core.protobuf import config_pb2
29from tensorflow.core.protobuf import rewriter_config_pb2
30from tensorflow.python.compiler.tensorrt import trt_convert
31from tensorflow.python.compiler.tensorrt.wrap_conversion import is_tensorrt_enabled
32from tensorflow.python.framework import dtypes
33from tensorflow.python.framework import graph_io
34from tensorflow.python.framework import importer
35from tensorflow.python.framework import ops
36from tensorflow.python.framework import test_util
37from tensorflow.python.ops import math_ops
38from tensorflow.python.platform import tf_logging as logging
39
40TfTrtIntegrationTestParams = namedtuple(
41    "TfTrtIntegrationTestParams",
42    [
43        "gdef",
44        # A list of names of the input placeholder nodes.
45        "input_names",
46        # A list of list of output shapes of the input placeholder nodes.
47        "input_dims",
48        # A list of names of the output identity nodes.
49        "output_names",
50        # A list of list of expected output shapes of the output identity nodes.
51        "expected_output_dims"
52    ])
53
54RunParams = namedtuple("RunParams", [
55    "use_optimizer", "precision_mode", "dynamic_engine", "test_name",
56    "use_calibration"
57])
58
59ConversionParams = namedtuple("ConversionParams", [
60    "max_batch_size", "max_workspace_size_bytes", "precision_mode",
61    "minimum_segment_size", "is_dynamic_op", "maximum_cached_engines",
62    "cached_engine_batches", "rewriter_config", "use_calibration"
63])
64
65PRECISION_MODES = ["FP32", "FP16", "INT8"]
66
67
68def IsQuantizationMode(mode):
69  return mode == "INT8"
70
71
72def IsQuantizationWithCalibration(params):
73  return IsQuantizationMode(params.precision_mode) and params.use_calibration
74
75
76class GraphState(object):
77  ORIGINAL = 0
78  CALIBRATE = 1
79  INFERENCE = 2
80
81
82def OptimizerDisabledRewriterConfig():
83  """Returns a RewriterConfig with all default Grappler optimizers disabled."""
84  rewriter_config = rewriter_config_pb2.RewriterConfig()
85
86  # Turn off all default Grappler optimizers.
87  off = rewriter_config_pb2.RewriterConfig.OFF
88  rewriter_config.layout_optimizer = off
89  rewriter_config.constant_folding = off
90  rewriter_config.shape_optimization = off
91  rewriter_config.remapping = off
92  rewriter_config.arithmetic_optimization = off
93  rewriter_config.dependency_optimization = off
94  rewriter_config.loop_optimization = off
95  rewriter_config.function_optimization = off
96  rewriter_config.debug_stripper = off
97  rewriter_config.disable_model_pruning = True
98  rewriter_config.scoped_allocator_optimization = off
99  rewriter_config.memory_optimization = (
100      rewriter_config_pb2.RewriterConfig.NO_MEM_OPT)
101  rewriter_config.pin_to_host_optimization = off
102  rewriter_config.auto_parallel.enable = False
103
104  # Run only once for each enabled optimizer.
105  rewriter_config.meta_optimizer_iterations = (
106      rewriter_config_pb2.RewriterConfig.ONE)
107  return rewriter_config
108
109
110class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
111  """Class to test Tensorflow-TensorRT integration."""
112
113  @property
114  def trt_incompatible_op(self):
115    return math_ops.erf
116
117  @property
118  def precision_modes(self):
119    return ["FP32", "FP16", "INT8"]
120
121  # str is bytes in py2, but unicode in py3.
122  def _ToUnicode(self, s):
123    if six.PY2:
124      if isinstance(s, unicode):
125        return s
126      return s.decode("utf-8")
127    else:
128      if isinstance(s, str):
129        return s
130      return s.decode("utf-8")
131
132  def _ToBytes(self, s):
133    if six.PY2:
134      if isinstance(s, unicode):
135        return s.encode("utf-8")
136      return s
137    else:
138      if isinstance(s, str):
139        return s.encode("utf-8")
140      return s
141
142  def _ToString(self, s):
143    if six.PY2:
144      if isinstance(s, unicode):
145        return s.encode("utf-8")
146      return s
147    else:
148      if isinstance(s, str):
149        return s
150      return s.decode("utf-8")
151
152  def __init__(self, methodName="runTest"):  # pylint: disable=invalid-name
153    super(TfTrtIntegrationTestBase, self).__init__(methodName)
154    self._trt_test_params = None
155
156  def setUp(self):
157    """Setup method."""
158    super(TfTrtIntegrationTestBase, self).setUp()
159    warnings.simplefilter("always")
160
161  def GetParams(self):
162    """Return a TfTrtIntegrationTestParams for test, implemented by subclass."""
163    raise NotImplementedError()
164
165  def GetConversionParams(self, run_params):
166    """Return a ConversionParams for test."""
167    batch_list = []
168    for dims_list in self._GetParamsCached().input_dims:
169      assert dims_list
170      # Each list of shapes should have same batch size.
171      input_batches = [dims[0] for dims in dims_list]
172      assert max(input_batches) == min(input_batches)
173      batch_list.append(input_batches[0])
174    return ConversionParams(
175        # We use the minimum of all the batch sizes, so when multiple different
176        # input shapes are provided it'll always create new engines in the
177        # cache, and we can therefore test the cache behavior.
178        max_batch_size=min(batch_list),
179        max_workspace_size_bytes=1 << 25,
180        precision_mode=run_params.precision_mode,
181        minimum_segment_size=2,
182        is_dynamic_op=run_params.dynamic_engine,
183        maximum_cached_engines=1,
184        cached_engine_batches=None,
185        rewriter_config=None,
186        use_calibration=run_params.use_calibration)
187
188  def ShouldRunTest(self, run_params):
189    """Whether to run the test."""
190    # This setting combination requires quantization nodes to be present in
191    # order to build the engine.
192    return not (IsQuantizationMode(run_params.precision_mode) and
193                not run_params.use_calibration)
194
195  def ExpectedEnginesToBuild(self, run_params):
196    """Return the expected engines to build, implemented by subclass."""
197    raise NotImplementedError()
198
199  def ExpectedAbsoluteTolerance(self, run_params):
200    """The absolute tolerance to compare floating point results."""
201    return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-02
202
203  def ExpectedRelativeTolerance(self, run_params):
204    """The relative tolerance to compare floating point results."""
205    return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-02
206
207  def _GetParamsCached(self):
208    if self._trt_test_params is None:
209      self._trt_test_params = self.GetParams()
210    return self._trt_test_params
211
212  def _GetGPUOptions(self):
213    gpu_options = config_pb2.GPUOptions()
214    gpu_options.allow_growth = True
215    return gpu_options
216
217  def _GetConfigProto(self, run_params, graph_state):
218    """Get config proto based on specific settings."""
219    conversion_params = self.GetConversionParams(run_params)
220    if graph_state == GraphState.INFERENCE and run_params.use_optimizer:
221      rewriter_cfg = trt_convert.TrtGraphConverter.get_tensorrt_rewriter_config(
222          conversion_params.rewriter_config,
223          conversion_params.max_batch_size,
224          conversion_params.max_workspace_size_bytes,
225          conversion_params.precision_mode,
226          conversion_params.minimum_segment_size,
227          conversion_params.is_dynamic_op,
228          conversion_params.maximum_cached_engines,
229          conversion_params.cached_engine_batches,
230          conversion_params.use_calibration,
231          use_function_backup=IsQuantizationWithCalibration(conversion_params))
232
233      graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_cfg)
234    else:
235      graph_options = config_pb2.GraphOptions()
236      if conversion_params.rewriter_config is not None:
237        graph_options.rewrite_options.CopyFrom(
238            conversion_params.rewriter_config)
239
240    config = config_pb2.ConfigProto(
241        gpu_options=self._GetGPUOptions(), graph_options=graph_options)
242    return config
243
244  def _GetFeedNames(self):
245    params = self._GetParamsCached()
246    # Construct the feeds tensor names by appending :0 to the node names.
247    return [input_name + ":0" for input_name in params.input_names]
248
249  def _GetFetchNames(self):
250    params = self._GetParamsCached()
251    # Construct the fetches tensor names by appending :0 to the node names.
252    return [output_name + ":0" for output_name in params.output_names]
253
254  def _GetFeedDict(self, inputs_data, input_shape_index):
255    assert input_shape_index < len(inputs_data)
256    feeds = self._GetFeedNames()
257    return {
258        feeds[i]: inputs_data[input_shape_index][i] for i in range(len(feeds))
259    }
260
261  def _RunGraph(self,
262                run_params,
263                gdef,
264                inputs_data,
265                config,
266                graph_state,
267                num_runs=2):
268    """Run given graphdef multiple times."""
269    params = self._GetParamsCached()
270    for data in inputs_data:
271      assert len(params.input_names) == len(data)
272
273    fetches = self._GetFetchNames()
274    g = ops.Graph()
275    with g.as_default():
276      importer.import_graph_def(graph_def=gdef, name="")
277      with self.session(
278          graph=g, config=config, use_gpu=True, force_gpu=True) as sess:
279        vals = []
280        # Run for each input(s) shape
281        for shape_index in range(len(inputs_data)):
282          val = None
283          for _ in range(num_runs):
284            new_val = sess.run(fetches,
285                               self._GetFeedDict(inputs_data, shape_index))
286            output_len = len(params.expected_output_dims[shape_index])
287            self.assertEqual(output_len, len(new_val))
288            for i in range(output_len):
289              self.assertEqual(
290                  list(params.expected_output_dims[shape_index][i]),
291                  list(new_val[i].shape))
292            if val is not None:
293              self.assertAllClose(val, new_val, atol=1.e-06, rtol=1.e-06)
294            val = new_val
295          vals.append(val)
296        return vals
297
298  def _CreateConverter(self, gdef, session_config, conversion_params):
299    """Return a TrtGraphConverter."""
300    params = self._GetParamsCached()
301    converter = trt_convert.TrtGraphConverter(
302        input_graph_def=gdef,
303        nodes_blacklist=params.input_names + params.output_names,
304        session_config=session_config,
305        max_batch_size=conversion_params.max_batch_size,
306        max_workspace_size_bytes=conversion_params.max_workspace_size_bytes,
307        precision_mode=conversion_params.precision_mode,
308        minimum_segment_size=conversion_params.minimum_segment_size,
309        is_dynamic_op=conversion_params.is_dynamic_op,
310        maximum_cached_engines=conversion_params.maximum_cached_engines,
311        cached_engine_batches=conversion_params.cached_engine_batches,
312        use_calibration=conversion_params.use_calibration,
313        use_function_backup=IsQuantizationWithCalibration(conversion_params))
314    return converter
315
316  def _GetCalibratedInferGraph(self, run_params, gdef, inputs_data):
317    """Return trt converted graphdef in INT8 mode."""
318    conversion_params = self.GetConversionParams(run_params)
319    logging.info(conversion_params)
320    assert conversion_params.precision_mode == "INT8"
321    assert conversion_params.is_dynamic_op
322    assert conversion_params.maximum_cached_engines == 1
323    assert not conversion_params.cached_engine_batches
324    assert conversion_params.use_calibration
325    assert len(inputs_data) == 1  # We only support calibrating single engine.
326
327    session_config = self._GetConfigProto(run_params, GraphState.CALIBRATE)
328    logging.info("Running calibration graph, config:\n%s", str(session_config))
329
330    converter = self._CreateConverter(gdef, session_config, conversion_params)
331    int8_gdef = converter.convert()
332    self._VerifyGraphDef(run_params, int8_gdef, GraphState.CALIBRATE)
333
334    return converter.calibrate(
335        fetch_names=self._GetFetchNames(),
336        num_runs=5,
337        feed_dict_fn=lambda: self._GetFeedDict(inputs_data, 0))
338
339  def _GetInferGraph(self, run_params, gdef):
340    """Return trt converted graphdef."""
341    conversion_params = self.GetConversionParams(run_params)
342    logging.info(conversion_params)
343
344    session_config = self._GetConfigProto(run_params, GraphState.INFERENCE)
345    logging.info("Creating TRT graph for inference, config\n%s",
346                 str(session_config))
347    converter = self._CreateConverter(gdef, session_config, conversion_params)
348    return converter.convert()
349
350  def _WriteGraph(self, run_params, gdef, graph_state):
351    if graph_state == GraphState.ORIGINAL:
352      label = "Original"
353    elif graph_state == GraphState.CALIBRATE:
354      label = "CalibEngine"
355    elif graph_state == GraphState.INFERENCE:
356      label = "InferEngine"
357    graph_name = (
358        self.__class__.__name__ + "_" + run_params.test_name + "_" + label +
359        ".pbtxt")
360    temp_dir = os.getenv("TRT_TEST_TMPDIR", self.get_temp_dir())
361    if temp_dir:
362      logging.info("Writing graph to %s/%s", temp_dir, graph_name)
363      graph_io.write_graph(gdef, temp_dir, graph_name)
364
365  def _VerifyConnections(self, expected_engines, converted_gdef):
366    params = self._GetParamsCached()
367    old_to_new_node_map = {
368        self._ToString(node.name): self._ToString(node.name)
369        for node in params.gdef.node
370    }
371    for engine_name, node_names in expected_engines.items():
372      for node_name in node_names:
373        old_to_new_node_map[node_name] = engine_name
374    name_to_node_map = {
375        self._ToString(node.name): node for node in params.gdef.node
376    }
377
378    def _InputName(inp):
379      inp = self._ToString(inp)
380      prefix = ""
381      if inp[0] == "^":
382        prefix = "^"
383        inp = inp[1:]
384      parts = inp.split(":")
385      if len(parts) > 1 and parts[-1].isdigit():
386        inp = inp[:-len(parts[-1]) - 1]
387      return (prefix, inp)
388
389    expected_input_map = {}
390    for node in params.gdef.node:
391      name_str = self._ToString(node.name)
392      target_node_name = old_to_new_node_map[name_str]
393      is_engine_op = (target_node_name != name_str)
394      if target_node_name not in expected_input_map:
395        expected_input_map[target_node_name] = set()
396      input_set = expected_input_map[target_node_name]
397      for inp in node.input:
398        (prefix, inp_name) = _InputName(inp)
399        # Add the input only if it's outside the segment (note that it could be
400        # in a different engine).
401        if (not is_engine_op or
402            old_to_new_node_map[inp_name] != target_node_name):
403          if is_engine_op and name_to_node_map[inp_name].op == "Const":
404            # Const data input nodes to the segment has been copied to the
405            # segment graphdef and the engine, and the dependency has been
406            # converted to control dependendy.
407            input_set.add("^" + old_to_new_node_map[inp_name])
408          else:
409            input_set.add(prefix + old_to_new_node_map[inp_name])
410
411    actual_input_map = {}
412    for node in converted_gdef.node:
413      name_str = self._ToString(node.name)
414      actual_input_map[name_str] = set()
415      input_set = actual_input_map[name_str]
416      for inp in node.input:
417        (prefix, node_name) = _InputName(inp)
418        input_set.add(prefix + node_name)
419
420    self.assertEqual(
421        expected_input_map,
422        actual_input_map,
423        msg="expected:\n%s\nvs actual:\n%s" %
424        (sorted(expected_input_map.items()), sorted(actual_input_map.items())))
425
426  def _VerifyGraphDef(self, run_params, gdef, graph_state):
427    self._WriteGraph(run_params, gdef, graph_state)
428
429    expected_engines = self.ExpectedEnginesToBuild(run_params)
430    num_engines = 0
431    functions = [f.signature.name for f in gdef.library.function]
432    for node in gdef.node:
433      if node.op == "TRTEngineOp":
434        logging.info("Found TRTEngineOp: " + node.name)
435        num_engines += 1
436        segment_funcdef_name = node.attr["segment_funcdef_name"].s
437        function_name = node.name + "_native_segment"
438        if IsQuantizationWithCalibration(run_params):
439          self.assertNotEmpty(segment_funcdef_name, node.name)
440          self.assertIn(function_name, functions)
441        else:
442          self.assertEmpty(segment_funcdef_name, node.name)
443          self.assertNotIn(function_name, functions)
444        self.assertIn(node.name, expected_engines)
445        self.assertTrue(len(node.attr["serialized_segment"].s), node.name)
446        self.assertEqual(
447            self._ToBytes(run_params.precision_mode),
448            node.attr["precision_mode"].s, node.name)
449
450        is_dynamic_engine = not node.attr["static_engine"].b
451        self.assertEqual(run_params.dynamic_engine, is_dynamic_engine,
452                         node.name)
453        self.assertEqual(node.attr["use_calibration"].b,
454                         run_params.use_calibration, node.name)
455
456        has_calibration_data = len(node.attr["calibration_data"].s)
457        if (IsQuantizationMode(run_params.precision_mode) and
458            run_params.use_calibration and graph_state == GraphState.INFERENCE):
459          self.assertTrue(has_calibration_data, node.name)
460        else:
461          self.assertFalse(has_calibration_data, node.name)
462    if graph_state == GraphState.ORIGINAL:
463      self.assertEqual(0, num_engines)
464    else:
465      self.assertEqual(num_engines, len(expected_engines))
466      if isinstance(expected_engines, dict):
467        self._VerifyConnections(expected_engines, gdef)
468      # TODO(aaroey): consider verifying the corresponding TF function.
469
470  def RunTest(self, run_params):
471    if not self.ShouldRunTest(run_params):
472      return
473    assert run_params.precision_mode in PRECISION_MODES
474    np.random.seed(12345)
475
476    params = self._GetParamsCached()
477    input_gdef = params.gdef
478    input_dtypes = {}
479    for node in input_gdef.node:
480      if self._ToString(node.name) in params.input_names:
481        assert self._ToString(node.op) == "Placeholder"
482        input_dtypes[self._ToString(node.name)] = (
483            dtypes.as_dtype(node.attr["dtype"].type).as_numpy_dtype())
484    assert len(params.input_names) == len(input_dtypes)
485
486    inputs_data = []
487    for inp in params.input_dims:
488      current_input_data = []
489      for i in range(len(params.input_names)):
490        dtype = input_dtypes[params.input_names[i]]
491        # Multiply the input by some constant to avoid all zeros input for
492        # integer types.
493        scale = 10.0 if np.issubdtype(dtype, np.integer) else 1.0
494        dims = inp[i]
495        # TODO(laigd): add debug options. E.g. we can set the input data to be
496        # continuous natural numbers:
497        # seq = np.arange(np.prod(dims))
498        # seq.resize(dims)
499        # input_data.append(scale * seq.astype(dtype))
500        current_input_data.append(
501            (scale * np.random.random_sample(dims)).astype(dtype))
502      inputs_data.append(current_input_data)
503
504    # Verify original graph.
505    self._VerifyGraphDef(run_params, input_gdef, GraphState.ORIGINAL)
506
507    # Run original graph without trt to get reference result.
508    config_no_trt = self._GetConfigProto(run_params, GraphState.ORIGINAL)
509    logging.info("Running original graph w/o trt, config:\n%s",
510                 str(config_no_trt))
511    ref_result = self._RunGraph(run_params, input_gdef, inputs_data,
512                                config_no_trt, GraphState.ORIGINAL)
513
514    # Run calibration if necessary.
515    if (IsQuantizationMode(run_params.precision_mode) and
516        run_params.use_calibration):
517      infer_gdef = self._GetCalibratedInferGraph(run_params, input_gdef,
518                                                 inputs_data)
519      self._VerifyGraphDef(run_params, infer_gdef, GraphState.INFERENCE)
520    elif not run_params.use_optimizer:
521      infer_gdef = self._GetInferGraph(run_params, input_gdef)
522      self._VerifyGraphDef(run_params, infer_gdef, GraphState.INFERENCE)
523    else:
524      infer_gdef = input_gdef
525
526    # Run inference.
527    infer_config = self._GetConfigProto(run_params, GraphState.INFERENCE)
528    logging.info("Running final inference graph, config:\n%s",
529                 str(infer_config))
530    result = self._RunGraph(run_params, infer_gdef, inputs_data, infer_config,
531                            GraphState.INFERENCE)
532    self.assertAllClose(
533        ref_result,
534        result,
535        atol=self.ExpectedAbsoluteTolerance(run_params),
536        rtol=self.ExpectedRelativeTolerance(run_params))
537
538  def testIdempotence(self):
539    # Test that applying tensorrt optimizer or offline conversion tools multiple
540    # times to the same graph will result in same graph.
541    #
542    # TODO(aaroey): currently the conversion is not deterministic, this is
543    # mainly because during tensorflow::ConvertGraphDefToGraph(), the graph uses
544    # EdgeSet which use a map keyed by Edge*, so the order of input/output edges
545    # of a node is nondeterministic, thus the order for segmenter to contract
546    # edges is nondeterministic. Need to evaluate whether we should fix this.
547    pass
548
549
550def _AddTests(test_class):
551  """Adds test methods to TfTrtIntegrationTestBase."""
552
553  def _GetTest(run_params):
554    """Gets a single test method based on the parameters."""
555
556    def _Test(self):
557      logging.info(
558          "Running test %s with parameters: use_optimizer=%s, "
559          "precision_mode=%s, dynamic_engine=%s",
560          "testTfTrt_" + run_params.test_name, run_params.use_optimizer,
561          run_params.precision_mode, run_params.dynamic_engine)
562      self.RunTest(run_params)
563
564    return _Test
565
566  use_optimizer_options = [False, True]
567  dynamic_engine_options = [False, True]
568  use_calibration_options = [False, True]
569  opts = itertools.product(use_optimizer_options, PRECISION_MODES,
570                           dynamic_engine_options, use_calibration_options)
571  for (use_optimizer, precision_mode, dynamic_engine, use_calibration) in opts:
572    if IsQuantizationMode(precision_mode):
573      if use_optimizer:
574        # We ignore the use_optimizer option and always use TrtGraphConverter
575        # for INT8 mode, so no need to run it twice.
576        continue
577      if use_calibration and not dynamic_engine:
578        # Static engine with use_calibration=False will be static, so we want to
579        # test that. If use_calibration=True, only dynamic op is supported.
580        # TODO(aaroey): construction of static calibration engine is not
581        # supported yet.
582        continue
583    else:
584      if use_calibration:
585        # Don't calibrate in FP32 or FP16 mode
586        continue
587
588    conversion = "OptimizerConversion" if use_optimizer else "ToolConversion"
589    engine_type = "DynamicEngine" if dynamic_engine else "StaticEngine"
590    calibration_type = "UseCalibration" if use_calibration else "NoCalibration"
591    test_name = "%s_%s_%s_%s" % (conversion, engine_type, precision_mode,
592                                 calibration_type)
593    run_params = RunParams(
594        use_optimizer=use_optimizer,
595        precision_mode=precision_mode,
596        dynamic_engine=dynamic_engine,
597        test_name=test_name,
598        use_calibration=use_calibration)
599    setattr(test_class, "testTfTrt_" + test_name, _GetTest(run_params))
600
601
602if is_tensorrt_enabled():
603  _AddTests(TfTrtIntegrationTestBase)
604