1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
17 
18 #include <algorithm>
19 #include <cmath>
20 #include <cstring>
21 #include <map>
22 #include <memory>
23 #include <set>
24 #include <unordered_map>
25 #include <utility>
26 #include <vector>
27 
28 #include "absl/algorithm/container.h"
29 #include "absl/container/flat_hash_set.h"
30 #include "absl/memory/memory.h"
31 #include "absl/strings/match.h"
32 #include "absl/strings/str_cat.h"
33 #include "absl/strings/str_format.h"
34 #include "absl/strings/string_view.h"
35 #include "tensorflow/compiler/tf2tensorrt/common/utils.h"
36 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
37 #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
38 #include "tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h"
39 #include "tensorflow/core/common_runtime/graph_constructor.h"
40 #include "tensorflow/core/framework/node_def.pb.h"  // NOLINT
41 #include "tensorflow/core/framework/node_def_builder.h"
42 #include "tensorflow/core/framework/tensor.pb.h"  // NOLINT
43 #include "tensorflow/core/framework/tensor_shape.h"
44 #include "tensorflow/core/framework/tensor_shape.pb.h"  // NOLINT
45 #include "tensorflow/core/framework/types.h"
46 #include "tensorflow/core/graph/algorithm.h"
47 #include "tensorflow/core/graph/graph.h"
48 #include "tensorflow/core/grappler/op_types.h"
49 #include "tensorflow/core/lib/core/errors.h"
50 #include "tensorflow/core/lib/core/status.h"
51 #include "tensorflow/core/lib/strings/numbers.h"
52 #include "tensorflow/core/lib/strings/str_util.h"
53 #include "tensorflow/core/lib/strings/strcat.h"
54 #include "tensorflow/core/platform/logging.h"
55 #include "tensorflow/core/platform/mutex.h"
56 #include "tensorflow/core/platform/protobuf.h"
57 #include "tensorflow/core/platform/tensor_coding.h"
58 #include "tensorflow/core/platform/types.h"
59 #include "tensorflow/core/profiler/lib/annotated_traceme.h"
60 #include "tensorflow/core/public/version.h"
61 #include "tensorflow/core/util/env_var.h"
62 #include "tensorflow/core/util/strided_slice_op.h"
63 
64 #if GOOGLE_CUDA && GOOGLE_TENSORRT
65 #include "third_party/tensorrt/NvInfer.h"
66 #include "third_party/tensorrt/NvInferPlugin.h"
67 
68 // Check if the types are equal. Cast to int first so that failure log message
69 // would work!
70 #define TFTRT_CHECK_EQ_TYPE(val1, val2) CHECK_EQ((int)val1, (int)val2)
71 
72 #define TFTRT_INTERNAL_ERROR_AT_NODE(node)                           \
73   do {                                                               \
74     return errors::Internal("TFTRT::", __FUNCTION__, ":", __LINE__,  \
75                             " failed to add TRT layer, at: ", node); \
76   } while (0)
77 
78 #define TFTRT_RETURN_ERROR_IF_NULLPTR(ptr, node) \
79   do {                                           \
80     if (ptr == nullptr) {                        \
81       TFTRT_INTERNAL_ERROR_AT_NODE(node);        \
82     }                                            \
83   } while (0)
84 
85 namespace tensorflow {
86 namespace tensorrt {
87 namespace convert {
88 
89 using absl::StrAppend;
90 using absl::StrCat;
91 
92 namespace {
93 
94 #define ADD_LAYER(layer_name)              \
95   case nvinfer1::LayerType::k##layer_name: \
96     return #layer_name;
97 
LayerTypeToString(nvinfer1::LayerType layer_type)98 const char* LayerTypeToString(nvinfer1::LayerType layer_type) {
99   switch (layer_type) {
100     ADD_LAYER(CONVOLUTION)
101     ADD_LAYER(FULLY_CONNECTED)
102     ADD_LAYER(ACTIVATION)
103     ADD_LAYER(POOLING)
104     ADD_LAYER(LRN)
105     ADD_LAYER(SCALE)
106     ADD_LAYER(SOFTMAX)
107     ADD_LAYER(DECONVOLUTION)
108     ADD_LAYER(CONCATENATION)
109     ADD_LAYER(ELEMENTWISE)
110     ADD_LAYER(PLUGIN)
111     ADD_LAYER(RNN)
112     ADD_LAYER(UNARY)
113     ADD_LAYER(PADDING)
114     ADD_LAYER(SHUFFLE)
115     ADD_LAYER(REDUCE)
116     ADD_LAYER(TOPK)
117     ADD_LAYER(GATHER)
118     ADD_LAYER(MATRIX_MULTIPLY)
119     ADD_LAYER(RAGGED_SOFTMAX)
120     ADD_LAYER(CONSTANT)
121     ADD_LAYER(RNN_V2)
122     ADD_LAYER(IDENTITY)
123     ADD_LAYER(PLUGIN_V2)
124     ADD_LAYER(SLICE)
125 #if IS_TRT_VERSION_GE(6, 0, 0, 0)
126     ADD_LAYER(SHAPE)
127     ADD_LAYER(PARAMETRIC_RELU)
128     ADD_LAYER(RESIZE)
129 #endif
130 #if IS_TRT_VERSION_GE(7, 0, 0, 0)
131     ADD_LAYER(TRIP_LIMIT)
132     ADD_LAYER(RECURRENCE)
133     ADD_LAYER(ITERATOR)
134     ADD_LAYER(LOOP_OUTPUT)
135     ADD_LAYER(SELECT)
136     ADD_LAYER(FILL)
137 #endif
138   }
139   return "UNKNOWN_LAYER";
140 }
141 
142 #undef ADD_LAYER
143 
144 // Sets the ILayer name in the form of
145 // <engine_name>/<tf_related_part>:<trt_operation_name>.
SetLayerNameHelper(nvinfer1::ILayer * layer,absl::string_view engine_name,absl::string_view tf_name)146 void SetLayerNameHelper(nvinfer1::ILayer* layer, absl::string_view engine_name,
147                         absl::string_view tf_name) {
148   const char* trt_name = LayerTypeToString(layer->getType());
149   layer->setName(
150       absl::StrCat(engine_name, "/", tf_name, ":", trt_name).c_str());
151 }
152 
153 // Returns a string in the form of <sub_op_name><sub_op_instance>.
GetLayerNameSuffix(absl::string_view sub_op_name,absl::optional<int> sub_op_instance)154 std::string GetLayerNameSuffix(absl::string_view sub_op_name,
155                                absl::optional<int> sub_op_instance) {
156   std::string op_suffix(sub_op_name);
157   if (sub_op_instance.has_value()) {
158     op_suffix =
159         absl::StrCat(op_suffix, "_", std::to_string(sub_op_instance.value()));
160   }
161   return op_suffix;
162 }
163 
164 }  // namespace
165 
IsEngineInput(absl::string_view name)166 bool IsEngineInput(absl::string_view name) {
167   return absl::StartsWith(name, IONamePrefixes::kInputPHName);
168 }
IsEngineOutput(absl::string_view name)169 bool IsEngineOutput(absl::string_view name) {
170   return absl::StartsWith(name, IONamePrefixes::kOutputPHName);
171 }
172 
173 class TFAttrs {
174  public:
TFAttrs(const NodeDef & tf_node)175   explicit TFAttrs(const NodeDef& tf_node) {
176     for (const auto& attr : tf_node.attr()) {
177       attrs_.insert({attr.first, &attr.second});
178     }
179   }
180 
count(const string & key) const181   bool count(const string& key) const { return attrs_.count(key); }
182 
at(const string & key) const183   AttrValue const* at(const string& key) const {
184     if (!attrs_.count(key)) {
185       LOG(FATAL) << "Attribute not found: " << key;
186     }
187     return attrs_.at(key);
188   }
189 
190   template <typename T>
191   T get(const string& key) const;
192 
193   template <typename T>
get(const string & key,const T & default_value) const194   T get(const string& key, const T& default_value) const {
195     return attrs_.count(key) ? this->get<T>(key) : default_value;
196   }
197 
198  private:
199   std::map<string, AttrValue const*> attrs_;
200 };
201 
202 template <>
get(const string & key) const203 string TFAttrs::get<string>(const string& key) const {
204   return this->at(key)->s();
205 }
206 
207 template <>
get(const string & key) const208 std::vector<int64> TFAttrs::get<std::vector<int64>>(const string& key) const {
209   auto attr = this->at(key)->list().i();
210   return std::vector<int64>(attr.begin(), attr.end());
211 }
212 
213 template <>
get(const string & key) const214 std::vector<float> TFAttrs::get<std::vector<float>>(const string& key) const {
215   auto attr = this->at(key)->list().f();
216   return std::vector<float>(attr.begin(), attr.end());
217 }
218 
219 template <>
get(const string & key) const220 nvinfer1::DataType TFAttrs::get<nvinfer1::DataType>(const string& key) const {
221   nvinfer1::DataType trt_dtype(nvinfer1::DataType::kFLOAT);
222   TF_CHECK_OK(TfTypeToTrtType(this->at(key)->type(), &trt_dtype));
223   return trt_dtype;
224 }
225 
226 template <>
get(const string & key) const227 DataType TFAttrs::get<DataType>(const string& key) const {
228   return this->at(key)->type();
229 }
230 
231 template <>
get(const string & key) const232 float TFAttrs::get<float>(const string& key) const {
233   return this->at(key)->f();
234 }
235 
236 template <>
get(const string & key) const237 bool TFAttrs::get<bool>(const string& key) const {
238   return this->at(key)->b();
239 }
240 
241 template <>
get(const string & key) const242 int64 TFAttrs::get<int64>(const string& key) const {
243   return this->at(key)->i();
244 }
245 
246 template <typename Container>
TensorShapeArrayToTrtDims(const Container & shape,nvinfer1::Dims * out,bool ignore_first_dim=false)247 Status TensorShapeArrayToTrtDims(const Container& shape, nvinfer1::Dims* out,
248                                  bool ignore_first_dim = false) {
249   PartialTensorShape tensor_shape;
250   TF_RETURN_IF_ERROR(TensorShapeUtils::MakeShape(shape, &tensor_shape));
251   *out = TensorShapeToTrtDims(tensor_shape, ignore_first_dim);
252   return Status::OK();
253 }
254 
255 // TODO(laigd): use this utility function in more places.
RemoveBatchDimension(nvinfer1::Dims * dims)256 Status RemoveBatchDimension(nvinfer1::Dims* dims) {
257   if (dims->nbDims < 2) {
258     return errors::InvalidArgument(
259         "Dropping batch dimension requires dims with rank>=2.");
260   }
261   std::copy(dims->d + 1, dims->d + dims->nbDims, dims->d);
262   dims->nbDims--;
263   return Status::OK();
264 }
265 
GetOutputProperties(const grappler::GraphProperties & graph_properties,const Node * node,const int out_port,PartialTensorShape * shape,DataType * dtype)266 void GetOutputProperties(const grappler::GraphProperties& graph_properties,
267                          const Node* node, const int out_port,
268                          PartialTensorShape* shape, DataType* dtype) {
269   if (graph_properties.HasOutputProperties(node->name())) {
270     auto output_params = graph_properties.GetOutputProperties(node->name());
271     auto out_shape = output_params.at(out_port);
272     *dtype = out_shape.dtype();
273     *shape = out_shape.shape();
274   } else {
275     LOG(INFO) << "Unknown output shape" << node->name();
276     *dtype = node->output_type(out_port);
277   }
278 }
279 
GetInputProperties(const grappler::GraphProperties & graph_properties,const Node * node,const int in_port,PartialTensorShape * shape,DataType * dtype)280 void GetInputProperties(const grappler::GraphProperties& graph_properties,
281                         const Node* node, const int in_port,
282                         PartialTensorShape* shape, DataType* dtype) {
283   if (graph_properties.HasInputProperties(node->name())) {
284     auto input_params = graph_properties.GetInputProperties(node->name());
285     auto in_shape = input_params.at(in_port);
286     *dtype = in_shape.dtype();
287     *shape = in_shape.shape();
288   } else {
289     *dtype = node->input_type(in_port);
290   }
291 }
292 
293 // This function checks if a tensor is compatible with TRT.
294 //
295 // We check that the shape and datatype are compatible with TensorRT. We also
296 // return the corresponding trt_dtype, the trt_dims and the batch_size (latter
297 // is only needed in implicit batch mode).
298 //
299 // The return status indicates wether the tensor is compatible.
300 //
301 // For implicit batch mode, when validation_only == false, we also check that
302 // all input dimensions (besides the batch dimension) are known dimensions.
ValidateTensorProperties(const string & producer_node_type,const DataType dtype,const PartialTensorShape & shape,const bool use_implicit_batch,bool validation_only,nvinfer1::DataType * trt_dtype,nvinfer1::Dims * trt_dims,int * batch_size)303 Status ValidateTensorProperties(const string& producer_node_type,
304                                 const DataType dtype,
305                                 const PartialTensorShape& shape,
306                                 const bool use_implicit_batch,
307                                 bool validation_only,
308                                 nvinfer1::DataType* trt_dtype,
309                                 nvinfer1::Dims* trt_dims, int* batch_size) {
310   // Convert data type.
311   TF_RETURN_IF_ERROR(TfTypeToTrtType(dtype, trt_dtype));
312 
313   // Convert shape.
314   if (shape.dims() < 0) {
315     return errors::InvalidArgument("Input tensor rank is unknown.");
316   }
317   // Add 1 to maximum rank for implicit batch dim.
318   const int max_rank = nvinfer1::Dims::MAX_DIMS + (use_implicit_batch ? 1 : 0);
319   if (shape.dims() > max_rank) {
320     return errors::OutOfRange("Input tensor rank is greater than ", max_rank);
321   }
322   if (use_implicit_batch && (producer_node_type != "Const") &&
323       (shape.dims() < 1)) {
324     return errors::InvalidArgument(
325         "Scalar input tensor is not supported since the first dimension "
326         "is treated as batch dimension by TRT");
327   }
328   *trt_dims = TensorShapeToTrtDims(shape,
329                                    /*ignore_first_dim=*/use_implicit_batch);
330   // Get batch size for tensor if it will not be included the shape.
331   if (use_implicit_batch) {
332     *batch_size = shape.dim_size(0);
333   }
334 
335   // Don't convert empty tensors (dim value of 0).
336   const int first_trt_dim = use_implicit_batch ? 1 : 0;
337   for (int d = first_trt_dim; d < shape.dims(); ++d) {
338     if (shape.dim_size(d) == 0) {
339       return errors::Unimplemented(
340           "Input tensor with shape ", shape.DebugString(),
341           " is an empty tensor, which is not supported by TRT");
342     }
343   }
344 
345   if (validation_only) return Status::OK();
346 
347   // Following checks are only used during TRT engine creation time.
348   if (use_implicit_batch) {
349     for (int d = first_trt_dim; d < shape.dims(); ++d) {
350       if (shape.dim_size(d) < 0) {
351         return errors::InvalidArgument(
352             "Input tensor with shape ", shape.DebugString(),
353             " has an unknown non-batch dimension at dim ", d);
354       }
355     }
356   }
357   return Status::OK();
358 }
359 
GetTrtBroadcastShape(const TRT_TensorOrWeights & operand_l,const TRT_TensorOrWeights & operand_r,const bool check_feasibility,const bool use_implicit_batch,nvinfer1::Dims * operand_l_new_dims,nvinfer1::Dims * operand_r_new_dims)360 Status GetTrtBroadcastShape(const TRT_TensorOrWeights& operand_l,
361                             const TRT_TensorOrWeights& operand_r,
362                             const bool check_feasibility,
363                             const bool use_implicit_batch,
364                             nvinfer1::Dims* operand_l_new_dims,
365                             nvinfer1::Dims* operand_r_new_dims) {
366   // TensorRT Elementwise op supports broadcast but requires both tensor to be
367   // of Identical rank
368   //
369   // We consider case of:
370   //   1. operand_l to be a Tensor & operand_r to be a Const;
371   //   2. operand_l to be a Tensor & operand_r to be a Tensor;
372   // note: const op const (constant folding) should fallback to TensorFlow
373   //
374   // broadcast scheme:
375   //       T:  1 3 5    (tensor would not have batch dimension)
376   //       W:  1 1 3 1  (weight would have all explicit dimensions)
377   // i. fill in explicit dimensions
378   //    -> T: -1 1 3 5  (we put a -1 for batch dimension)
379   //    -> W:  1 1 3 1
380   // ii. compare broadcast feasibility
381   //
382   // We cannot support the following since TensorRT does not allow manipulation
383   // on batch dimension, we cannot generate output with proper shape
384   //    T: 3 5 1
385   //    W: 1 1 1  1 3 5 1
386   // -> T: 1 1 1 -1 3 5 1
387   // -> W: 1 1 1  1 3 5 1
388   // ***************************************************************************
389   if (!operand_l.is_tensor() && !operand_r.is_tensor()) {
390     return errors::InvalidArgument(
391         "Broadcasting requires at least one of the operands be tensors");
392   }
393 
394   const int max_nb_dims = nvinfer1::Dims::MAX_DIMS + 1;
395   auto compute_output_dims = [use_implicit_batch](
396                                  const TRT_TensorOrWeights& input,
397                                  int broadcast_num_dims, int* output_dims_array,
398                                  nvinfer1::Dims* output_dims) {
399     const nvinfer1::Dims input_dims = input.GetTrtDims();
400     std::fill(output_dims_array, output_dims_array + max_nb_dims, 1);
401     std::copy(input_dims.d, input_dims.d + input_dims.nbDims,
402               output_dims_array + broadcast_num_dims - input_dims.nbDims);
403     if (use_implicit_batch && input.is_tensor()) {
404       const int true_input_dims = input_dims.nbDims + 1;
405       if (true_input_dims < broadcast_num_dims) {
406         return errors::InvalidArgument(
407             "Broadcasting beyond batch dimension is not supported ",
408             "(tensor #dims ", true_input_dims, " vs broadcast #dims ",
409             broadcast_num_dims, ")");
410       }
411       // Set the batch dimension to -1, since batch size is not supposed to
412       // be broadcasted.
413       output_dims_array[0] = -1;
414     }
415     // Copy to output dimensions
416     if (use_implicit_batch) {
417       // Strip batch dimension while copying
418       output_dims->nbDims = broadcast_num_dims - 1;
419       std::copy(output_dims_array + 1, output_dims_array + broadcast_num_dims,
420                 output_dims->d);
421     } else {
422       output_dims->nbDims = broadcast_num_dims;
423       std::copy(output_dims_array, output_dims_array + broadcast_num_dims,
424                 output_dims->d);
425     }
426 
427     return Status::OK();
428   };
429 
430   // Compute the output dimensions.
431   const int broadcast_num_dims =
432       std::max(operand_l.GetTrtDims().nbDims +
433                    (use_implicit_batch && operand_l.is_tensor()),
434                operand_r.GetTrtDims().nbDims +
435                    (use_implicit_batch && operand_r.is_tensor()));
436   int output_l[max_nb_dims], output_r[max_nb_dims];
437   TF_RETURN_IF_ERROR(compute_output_dims(operand_l, broadcast_num_dims,
438                                          output_l, operand_l_new_dims));
439   TF_RETURN_IF_ERROR(compute_output_dims(operand_r, broadcast_num_dims,
440                                          output_r, operand_r_new_dims));
441 
442   // Compare broadcast feasibility
443   if (check_feasibility) {
444     for (int i = 0; i < broadcast_num_dims; ++i) {
445       if (!use_implicit_batch && (output_l[i] == -1 || output_r[i] == -1)) {
446         // If the condition is true then we are in explicit batch mode and (at
447         // least) one of the input dimensions are unknown. In other words we
448         // are in dynamic shape mode. During conversion time we only see -1 for
449         // the unknown shapes, therefore we cannot decide on the feasibility of
450         // broadcast over the unknown dimensions. Therefore we just continue for
451         // the next dimension. In dynamic shape mode TRT can only check the
452         // feasibility of the broadcast when the actual input dimensions are
453         // specified by SetTrtEngineInputs and the inference job is launched by
454         // TrtEnque.
455         continue;
456       }
457       if ((output_l[i] != output_r[i]) && (output_l[i] != 1) &&
458           (output_r[i] != 1)) {
459         return errors::InvalidArgument("Infeasible broadcast scheme (",
460                                        "batch_dim: ", output_l[0], ", ",
461                                        DebugString(*operand_l_new_dims), " vs ",
462                                        "batch_dim: ", output_r[0], ", ",
463                                        DebugString(*operand_r_new_dims), ")");
464       }
465     }
466   }
467   return Status::OK();
468 }
469 
CreateConstantLayer(const TRT_ShapedWeights & weights,const nvinfer1::Dims & dims)470 nvinfer1::ITensor* Converter::CreateConstantLayer(
471     const TRT_ShapedWeights& weights, const nvinfer1::Dims& dims) {
472   nvinfer1::Weights trt_weights = weights.GetTrtWeights();
473   nvinfer1::IConstantLayer* layer = network()->addConstant(dims, trt_weights);
474   if (!layer) return nullptr;
475   SetLayerName(layer, "_tftrt_constant_",
476                std::to_string(next_constant_layer_id_));
477   next_constant_layer_id_++;
478   nvinfer1::ITensor* trt_tensor = layer->getOutput(0);
479 #if !IS_TRT_VERSION_GE(5, 1, 3, 0)
480   // TODO(laigd): there is a bug in TensorRT 5.0 library that, if we don't set
481   // the data type below, it will always be kFLOAT regardless what the data type
482   // of the weights is. Once NVIDIA fixes this bug, we should remove the data
483   // type setting logic below and test should still pass.
484   trt_tensor->setType(trt_weights.type);
485 #endif
486   return trt_tensor;
487 }
488 
489 // Creates a scalar constant and fills with value.
490 template <typename T>
CreateScalarConstant(OpConverterParams * params,T value,nvinfer1::ITensor ** tensor,nvinfer1::DataType trt_type=nvinfer1::DataType::kINT32,const nvinfer1::Dims & dims={1, {1}})491 Status CreateScalarConstant(
492     OpConverterParams* params, T value, nvinfer1::ITensor** tensor,
493     nvinfer1::DataType trt_type = nvinfer1::DataType::kINT32,
494     const nvinfer1::Dims& dims = {1, {1}}) {
495   TRT_ShapedWeights weights =
496       params->weight_store->GetTempWeights(trt_type, dims);
497   TF_RETURN_IF_ERROR(weights.SetValues(value));
498   *tensor = params->converter->CreateConstantLayer(weights, dims);
499   TFTRT_RETURN_ERROR_IF_NULLPTR(*tensor, params->node_def.name());
500   params->converter->ProvideQuantizationRange(*tensor, value, value);
501   return Status::OK();
502 }
503 
504 // Creates a constant with the same rank as dims, where each dimension has
505 // size = 1.
CreateBroadcastableScalarConstant(OpConverterParams * params,float value,const nvinfer1::Dims & dims,nvinfer1::ITensor ** tensor,const char * dtype_attr_name="T")506 Status CreateBroadcastableScalarConstant(OpConverterParams* params, float value,
507                                          const nvinfer1::Dims& dims,
508                                          nvinfer1::ITensor** tensor,
509                                          const char* dtype_attr_name = "T") {
510   nvinfer1::DataType trt_type = nvinfer1::DataType::kFLOAT;  // Default to FP32.
511   TFAttrs attrs(params->node_def);
512   if (attrs.count(dtype_attr_name)) {
513     DataType dtype = attrs.get<DataType>(dtype_attr_name);
514     TF_RETURN_IF_ERROR(TfTypeToTrtType(dtype, &trt_type));
515   }
516 
517   // In order to be broadcastable, the number of dims has to match.
518   nvinfer1::Dims broadcastable_dims(dims);
519   for (int i = 0; i < broadcastable_dims.nbDims; i++) {
520     broadcastable_dims.d[i] = 1;
521   }
522   return CreateScalarConstant(params, value, tensor, trt_type,
523                               broadcastable_dims);
524 }
525 
526 // Convert an axis from TF format to TRT format while validating. TF format
527 // includes the batch dimension, while TRT does not if implicit batching is used
528 // (i.e. for tensors). TF can also use negative indices.
ConvertAxis(int tf_axis,int trt_nb_dims,absl::string_view node_name,bool use_implicit_batch,int * trt_axis)529 Status ConvertAxis(int tf_axis, int trt_nb_dims, absl::string_view node_name,
530                    bool use_implicit_batch, int* trt_axis) {
531   const int tf_nb_dims = trt_nb_dims + (use_implicit_batch ? 1 : 0);
532   // Check bounds.
533   if (tf_axis < -tf_nb_dims || tf_axis >= tf_nb_dims) {
534     return errors::InvalidArgument(
535         "Axis value of ", tf_axis, " is out of bounds, must be in range [",
536         -tf_nb_dims, ", ", tf_nb_dims, "), at ", node_name);
537   }
538   // Make negative axis positive.
539   if (tf_axis < 0) tf_axis += tf_nb_dims;
540   // Don't allow axis to be the batch dimension.
541   if (use_implicit_batch && tf_axis == 0) {
542     return errors::Unimplemented(
543         "TensorRT does not allow manipulation of the batch dimension, at ",
544         node_name);
545   }
546   // Remove batch dimension if it is implicit.
547   *trt_axis = use_implicit_batch ? tf_axis - 1 : tf_axis;
548   return Status::OK();
549 }
550 
DimsEqual(const nvinfer1::Dims & dim_l,const nvinfer1::Dims & dim_r)551 inline bool DimsEqual(const nvinfer1::Dims& dim_l,
552                       const nvinfer1::Dims& dim_r) {
553   if (dim_l.nbDims != dim_r.nbDims) {
554     return false;
555   }
556   for (int i = 0; i < dim_l.nbDims; i++) {
557     if (dim_l.d[i] != dim_r.d[i]) {
558       return false;
559     }
560   }
561   return true;
562 }
563 
AllLengthsEqual(const std::vector<std::vector<int>> & inputs)564 bool AllLengthsEqual(const std::vector<std::vector<int>>& inputs) {
565   if (inputs.size() == 0) return true;
566   int length = inputs.at(0).size();
567   for (int i = 1; i < inputs.size(); i++) {
568     if (inputs.at(i).size() != length) return false;
569   }
570   return true;
571 }
572 
GetTrtDimsForTensor(const Tensor & tensor)573 inline nvinfer1::Dims GetTrtDimsForTensor(const Tensor& tensor) {
574   nvinfer1::Dims dims;
575   dims.nbDims = tensor.dims();
576   for (int i = 0; i < dims.nbDims; i++) {
577     dims.d[i] = tensor.dim_size(i);
578   }
579   return dims;
580 }
581 
Prod(const nvinfer1::Dims & dims)582 int64_t Prod(const nvinfer1::Dims& dims) {
583   int64_t count = 1;
584   for (int d = 0; d < dims.nbDims; ++d) {
585     count *= dims.d[d];
586   }
587   return count;
588 }
589 
590 // Returns total number of elements in a TensorRT weights dimensions.
591 // Returning 0 means either some dim is 0 or the number of dims is 0 (TensorRT
592 // doesn't allow scalar weights).
593 // Note that for TF scalar constant, we always convert to dims [1].
TrtWeightDimsNumElements(const nvinfer1::Dims & dims)594 int64_t TrtWeightDimsNumElements(const nvinfer1::Dims& dims) {
595   if (dims.nbDims == 0) return 0;
596   return Prod(dims);
597 }
598 
599 // Returns total number of elements in an ITensor dimension.
600 // Returns 1 if the number of dims is 0 (the total number is fully determined by
601 // the batch size).
602 // Returns -1 if any dimension is known.
TrtTensorDimsNumElements(const nvinfer1::Dims & dims)603 int64_t TrtTensorDimsNumElements(const nvinfer1::Dims& dims) {
604   if (!HasStaticShape(dims)) return -1;
605   return Prod(dims);
606 }
607 
DimsHaveSameSize(const nvinfer1::Dims & lhs,const nvinfer1::Dims & rhs,bool is_tensor)608 bool DimsHaveSameSize(const nvinfer1::Dims& lhs, const nvinfer1::Dims& rhs,
609                       bool is_tensor) {
610   if (is_tensor) {
611     return TrtTensorDimsNumElements(lhs) == TrtTensorDimsNumElements(rhs);
612   }
613   return TrtWeightDimsNumElements(lhs) == TrtWeightDimsNumElements(rhs);
614 }
615 
616 // Returns whether both dimensions are fully specified and the total number of
617 // elements equals.
AreDimsStaticWithSameSize(const nvinfer1::Dims & lhs,const nvinfer1::Dims & rhs,bool is_tensor)618 bool AreDimsStaticWithSameSize(const nvinfer1::Dims& lhs,
619                                const nvinfer1::Dims& rhs, bool is_tensor) {
620   if (!HasStaticShape(lhs) || !HasStaticShape(rhs)) return false;
621   return DimsHaveSameSize(lhs, rhs, is_tensor);
622 }
623 
AreDimsStaticWithDifferentSize(const nvinfer1::Dims & lhs,const nvinfer1::Dims & rhs,bool is_tensor)624 bool AreDimsStaticWithDifferentSize(const nvinfer1::Dims& lhs,
625                                     const nvinfer1::Dims& rhs, bool is_tensor) {
626   if (!HasStaticShape(lhs) || !HasStaticShape(rhs)) return false;
627   return !DimsHaveSameSize(lhs, rhs, is_tensor);
628 }
629 
CreateSamePadding(const nvinfer1::Dims & stride,const nvinfer1::Dims & kernel,const std::vector<int64_t> & input_dims)630 static std::vector<std::pair<int, int>> CreateSamePadding(
631     const nvinfer1::Dims& stride, const nvinfer1::Dims& kernel,
632     const std::vector<int64_t>& input_dims) {
633   std::vector<std::pair<int, int>> padding(input_dims.size());
634   CHECK_EQ(stride.nbDims, input_dims.size());  // TODO(jie): N+C? NC+?
635 
636   for (size_t i = 0; i < input_dims.size(); ++i) {
637     // Formula to calculate the padding
638     int p = ((input_dims[i] - 1) / stride.d[i]) * stride.d[i] + kernel.d[i] -
639             input_dims[i];
640     p = (p > 0) ? p : 0;
641 
642     // Right precedence padding, like in TensorFlow
643     int left = p / 2;
644     int right = p - left;
645 
646     VLOG(2) << "PADDING_" << i << " pre: " << left << ", post: " << right
647             << "paras: " << input_dims[i] << ", " << stride.d[i] << ", "
648             << "kernel: " << kernel.d[i];
649     padding[i] = {left, right};
650   }
651   return padding;
652 }
653 
GetCommonNameScope(const string & op_name_a,const string & op_name_b)654 string GetCommonNameScope(const string& op_name_a, const string& op_name_b) {
655   size_t last_scope_separator = 0;
656   const size_t min_size = std::min(op_name_a.size(), op_name_b.size());
657   for (size_t i = 0; i < min_size; ++i) {
658     if (op_name_a[i] != op_name_b[i]) break;
659     if (op_name_a[i] == '/') last_scope_separator = i + 1;
660   }
661   return op_name_a.substr(0, last_scope_separator);
662 }
663 
664 // Verifies that shapes of the given inputs match after masking the specified
665 // dimension.
VerifyShapesMatch(absl::Span<const TRT_TensorOrWeights> inputs,int masked_dim,absl::string_view node_name)666 Status VerifyShapesMatch(absl::Span<const TRT_TensorOrWeights> inputs,
667                          int masked_dim, absl::string_view node_name) {
668   size_t num_inputs = inputs.size();
669   if (num_inputs <= 1) return Status::OK();
670 
671   const nvinfer1::Dims dims_0 = inputs.at(0).GetTrtDims();
672   for (size_t i = 1; i < num_inputs; ++i) {
673     const nvinfer1::Dims dim_i = inputs.at(i).GetTrtDims();
674     if (dim_i.nbDims != dims_0.nbDims) {
675       return errors::InvalidArgument(
676           "Received inputs with inconsistent rank, at ", node_name);
677     }
678     for (size_t j = 0; j < dims_0.nbDims; ++j) {
679       // Dynamic dimensions will be verified at runtime.
680       if (dim_i.d[j] == -1 || dims_0.d[j] == -1) continue;
681       if (dim_i.d[j] != dims_0.d[j] && j != masked_dim) {
682         return errors::InvalidArgument(
683             "Received inputs with inconsistent shape, at ", node_name);
684       }
685     }
686   }
687   return Status::OK();
688 }
689 
TRT_ShapedWeights(nvinfer1::DataType type)690 TRT_ShapedWeights::TRT_ShapedWeights(nvinfer1::DataType type) : type_(type) {
691   shape_.nbDims = 0;
692 }
693 
TRT_ShapedWeights(nvinfer1::DataType type,nvinfer1::Dims dims,Tensor tensor)694 TRT_ShapedWeights::TRT_ShapedWeights(nvinfer1::DataType type,
695                                      nvinfer1::Dims dims, Tensor tensor)
696     : shape_(dims), type_(type), tensor_(tensor) {}
697 
TRT_ShapedWeights(const TRT_ShapedWeights & rhs)698 TRT_ShapedWeights::TRT_ShapedWeights(const TRT_ShapedWeights& rhs)
699     : shape_(rhs.shape_), type_(rhs.type_), tensor_(rhs.tensor_) {}
700 
count() const701 int64_t TRT_ShapedWeights::count() const {
702   return TrtWeightDimsNumElements(shape_);
703 }
704 
GetTrtWeights() const705 nvinfer1::Weights TRT_ShapedWeights::GetTrtWeights() const {
706   return nvinfer1::Weights{type_, GetValues(), count()};
707 }
708 
709 template <typename T>
SetValues(T value)710 Status TRT_ShapedWeights::SetValues(T value) {
711   switch (type_) {
712     case nvinfer1::DataType::kFLOAT: {
713       float* ptr = tensor_.flat<float>().data();
714       std::fill(ptr, ptr + count(), value);
715       break;
716     }
717     case nvinfer1::DataType::kHALF: {
718       Eigen::half* ptr = tensor_.flat<Eigen::half>().data();
719       std::fill(ptr, ptr + count(), Eigen::half(value));
720       break;
721     }
722     case nvinfer1::DataType::kINT32: {
723       int32* ptr = tensor_.flat<int32>().data();
724       std::fill(ptr, ptr + count(), value);
725       break;
726     }
727     default:
728       return errors::InvalidArgument("Unsupported data type ",
729                                      tensorflow::tensorrt::DebugString(type_));
730   }
731   return Status::OK();
732 }
733 
size_bytes() const734 size_t TRT_ShapedWeights::size_bytes() const {
735   size_t data_type_size = -1;
736   switch (type_) {
737     case nvinfer1::DataType::kFLOAT:
738     case nvinfer1::DataType::kINT32:
739       data_type_size = 4;
740       break;
741     case nvinfer1::DataType::kHALF:
742       data_type_size = 2;
743       break;
744     case nvinfer1::DataType::kINT8:
745 #if IS_TRT_VERSION_GE(7, 0, 0, 0)
746     case nvinfer1::DataType::kBOOL:
747 #endif
748       data_type_size = 1;
749       break;
750   }
751   return this->count() * data_type_size;
752 }
753 
DebugString() const754 string TRT_ShapedWeights::DebugString() const {
755   return StrCat(
756       "TRT_ShapedWeights(shape=", tensorflow::tensorrt::DebugString(shape_),
757       ", type=", tensorflow::tensorrt::DebugString(type_),
758       ", values=", reinterpret_cast<uintptr_t>(GetValues()), ")");
759 }
760 
761 // A fake ITensor implementation used to check whether the TF-TRT converter can
762 // handle specific node. We only need shape and type information, and the
763 // converter won't (and shouldn't) use this to build the TRT network.
764 class TRT_TensorOrWeights::SimpleITensor : public nvinfer1::ITensor {
765  public:
SimpleITensor(nvinfer1::DataType trt_dtype,const nvinfer1::Dims & trt_dims)766   SimpleITensor(nvinfer1::DataType trt_dtype, const nvinfer1::Dims& trt_dims)
767       : trt_dtype_(trt_dtype), trt_dims_(trt_dims) {}
768 
setName(const char * name)769   void setName(const char* name) override {}
770 
getName() const771   const char* getName() const override { return ""; }
772 
setDimensions(nvinfer1::Dims dimensions)773   void setDimensions(nvinfer1::Dims dimensions) override {
774     trt_dims_ = dimensions;
775   }
776 
getDimensions() const777   nvinfer1::Dims getDimensions() const override { return trt_dims_; }
778 
setType(nvinfer1::DataType trt_dtype)779   void setType(nvinfer1::DataType trt_dtype) override {
780     trt_dtype_ = trt_dtype;
781   }
782 
getType() const783   nvinfer1::DataType getType() const override { return trt_dtype_; }
784 
isNetworkInput() const785   bool isNetworkInput() const override { return false; }
786 
isNetworkOutput() const787   bool isNetworkOutput() const override { return false; }
788 
setBroadcastAcrossBatch(bool broadcastAcrossBatch)789   void setBroadcastAcrossBatch(bool broadcastAcrossBatch) override {}
790 
getBroadcastAcrossBatch() const791   bool getBroadcastAcrossBatch() const override { return false; }
792 
getLocation() const793   nvinfer1::TensorLocation getLocation() const override {
794     // This is arbitrary, since we don't use it.
795     return nvinfer1::TensorLocation::kDEVICE;
796   }
797 
setLocation(nvinfer1::TensorLocation location)798   void setLocation(nvinfer1::TensorLocation location) override {}
799 
800 #if IS_TRT_VERSION_GE(5, 0, 0, 0)
setDynamicRange(float min,float max)801   bool setDynamicRange(float min, float max) override { return true; }
802 
getDynamicRange() const803   float getDynamicRange() const override { return 0; }
804 #endif
805 
806 #if IS_TRT_VERSION_GE(5, 1, 0, 0)
dynamicRangeIsSet() const807   bool dynamicRangeIsSet() const override { return true; }
808 
resetDynamicRange()809   void resetDynamicRange() override {}
810 
getDynamicRangeMin() const811   float getDynamicRangeMin() const override { return 0.f; }
812 
getDynamicRangeMax() const813   float getDynamicRangeMax() const override { return 0.f; }
814 #endif
815 
816 #if IS_TRT_VERSION_GE(6, 0, 0, 0)
setAllowedFormats(nvinfer1::TensorFormats formats)817   void setAllowedFormats(nvinfer1::TensorFormats formats) override {}
818 
getAllowedFormats() const819   nvinfer1::TensorFormats getAllowedFormats() const override { return 1; }
820 
isShapeTensor() const821   bool isShapeTensor() const override { return false; }
822 
isExecutionTensor() const823   bool isExecutionTensor() const override { return true; }
824 #endif
825 
826  private:
827   nvinfer1::DataType trt_dtype_;
828   nvinfer1::Dims trt_dims_;
829 };
830 
TRT_TensorOrWeights(nvinfer1::ITensor * tensor,int batch_size)831 TRT_TensorOrWeights::TRT_TensorOrWeights(nvinfer1::ITensor* tensor,
832                                          int batch_size)
833     : tensor_(tensor),
834       batch_size_(batch_size),
835       initialized_(true),
836       is_tensor_(true) {}
837 
TRT_TensorOrWeights(nvinfer1::DataType trt_dtype,const nvinfer1::Dims & trt_dims,int batch_size)838 TRT_TensorOrWeights::TRT_TensorOrWeights(nvinfer1::DataType trt_dtype,
839                                          const nvinfer1::Dims& trt_dims,
840                                          int batch_size)
841     : simple_itensor_(new SimpleITensor(trt_dtype, trt_dims)),
842       batch_size_(batch_size),
843       initialized_(true),
844       is_tensor_(true) {}
845 
TRT_TensorOrWeights(const TRT_ShapedWeights & weights)846 TRT_TensorOrWeights::TRT_TensorOrWeights(const TRT_ShapedWeights& weights)
847     : weights_(weights), initialized_(true), is_tensor_(false) {}
848 
TRT_TensorOrWeights(const TRT_TensorOrWeights & rhs)849 TRT_TensorOrWeights::TRT_TensorOrWeights(const TRT_TensorOrWeights& rhs)
850     : tensor_(rhs.tensor_),
851       simple_itensor_(rhs.simple_itensor_),
852       batch_size_(rhs.batch_size_),
853       weights_(rhs.weights_),
854       initialized_(rhs.initialized_),
855       is_tensor_(rhs.is_tensor_) {}
856 
operator =(const TRT_TensorOrWeights & rhs)857 void TRT_TensorOrWeights::operator=(const TRT_TensorOrWeights& rhs) {
858   tensor_ = rhs.tensor_;
859   simple_itensor_ = rhs.simple_itensor_;
860   batch_size_ = rhs.batch_size_;
861   weights_ = rhs.weights_;
862   initialized_ = rhs.initialized_;
863   is_tensor_ = rhs.is_tensor_;
864 }
865 
tensor() const866 nvinfer1::ITensor* TRT_TensorOrWeights::tensor() const {
867   CHECK(is_tensor());
868   return tensor_ == nullptr ? simple_itensor_.get() : tensor_;
869 }
870 
GetTrtDims() const871 nvinfer1::Dims TRT_TensorOrWeights::GetTrtDims() const {
872   if (is_tensor()) {
873     return tensor()->getDimensions();
874   } else {
875     return weights().shape_;
876   }
877 }
878 
GetTfType(DataType * tf_type) const879 Status TRT_TensorOrWeights::GetTfType(DataType* tf_type) const {
880   if (is_tensor()) {
881     nvinfer1::DataType trt_type = tensor()->getType();
882     return TrtTypeToTfType(trt_type, tf_type);
883   }
884 
885   if (is_weights()) {
886     *tf_type = weights().GetTensor().dtype();
887     return Status::OK();
888   }
889   return errors::Internal("The object is probably not initialized");
890 }
891 
DebugString() const892 string TRT_TensorOrWeights::DebugString() const {
893   string output = "TRT_TensorOrWeights(type=";
894   if (is_tensor()) {
895     StrAppend(&output, "tensor=", tensorflow::tensorrt::DebugString(*tensor()),
896               ", batch_size=", batch_size_);
897   } else {
898     StrAppend(&output, "weights=", weights_.DebugString());
899   }
900   StrAppend(&output, ")");
901   return output;
902 }
903 
904 // Perform 5 dimensional reorder of data on CPU
905 // This is done once at convert time and does not affect GPU inference perf
906 // Example: reorder NDHWC (Tensorflow) -> NCDHW (TensorRT)
907 template <typename T>
Reorder5(const nvinfer1::Dims & shape,const T * idata,const nvinfer1::Dims & istrides,T * odata,const nvinfer1::Dims & ostrides)908 void Reorder5(const nvinfer1::Dims& shape, const T* idata,
909               const nvinfer1::Dims& istrides, T* odata,
910               const nvinfer1::Dims& ostrides) {
911   for (int k = 0; k < shape.d[0]; ++k) {
912     for (int c = 0; c < shape.d[1]; ++c) {
913       for (int d = 0; d < shape.d[2]; ++d) {
914         for (int r = 0; r < shape.d[3]; ++r) {
915           for (int s = 0; s < shape.d[4]; ++s) {
916             odata[k * ostrides.d[0] + c * ostrides.d[1] + d * ostrides.d[2] +
917                   r * ostrides.d[3] + s * ostrides.d[4]] =
918                 idata[k * istrides.d[0] + c * istrides.d[1] +
919                       d * istrides.d[2] + r * istrides.d[3] +
920                       s * istrides.d[4]];
921           }
922         }
923       }
924     }
925   }
926 }
927 
928 // TODO(jie): reorder4 & reorder2 should be merged?
929 // TODO(aaroey): fix the order of parameters.
930 template <typename T>
Reorder4(const nvinfer1::DimsNCHW & shape,const T * idata,const nvinfer1::DimsNCHW & istrides,T * odata,const nvinfer1::DimsNCHW & ostrides)931 void Reorder4(const nvinfer1::DimsNCHW& shape, const T* idata,
932               const nvinfer1::DimsNCHW& istrides, T* odata,
933               const nvinfer1::DimsNCHW& ostrides) {
934   for (int n = 0; n < shape.n(); ++n) {
935     for (int c = 0; c < shape.c(); ++c) {
936       for (int h = 0; h < shape.h(); ++h) {
937         for (int w = 0; w < shape.w(); ++w) {
938           odata[n * ostrides.n() + c * ostrides.c() + h * ostrides.h() +
939                 w * ostrides.w()] = idata[n * istrides.n() + c * istrides.c() +
940                                           h * istrides.h() + w * istrides.w()];
941         }
942       }
943     }
944   }
945 }
946 
947 template <typename T>
Reorder2(const nvinfer1::DimsHW & shape,const T * idata,const nvinfer1::DimsHW & istrides,T * odata,const nvinfer1::DimsHW & ostrides)948 void Reorder2(const nvinfer1::DimsHW& shape, const T* idata,
949               const nvinfer1::DimsHW& istrides, T* odata,
950               const nvinfer1::DimsHW& ostrides) {
951   for (int h = 0; h < shape.h(); ++h) {
952     for (int w = 0; w < shape.w(); ++w) {
953       odata[h * ostrides.h() + w * ostrides.w()] =
954           idata[h * istrides.h() + w * istrides.w()];
955     }
956   }
957 }
958 
959 // TODO(jie): fallback to tensorflow!!
ReorderCKtoKC(const TRT_ShapedWeights & iweights,TRT_ShapedWeights * oweights)960 void ReorderCKtoKC(const TRT_ShapedWeights& iweights,
961                    TRT_ShapedWeights* oweights) {
962   const int c = iweights.shape_.d[0];
963   const int k = iweights.shape_.d[1];
964   oweights->shape_.d[0] = k;
965   oweights->shape_.d[1] = c;
966   const nvinfer1::DimsHW istrides = {1, k};
967   const nvinfer1::DimsHW ostrides = {c, 1};
968   switch (iweights.TrtDType()) {
969     case nvinfer1::DataType::kFLOAT: {
970       Reorder2({k, c}, static_cast<float const*>(iweights.GetValues()),
971                istrides, static_cast<float*>(oweights->GetValues()), ostrides);
972       break;
973     }
974     case nvinfer1::DataType::kHALF: {
975       Reorder2({k, c}, static_cast<Eigen::half const*>(iweights.GetValues()),
976                istrides, static_cast<Eigen::half*>(oweights->GetValues()),
977                ostrides);
978       break;
979     }
980     default:
981       LOG(FATAL) << "Unsupported type in reorder expected fp32 or fp16 but got "
982                  << DebugString(iweights.TrtDType());
983   }
984 }
985 
ReorderRSCKToKCRS(const TRT_ShapedWeights & iweights,TRT_ShapedWeights * oweights,const int num_groups)986 void ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights,
987                        TRT_ShapedWeights* oweights, const int num_groups) {
988   CHECK(iweights.TrtDType() == oweights->TrtDType());
989   CHECK_EQ(iweights.size_bytes(), oweights->size_bytes());
990   // K indexes over output channels, C over input channels, and R and S over the
991   // height and width of the convolution
992   const int r = iweights.shape_.d[0];
993   const int s = iweights.shape_.d[1];
994   // TRT requires GKcRS, while TF depthwise has RSCK where c=1, C=G
995   const int c = iweights.shape_.d[2] / num_groups;
996   const int k = iweights.shape_.d[3] * num_groups;
997   VLOG(2) << "num_groups: " << num_groups << "c" << iweights.shape_.d[2]
998           << " then " << c << "k" << iweights.shape_.d[3] << " then " << k
999           << "r" << iweights.shape_.d[0] << " then " << r << "s"
1000           << iweights.shape_.d[1] << " then " << s;
1001   oweights->shape_.d[0] = k / num_groups;
1002   oweights->shape_.d[1] = c * num_groups;
1003   oweights->shape_.d[2] = r;
1004   oweights->shape_.d[3] = s;
1005   const nvinfer1::DimsNCHW istrides = {1, k, s * k * c, c * k};
1006   const nvinfer1::DimsNCHW ostrides = {c * r * s, r * s, s, 1};
1007   switch (iweights.TrtDType()) {
1008     case nvinfer1::DataType::kFLOAT: {
1009       Reorder4({k, c, r, s}, static_cast<float const*>(iweights.GetValues()),
1010                istrides, static_cast<float*>(oweights->GetValues()), ostrides);
1011       break;
1012     }
1013     case nvinfer1::DataType::kHALF: {
1014       Reorder4({k, c, r, s},
1015                static_cast<Eigen::half const*>(iweights.GetValues()), istrides,
1016                static_cast<Eigen::half*>(oweights->GetValues()), ostrides);
1017       break;
1018     }
1019 
1020     default:
1021       LOG(FATAL) << "Unsupported type, expected fp32 or fp16 but got "
1022                  << DebugString(iweights.TrtDType());
1023   }
1024 }
1025 
1026 // Initialize a Dims object with arbitrary dimension
InitDimsN(std::initializer_list<int> list)1027 nvinfer1::Dims InitDimsN(std::initializer_list<int> list) {
1028   nvinfer1::Dims dim;
1029   dim.nbDims = list.size();
1030   std::copy(list.begin(), list.end(), dim.d);
1031   return dim;
1032 }
1033 
1034 // Reorder 3D convolution weights from TF to TRT
ReorderDRSCKToKCDRS(const TRT_ShapedWeights & iweights,TRT_ShapedWeights * oweights,const int num_groups)1035 void ReorderDRSCKToKCDRS(const TRT_ShapedWeights& iweights,
1036                          TRT_ShapedWeights* oweights, const int num_groups) {
1037   DCHECK(iweights.TrtDType() == oweights->TrtDType());
1038   CHECK_EQ(iweights.size_bytes(), oweights->size_bytes());
1039   // K indexes over output channels, C over input channels, and R, S, D over the
1040   // height, width, depth
1041   const int d = iweights.shape_.d[0];
1042   const int r = iweights.shape_.d[1];
1043   const int s = iweights.shape_.d[2];
1044   // TRT requires GKcRS, while TF depthwise has RSCK where c=1, C=G
1045   const int c = iweights.shape_.d[3] / num_groups;
1046   const int k = iweights.shape_.d[4] * num_groups;
1047 
1048   VLOG(2) << "num_groups: " << num_groups << ", c: " << iweights.shape_.d[3]
1049           << " becomes " << c << ", k: " << iweights.shape_.d[4] << " becomes "
1050           << k << ", d: " << d << ", r: " << r << ", s: " << s;
1051 
1052   oweights->shape_.d[0] = iweights.shape_.d[4];  // k / num_groups;
1053   oweights->shape_.d[1] = iweights.shape_.d[3];  // c * num_groups;
1054   oweights->shape_.d[2] = d;
1055   oweights->shape_.d[3] = r;
1056   oweights->shape_.d[4] = s;
1057 
1058   nvinfer1::Dims shape =
1059       InitDimsN({k, c, d, r, s});  // KCDRS shape (same as output)
1060 
1061   nvinfer1::Dims ostrides =
1062       InitDimsN({c * d * r * s, d * r * s, r * s, s,
1063                  1});  // Output = KCDRS = k*CDRS + c*DRS + d*RS + r*S + s
1064 
1065   nvinfer1::Dims istrides =
1066       InitDimsN({1, k, r * s * c * k, s * c * k,
1067                  c * k});  // Input = DRSCK = k*1 + c*K + d*RSCK + r*SCK + s*CK
1068 
1069   switch (iweights.TrtDType()) {
1070     case nvinfer1::DataType::kFLOAT: {
1071       Reorder5(shape, static_cast<float const*>(iweights.GetValues()), istrides,
1072                static_cast<float*>(oweights->GetValues()), ostrides);
1073       break;
1074     }
1075     case nvinfer1::DataType::kHALF: {
1076       Reorder5(shape, static_cast<Eigen::half const*>(iweights.GetValues()),
1077                istrides, static_cast<Eigen::half*>(oweights->GetValues()),
1078                ostrides);
1079       break;
1080     }
1081     default:
1082       LOG(FATAL) << "Unsupported type, expected fp32 or fp16 but got "
1083                  << DebugString(iweights.TrtDType());
1084   }
1085 }
1086 
GetTempWeights(nvinfer1::DataType trt_dtype,const nvinfer1::Dims & dims)1087 TRT_ShapedWeights TrtWeightStore::GetTempWeights(nvinfer1::DataType trt_dtype,
1088                                                  const nvinfer1::Dims& dims) {
1089   TensorShape shape;
1090   DataType tf_dtype;
1091   // TODO(laigd): make it return a status.
1092   TF_CHECK_OK(TensorShapeUtils::MakeShape(dims.d, dims.nbDims, &shape));
1093   TF_CHECK_OK(TrtTypeToTfType(trt_dtype, &tf_dtype));
1094   // TODO(jie): check weights size_bytes. 0 means type error
1095   Tensor tensor(tf_dtype, shape);
1096   TRT_ShapedWeights weights(trt_dtype, dims, tensor);
1097   store_.emplace_back(std::move(tensor));
1098   return weights;
1099 }
1100 
OpConverterParams(const NodeDef & node_def,const std::vector<TRT_TensorOrWeights> & inputs,std::vector<TRT_TensorOrWeights> * outputs,TrtWeightStore * weight_store,TrtPrecisionMode precision_mode,bool use_calibration,bool use_implicit_batch)1101 OpConverterParams::OpConverterParams(
1102     const NodeDef& node_def, const std::vector<TRT_TensorOrWeights>& inputs,
1103     std::vector<TRT_TensorOrWeights>* outputs, TrtWeightStore* weight_store,
1104     TrtPrecisionMode precision_mode, bool use_calibration,
1105     bool use_implicit_batch)
1106     : node_def(node_def),
1107       inputs(inputs),
1108       outputs(outputs),
1109       validation_only(true),
1110       weight_store(weight_store),
1111       precision_mode(precision_mode),
1112       use_calibration(use_calibration),
1113       use_implicit_batch(use_implicit_batch) {}
1114 
OpConverterParams(Converter * converter,const NodeDef & node_def,const std::vector<TRT_TensorOrWeights> & inputs,std::vector<TRT_TensorOrWeights> * outputs,TrtWeightStore * weight_store)1115 OpConverterParams::OpConverterParams(
1116     Converter* converter, const NodeDef& node_def,
1117     const std::vector<TRT_TensorOrWeights>& inputs,
1118     std::vector<TRT_TensorOrWeights>* outputs, TrtWeightStore* weight_store)
1119     : converter(converter),
1120       node_def(node_def),
1121       inputs(inputs),
1122       outputs(outputs),
1123       validation_only(false),
1124       weight_store(weight_store),
1125       precision_mode(converter->precision_mode()),
1126       use_calibration(converter->use_calibration()),
1127       use_implicit_batch(converter->use_implicit_batch()) {}
1128 
1129 const std::set<string>* TrtNodeValidator::quantize_ops = new std::set<string>{
1130     "QuantizeAndDequantizeV2",
1131     "QuantizeAndDequantizeV3",
1132     "FakeQuantWithMinMaxVars",
1133     "FakeQuantWithMinMaxArgs",
1134 };
1135 
TrtNodeValidator(const grappler::GraphProperties & graph_properties,TrtPrecisionMode precision_mode,bool use_calibration,bool use_implicit_batch)1136 TrtNodeValidator::TrtNodeValidator(
1137     const grappler::GraphProperties& graph_properties,
1138     TrtPrecisionMode precision_mode, bool use_calibration,
1139     bool use_implicit_batch)
1140     : graph_properties_(graph_properties),
1141       precision_mode_(precision_mode),
1142       use_calibration_(use_calibration),
1143       use_implicit_batch_(use_implicit_batch) {
1144   RegisterOpValidators();
1145 }
1146 
ConvertToTensorOrWeights(const NodeDef & node_def,int output_port,TRT_TensorOrWeights * tensor_or_weights)1147 Status TrtNodeValidator::ConvertToTensorOrWeights(
1148     const NodeDef& node_def, int output_port,
1149     TRT_TensorOrWeights* tensor_or_weights) {
1150   if (node_def.op() == "Const") {
1151     if (output_port != 0) {
1152       return errors::InvalidArgument("Const node should only have one output.");
1153     }
1154     // The output of the conversion will be used as input to other nodes to
1155     // determine whether TRT supports those nodes. If it cannot convert the
1156     // Const, it's very likely we cannot treat it as a tensor and make it an
1157     // input to the TRT network, since TRT removes the first dimension and
1158     // treats it as batch size. Also, it's not likely that the converter can
1159     // support the op, and performance may suffer even if it can, so we just
1160     // simply return error if the conversion fails.
1161     std::vector<TRT_TensorOrWeights> inputs;
1162     return ConvertConstToWeights(node_def, inputs, tensor_or_weights);
1163   }
1164   if (!graph_properties_.HasOutputProperties(node_def.name())) {
1165     return errors::InvalidArgument("Shape and data type are unknown");
1166   }
1167 
1168   // Validate and convert shape and dtype.
1169   const auto& output_params =
1170       graph_properties_.GetOutputProperties(node_def.name());
1171   const auto& tensor_properties = output_params.at(output_port);
1172   const DataType dtype = tensor_properties.dtype();
1173   const PartialTensorShape shape = tensor_properties.shape();
1174   nvinfer1::DataType trt_dtype;
1175   nvinfer1::Dims trt_dims;
1176   int batch_size = -1;
1177   TF_RETURN_IF_ERROR(ValidateTensorProperties(
1178       node_def.op(), dtype, shape, use_implicit_batch_,
1179       /*validation_only_=*/true, &trt_dtype, &trt_dims, &batch_size));
1180 
1181   // Adds a fake ITensor. This is fine since op converter operates in
1182   // validation-only mode and it won't (and shouldn't) use the tensor to do
1183   // any TRT network operations.
1184   *tensor_or_weights = TRT_TensorOrWeights(trt_dtype, trt_dims, batch_size);
1185   return Status::OK();
1186 }
1187 
IsTensorRTCandidate(const Node * node)1188 Status TrtNodeValidator::IsTensorRTCandidate(const Node* node) {
1189   const string& op = node->def().op();
1190   // In INT8 mode, we will always apply the quantization ranges provided by
1191   // these ops to the relevant tensors. This happens regardless of the value of
1192   // use_calibration.
1193   bool is_supported_op = false;
1194   if (quantize_ops->count(op)) {
1195     is_supported_op = (precision_mode_ == TrtPrecisionMode::INT8);
1196   } else {
1197     is_supported_op = op_validators_.count(op);
1198   }
1199   if (!is_supported_op) {
1200     return errors::Unimplemented("Op type ", op, " is not supported.");
1201   }
1202 
1203   // Convert input NodeDef and corresponding output ports to
1204   // TRT_TensorOrWeights.
1205   std::vector<TRT_TensorOrWeights> inputs;
1206   std::vector<const Edge*> input_edges;
1207   TF_RETURN_IF_ERROR(node->input_edges(&input_edges));
1208   for (const Edge* edge : input_edges) {
1209     TRT_TensorOrWeights tensor_or_weights;
1210     const NodeDef& src_def = edge->src()->def();
1211     Status status = ConvertToTensorOrWeights(src_def, edge->src_output(),
1212                                              &tensor_or_weights);
1213     if (!status.ok()) {
1214       return errors::Internal(
1215           "Failed to convert input ", src_def.name(),
1216           " to a TRT_TensorOrWeights: ", status.error_message());
1217     }
1218     inputs.push_back(tensor_or_weights);
1219   }
1220 
1221   OpConverter validator = op_validators_[op];
1222   OpConverterParams params(node->def(), inputs, /*arg_outputs=*/nullptr,
1223                            &weight_store_, precision_mode_, use_calibration_,
1224                            use_implicit_batch_);
1225   return validator(&params);
1226 }
1227 
ConvertConstToWeights(const NodeDef & const_node_def,const std::vector<TRT_TensorOrWeights> & inputs,TRT_TensorOrWeights * output)1228 Status TrtNodeValidator::ConvertConstToWeights(
1229     const NodeDef& const_node_def,
1230     const std::vector<TRT_TensorOrWeights>& inputs,
1231     TRT_TensorOrWeights* output) {
1232   std::vector<TRT_TensorOrWeights> outputs;
1233   OpConverterParams params(const_node_def, inputs, &outputs, &weight_store_,
1234                            precision_mode_, use_calibration_,
1235                            use_implicit_batch_);
1236   Status status = op_validators_["Const"](&params);
1237   if (status.ok() && output) *output = outputs[0];
1238   return status;
1239 }
1240 
1241 // static
Create(TrtPrecisionMode precision_mode,bool use_calibration,nvinfer1::ILogger * trt_logger,const bool use_implicit_batch,absl::string_view engine_name)1242 StatusOr<std::unique_ptr<Converter>> Converter::Create(
1243     TrtPrecisionMode precision_mode, bool use_calibration,
1244     nvinfer1::ILogger* trt_logger, const bool use_implicit_batch,
1245     absl::string_view engine_name) {
1246   std::unique_ptr<Converter> converter = absl::WrapUnique(
1247       new Converter(precision_mode, use_calibration, trt_logger,
1248                     use_implicit_batch, engine_name));
1249   TF_RETURN_IF_ERROR(converter->Init(trt_logger));
1250   return converter;
1251 }
1252 
Converter(TrtPrecisionMode precision_mode,bool use_calibration,nvinfer1::ILogger * trt_logger,const bool use_implicit_batch,absl::string_view engine_name)1253 Converter::Converter(TrtPrecisionMode precision_mode, bool use_calibration,
1254                      nvinfer1::ILogger* trt_logger,
1255                      const bool use_implicit_batch,
1256                      absl::string_view engine_name)
1257     : precision_mode_(precision_mode),
1258       use_calibration_(use_calibration),
1259       use_implicit_batch_(use_implicit_batch),
1260       engine_name_(engine_name) {
1261   MaybeInitializeTrtPlugins(trt_logger);
1262   this->RegisterOpConverters();
1263 }
1264 
Init(nvinfer1::ILogger * trt_logger)1265 Status Converter::Init(nvinfer1::ILogger* trt_logger) {
1266   VLOG(1) << "Creating TensorRT builder";
1267   trt_builder_.reset(nvinfer1::createInferBuilder(*trt_logger));
1268 
1269   VLOG(1) << "Creating TensorRT network";
1270 #if IS_TRT_VERSION_GE(6, 0, 0, 0)
1271   const uint32_t flags =
1272       use_implicit_batch_
1273           ? 0U
1274           : (1U << static_cast<int>(
1275                  nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH));
1276   trt_network_.reset(trt_builder_->createNetworkV2(flags));
1277 #else
1278   trt_network_.reset(trt_builder_->createNetwork());
1279 #endif
1280   if (!trt_network_) {
1281     return errors::Internal("Failed to create TensorRT network object");
1282   }
1283   return Status::OK();
1284 }
1285 
ConvertNode(const NodeDef & node_def)1286 Status Converter::ConvertNode(const NodeDef& node_def) {
1287   std::vector<TRT_TensorOrWeights> inputs, outputs;
1288   TF_RETURN_IF_ERROR(this->GetInputs(node_def, &inputs));
1289 
1290   OpConverterParams params(this, node_def, inputs, &outputs, &weight_store_);
1291   const string& op = node_def.op();
1292   auto itr = op_registry_.find(op);
1293   if (itr == op_registry_.end()) {
1294     return errors::Unimplemented("No converter registered for op: ", op);
1295   }
1296   OpConverter op_converter = itr->second;
1297   TF_RETURN_IF_ERROR(op_converter(&params));
1298 
1299   for (size_t i = 0; i < outputs.size(); ++i) {
1300     TRT_TensorOrWeights& output = outputs[i];
1301     string output_name = node_def.name();
1302     if (i != 0) absl::StrAppend(&output_name, ":", i);
1303     // We need to check the name before setting it. If the input is one of the
1304     // engine input, setting the name here will overwrite engine input
1305     // bindings which will cause runtime error.
1306     // TODO(tmorris): Remove this work-around once we use TRT's IIdentityLayer
1307     // in ConvertIdentity.
1308     if (output.is_tensor()) {
1309       const char* tensor_name = output.tensor()->getName();
1310       if (!IsEngineInput(tensor_name)) {
1311         // TRT initializes tensor names as "(Unnamed ITensor* N)". We rename
1312         // them to match their corresponding TensorFlow name.
1313         // Note: ITensors that we create internally within TF-TRT which are
1314         // not inputs or outputs of a node will not be renamed. This is a
1315         // potential cause of confusion if an error message or warning
1316         // mentions the unnamed tensor.
1317         output.tensor()->setName(output_name.c_str());
1318       }
1319     }
1320     VLOG(2) << "Adding out tensor " << output_name << ": "
1321             << output.DebugString();
1322     Status status = AddTensorOrWeights(output_name, output);
1323     if (!status.ok()) {
1324       return Status(status.code(),
1325                     StrCat("Failed to add output for node ", node_def.name(),
1326                            ": ", status.error_message()));
1327     }
1328   }
1329   return Status::OK();
1330 }
1331 
AddInputTensor(const string & name,nvinfer1::DataType dtype,const nvinfer1::Dims & dims,int batch_size)1332 Status Converter::AddInputTensor(const string& name, nvinfer1::DataType dtype,
1333                                  const nvinfer1::Dims& dims, int batch_size) {
1334   // We verify the batch size only for the input nodes, and rely on individual
1335   // op converter to ensure the batch size of the outputs is not changed.
1336   // TODO(laigd): we need to test this properties.
1337   Status status;
1338   if (use_implicit_batch_) {
1339     status = MaybeUpdateBatchSize(batch_size);
1340     if (!status.ok()) {
1341       return Status(status.code(),
1342                     StrCat("Batch size doesn't match for tensor ", name, ": ",
1343                            status.error_message()));
1344     }
1345   }
1346   nvinfer1::ITensor* tensor = network()->addInput(name.c_str(), dtype, dims);
1347   if (tensor == nullptr) {
1348     return errors::InvalidArgument("Failed to create Input layer tensor ", name,
1349                                    " rank=", dims.nbDims);
1350   }
1351   status = AddTensorOrWeights(name, TRT_TensorOrWeights(tensor));
1352   if (!status.ok()) {
1353     return Status(status.code(), StrCat("Failed to add input tensor ", name,
1354                                         ": ", status.error_message()));
1355   }
1356   return Status::OK();
1357 }
1358 
RenameAndMarkOutputTensors(const std::vector<Converter::EngineOutputInfo> & output_tensors)1359 Status Converter::RenameAndMarkOutputTensors(
1360     const std::vector<Converter::EngineOutputInfo>& output_tensors) {
1361   int output_index = 0;
1362   for (const auto& output : output_tensors) {
1363     TRT_TensorOrWeights tensor_or_weights;
1364     TF_RETURN_IF_ERROR(
1365         GetTensorOrWeights(output.source_tensor_name, &tensor_or_weights));
1366     if (!tensor_or_weights.is_tensor()) {
1367       return errors::InvalidArgument("Output ", output.source_tensor_name,
1368                                      " is weights not tensor");
1369     }
1370     nvinfer1::ITensor* tensor = tensor_or_weights.tensor();
1371     if (tensor == nullptr) {
1372       return errors::NotFound("Output tensor not found: ",
1373                               output.source_tensor_name);
1374     }
1375     // Check if this tensor has already been marked as an input or output.
1376     //
1377     // ConvertIdentity can cause the same tensor to be repeated in
1378     // output_tensors, which can cause us to overwrite the name of the output
1379     // tensor binding. For example, if we rename OutputPH_0 to OutputPH_1 then
1380     // we won't be able to locate OutputPH_0 during runtime. To fix this,
1381     // duplicate the tensor using no-op shuffle.
1382     //
1383     // TODO(tmorris): Remove this work-around once we use TRT's IIdentityLayer
1384     // in ConvertIdentity.
1385     if (IsEngineInput(tensor->getName()) || IsEngineOutput(tensor->getName())) {
1386       // Using shuffle layer for identity by not setting reshape or transpose.
1387       nvinfer1::IShuffleLayer* layer = network()->addShuffle(*tensor);
1388       TFTRT_RETURN_ERROR_IF_NULLPTR(
1389           layer, StrCat("Output Copy for ", tensor->getName()));
1390       SetLayerName(layer, tensor->getName(), "shuffle", output_index);
1391       MarkQuantizationRangesAsInferrable(tensor, layer->getOutput(0));
1392       tensor = layer->getOutput(0);
1393     }
1394     tensor->setName(output.dest_node_name.c_str());
1395     network()->markOutput(*tensor);
1396     // Set type after marking as output. TRT only supports setType for engine
1397     // outputs and inputs (type is inferred otherwise).
1398     tensor->setType(output.trt_dtype);
1399     output_index++;
1400     VLOG(1) << "Marking output TRT tensor " << output.source_tensor_name
1401             << " with data type " << DebugString(output.trt_dtype)
1402             << ", which feeds TF node " << output.dest_node_name;
1403   }
1404   if (VLOG_IS_ON(2)) {
1405     VLOG(2) << "Created TensorRT network with the following layers:";
1406     for (int i = 0; i < network()->getNbLayers(); i++) {
1407       auto layer = network()->getLayer(i);
1408       VLOG(2) << "    " << layer->getName() << " ("
1409               << "type: " << static_cast<int>(layer->getType())
1410               << ", precision: " << static_cast<int>(layer->getPrecision())
1411               << ")";
1412     }
1413   }
1414   return Status::OK();
1415 }
1416 
1417 #if IS_TRT_VERSION_GE(7, 1, 3, 0)
1418 // An algorithm selector that always returns a specific ID for selectAlgorithms.
1419 // This is used to support the implementation of using environment variable
1420 // `TF_TRT_FIXED_ALGORITHM_ID` for debugging TensorRT.
1421 class StaticAlgorithmSelector : public nvinfer1::IAlgorithmSelector {
1422  private:
1423   int32_t algorithm_id_;
1424 
1425  public:
StaticAlgorithmSelector(int32_t algorithm_id)1426   StaticAlgorithmSelector(int32_t algorithm_id) : algorithm_id_(algorithm_id) {}
1427 
1428   // Returns value in [0, nbChoices] for a valid algorithm.
selectAlgorithms(const nvinfer1::IAlgorithmContext & algoContext,const nvinfer1::IAlgorithm * const * algoChoices,int32_t nbChoices,int32_t * selection)1429   int32_t selectAlgorithms(const nvinfer1::IAlgorithmContext& algoContext,
1430                            const nvinfer1::IAlgorithm* const* algoChoices,
1431                            int32_t nbChoices, int32_t* selection) override {
1432     // TensorRT always provides more than zero number of algorithms
1433     // in selectAlgorithms.
1434     assert(nbChoices > 0);
1435 
1436     // making sure that the requested TRT algorithm ID doesn't go above the
1437     // max value accepted.
1438     selection[0] = std::min(algorithm_id_, nbChoices);
1439     return 1;
1440   }
1441 
1442   // Called by TensorRT to report choices it made.
reportAlgorithms(const nvinfer1::IAlgorithmContext * const * algoContexts,const nvinfer1::IAlgorithm * const * algoChoices,int32_t nbAlgorithms)1443   void reportAlgorithms(const nvinfer1::IAlgorithmContext* const* algoContexts,
1444                         const nvinfer1::IAlgorithm* const* algoChoices,
1445                         int32_t nbAlgorithms) override {}  // do nothing
1446 };
1447 #endif
1448 
BuildCudaEngine(TrtUniquePtrType<nvinfer1::ICudaEngine> * engine,int max_batch_size,size_t max_workspace_size_bytes,nvinfer1::IGpuAllocator * allocator,TRTInt8Calibrator * calibrator,TrtShapeOptimizationProfile * profiles)1449 Status Converter::BuildCudaEngine(
1450     TrtUniquePtrType<nvinfer1::ICudaEngine>* engine, int max_batch_size,
1451     size_t max_workspace_size_bytes, nvinfer1::IGpuAllocator* allocator,
1452     TRTInt8Calibrator* calibrator, TrtShapeOptimizationProfile* profiles) {
1453   tensorflow::profiler::AnnotatedTraceMe activity(
1454       [&]() {
1455         return tensorflow::profiler::TraceMeOpOverride("TRTEngineOp",
1456                                                        "BuildEngine");
1457       },
1458       tensorflow::profiler::TraceMeLevel::kInfo);
1459 
1460   VLOG(1) << "Configuring TensorRT builder";
1461   trt_builder_->setMaxBatchSize(max_batch_size);
1462   trt_builder_->setGpuAllocator(allocator);
1463 #if IS_TRT_VERSION_GE(6, 0, 0, 0)
1464   // Create a network configuration and use it to build a TRT engine.
1465   TrtUniquePtrType<nvinfer1::IBuilderConfig> builder_config(
1466       trt_builder_->createBuilderConfig());
1467   builder_config->setMaxWorkspaceSize(max_workspace_size_bytes);
1468 
1469 #if IS_TRT_VERSION_GE(7, 1, 3, 0)
1470   static int32_t trt_algorithm_id = [] {
1471     int64 trt_algorithm_id;
1472     TF_CHECK_OK(tensorflow::ReadInt64FromEnvVar("TF_TRT_FIXED_ALGORITHM_ID",
1473                                                 /*default_val=*/-1,
1474                                                 &trt_algorithm_id));
1475     return static_cast<int32_t>(trt_algorithm_id);
1476   }();
1477 
1478   if (trt_algorithm_id >= 0) {
1479     VLOG(1) << "Forcing TRT algorithm selection to: ID=" << trt_algorithm_id;
1480     StaticAlgorithmSelector trt_algorithm_selector(trt_algorithm_id);
1481     builder_config->setAlgorithmSelector(&trt_algorithm_selector);
1482   }
1483 #endif
1484 
1485   if (precision_mode_ == TrtPrecisionMode::FP16) {
1486     builder_config->setFlag(nvinfer1::BuilderFlag::kFP16);
1487   } else if (precision_mode_ == TrtPrecisionMode::INT8) {
1488     builder_config->setFlag(nvinfer1::BuilderFlag::kFP16);
1489     builder_config->setFlag(nvinfer1::BuilderFlag::kINT8);
1490     if (use_calibration_) {
1491       builder_config->setInt8Calibrator(calibrator);
1492     } else {
1493       builder_config->setInt8Calibrator(nullptr);
1494     }
1495   }
1496   if (!use_implicit_batch_ && profiles) {
1497     TF_RETURN_IF_ERROR(profiles->ConfigureBuilder(
1498         trt_builder_.get(), builder_config.get(), network()));
1499   }
1500 
1501   string precision_mode_str;
1502   TF_RETURN_IF_ERROR(
1503       TrtPrecisionModeToName(precision_mode_, &precision_mode_str));
1504   string trt_network_name = StrCat(
1505       "TF:", TF_VERSION_STRING, ", ",
1506       "TRT:", absl::StrJoin(GetLoadedTensorRTVersion(), "."), "-",
1507       "Precision:", precision_mode_str, ", ", "Calibration:", use_calibration_,
1508       ", ", "Max-Batch-Size:", max_batch_size, ", ",
1509       "Max-Workspace-Size:", max_workspace_size_bytes);
1510   VLOG(1) << "Setting TensorRT network name to " << trt_network_name;
1511   network()->setName(trt_network_name.c_str());
1512 
1513   VLOG(1) << "Building TensorRT engine";
1514   engine->reset(
1515       trt_builder_->buildEngineWithConfig(*network(), *builder_config));
1516 #else
1517   trt_builder_->setMaxWorkspaceSize(max_workspace_size_bytes);
1518   if (precision_mode_ == TrtPrecisionMode::FP16) {
1519     trt_builder_->setFp16Mode(true);
1520   } else if (precision_mode_ == TrtPrecisionMode::INT8) {
1521     // Setting FP16 mode as well allows TRT to also consider FP16 kernels and
1522     // use them in situations where they are faster than INT8 or where INT8 is
1523     // not supported for a given layer.
1524     trt_builder_->setFp16Mode(true);
1525     trt_builder_->setInt8Mode(true);
1526     if (use_calibration_) {
1527       trt_builder_->setInt8Calibrator(calibrator);
1528     } else {
1529       trt_builder_->setInt8Calibrator(nullptr);
1530     }
1531   }
1532 
1533   VLOG(1) << "Building TensorRT engine";
1534   engine->reset(trt_builder_->buildCudaEngine(*network()));
1535 #endif
1536   if (engine->get() == nullptr) {
1537     return errors::Internal("Failed to build TensorRT engine");
1538   }
1539   return Status::OK();
1540 }
1541 
MaybeUpdateBatchSize(int batch_size)1542 Status Converter::MaybeUpdateBatchSize(int batch_size) {
1543   // OK iff either is unknown or they equal to each other.
1544   if (this->batch_size_ < 0 || batch_size < 0 ||
1545       this->batch_size_ == batch_size) {
1546     if (this->batch_size_ < 0 && batch_size >= 0) {
1547       this->batch_size_ = batch_size;
1548     }
1549     return Status::OK();
1550   }
1551   return errors::InvalidArgument(
1552       "Provided batch size does not match converter batch size: ", batch_size,
1553       " vs ", batch_size_);
1554 }
1555 
AddTensorOrWeights(const string & name,TRT_TensorOrWeights input)1556 Status Converter::AddTensorOrWeights(const string& name,
1557                                      TRT_TensorOrWeights input) {
1558   // Set the batch size of the tensor, using batch size collected from the
1559   // input tensors to the TRT subgraph at the beginning of the conversion.
1560   // We rely on the individual op converter to understand the semantics of the
1561   // TF node, and make sure it doesn't change the batch size nor introduce
1562   // intra-element dependency inside the batch.
1563   if (use_implicit_batch_ && input.is_tensor()) {
1564     input.set_batch_size(batch_size_);
1565   }
1566   if (trt_tensors_.insert({name, std::move(input)}).second) return Status::OK();
1567   return errors::AlreadyExists("tensor/weights ", name, " already exist.");
1568 }
1569 
GetTensorOrWeights(const string & name,TRT_TensorOrWeights * output)1570 Status Converter::GetTensorOrWeights(const string& name,
1571                                      TRT_TensorOrWeights* output) {
1572   if (!trt_tensors_.count(name)) {
1573     return errors::NotFound("Tensor or weights with name ", name,
1574                             " could not be found.");
1575   }
1576   *output = trt_tensors_.at(name);
1577   return Status::OK();
1578 }
1579 
TransposeTensor(nvinfer1::ITensor * input_tensor,const std::vector<int> & order_with_batch_dim,nvinfer1::ITensor ** output_tensor,const NodeDef & node_def,absl::string_view sub_op_name)1580 Status Converter::TransposeTensor(nvinfer1::ITensor* input_tensor,
1581                                   const std::vector<int>& order_with_batch_dim,
1582                                   nvinfer1::ITensor** output_tensor,
1583                                   const NodeDef& node_def,
1584                                   absl::string_view sub_op_name) {
1585   const auto dims = input_tensor->getDimensions();
1586   const int order_size = use_implicit_batch_ ? order_with_batch_dim.size() - 1
1587                                              : order_with_batch_dim.size();
1588   if (order_size != size_t(dims.nbDims)) {
1589     return errors::InvalidArgument(
1590         "Rank of perm for transpose does not match with that of the input.");
1591   }
1592   if (use_implicit_batch_ && order_with_batch_dim[0] != 0) {
1593     return errors::Unimplemented(
1594         "Transpose at batch dimension is not supported.");
1595   }
1596 
1597   nvinfer1::IShuffleLayer* layer = this->network()->addShuffle(*input_tensor);
1598   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, "TF-TRT Internal Transpose");
1599   SetLayerName(layer, node_def, sub_op_name);
1600 
1601   MarkQuantizationRangesAsInferrable(input_tensor, layer->getOutput(0));
1602 
1603   nvinfer1::Permutation permutation;
1604   if (use_implicit_batch_) {
1605     for (int32_t i = 0; i < dims.nbDims; ++i) {
1606       permutation.order[i] = order_with_batch_dim[i + 1] - 1;
1607     }
1608   } else {
1609     std::copy(order_with_batch_dim.begin(), order_with_batch_dim.end(),
1610               permutation.order);
1611   }
1612   VLOG(1) << "TransposeTensor permutation: "
1613           << DebugString(permutation, dims.nbDims);
1614   layer->setFirstTranspose(permutation);
1615 
1616   nvinfer1::Dims reshape_dims;
1617   reshape_dims.nbDims = dims.nbDims;
1618   for (int32_t i = 0; i < reshape_dims.nbDims; ++i) {
1619     reshape_dims.d[i] = 0;
1620     // TODO(aaroey): why not transposing the types as well?
1621     reshape_dims.type[i] = dims.type[i];
1622   }
1623   layer->setReshapeDimensions(reshape_dims);
1624 
1625   *output_tensor = layer->getOutput(0);
1626   return Status::OK();
1627 }
1628 
GetWeightRange(const TRT_ShapedWeights & weights,float * out_min,float * out_max) const1629 Status Converter::GetWeightRange(const TRT_ShapedWeights& weights,
1630                                  float* out_min, float* out_max) const {
1631   switch (weights.TrtDType()) {
1632     case nvinfer1::DataType::kFLOAT: {
1633       auto inp = static_cast<float const*>(weights.GetValues());
1634       auto result = std::minmax_element(inp, inp + weights.count());
1635       *out_min = *result.first;
1636       *out_max = *result.second;
1637       break;
1638     }
1639     case nvinfer1::DataType::kHALF: {
1640       auto inp = static_cast<Eigen::half const*>(weights.GetValues());
1641       auto result = std::minmax_element(inp, inp + weights.count());
1642       *out_min = Eigen::half_impl::half_to_float(*result.first);
1643       *out_max = Eigen::half_impl::half_to_float(*result.second);
1644       break;
1645     }
1646     case nvinfer1::DataType::kINT32: {
1647       auto inp = static_cast<int const*>(weights.GetValues());
1648       auto result = std::minmax_element(inp, inp + weights.count());
1649       *out_min = static_cast<float>(*result.first);
1650       *out_max = static_cast<float>(*result.second);
1651       break;
1652     }
1653     default:
1654       return errors::Unimplemented(
1655           "Data type not supported for GetWeightRange: ",
1656           DebugString(weights.TrtDType()));
1657   }
1658   return Status::OK();
1659 }
1660 
1661 // Constructs <tf_related_part> for the ILayer name as
1662 // <tf_node_def_name>_<sub_op_name>_<sub_op_instance> and callSetLayerNameHelper
1663 // to set the name for the ILayer.
1664 //
1665 // If the operation represented by the ILayer is generated by the converter to
1666 // support the conversion of node_def, callers need to specify a non-empty
1667 // sub_op_name to be appended to the name of node_def to avoid layer name
1668 // conflicts. If the operation is generated multiple times, callers also need
1669 // to specify sub_op_instance to be appended to the name of the layers to avoid
1670 // layer name conflicts.
SetLayerName(nvinfer1::ILayer * layer,const NodeDef & node_def,absl::string_view sub_op_name,absl::optional<int> sub_op_instance)1671 void Converter::SetLayerName(nvinfer1::ILayer* layer, const NodeDef& node_def,
1672                              absl::string_view sub_op_name,
1673                              absl::optional<int> sub_op_instance) {
1674   std::string sub_op_suffix = GetLayerNameSuffix(sub_op_name, sub_op_instance);
1675   if (sub_op_suffix.empty()) {
1676     SetLayerNameHelper(layer, engine_name_, node_def.name());
1677   } else {
1678     SetLayerNameHelper(layer, engine_name_,
1679                        absl::StrCat(node_def.name(), "-", sub_op_suffix));
1680   }
1681 }
1682 
1683 // Constructs <tf_related_part> for the ILayer name as
1684 // <main_op_name>_<sub_op_name>_<sub_op_instance> and callSetLayerNameHelper to
1685 // set the name for the ILayer.
SetLayerName(nvinfer1::ILayer * layer,absl::string_view main_op_name,absl::string_view sub_op_name,absl::optional<int> sub_op_instance)1686 void Converter::SetLayerName(nvinfer1::ILayer* layer,
1687                              absl::string_view main_op_name,
1688                              absl::string_view sub_op_name,
1689                              absl::optional<int> sub_op_instance) {
1690   std::string layer_name_suffix =
1691       GetLayerNameSuffix(sub_op_name, sub_op_instance);
1692   SetLayerNameHelper(layer, engine_name_,
1693                      absl::StrCat(main_op_name, "-", layer_name_suffix));
1694 }
1695 
1696 // Converts 'input' of 'node_def' into 'tensor' with shape specified by 'dims'
1697 // (which doesn't contain the batch dimension).
1698 //
1699 // If validation_only is true, it doesn't do the conversion but only do some
1700 // minimum validation for the eligibility of the conversion, and *tensor will
1701 // be set to nullptr.
PrepareTensorForShape(Converter * converter,const TRT_TensorOrWeights & input,const nvinfer1::Dims & dims,const bool validation_only,nvinfer1::ITensor ** tensor,const NodeDef & node_def,absl::optional<int> op_instance)1702 Status PrepareTensorForShape(Converter* converter,
1703                              const TRT_TensorOrWeights& input,
1704                              const nvinfer1::Dims& dims,
1705                              const bool validation_only,
1706                              nvinfer1::ITensor** tensor,
1707                              const NodeDef& node_def,
1708                              absl::optional<int> op_instance) {
1709   const nvinfer1::Dims input_dims = input.GetTrtDims();
1710   // If one of input_dims and dims doesn't have static shape, it means some of
1711   // the dims are unknown or need to be inferred. And we don't do further checks
1712   // but rely on the caller to not make mistakes.
1713   // Otherwise we do simple check to make sure the total sizes are the same.
1714   // If an input is a weight, it is going to become a tensor via
1715   // CreateConstantLayer. So we can treat it as a tensor for
1716   // AreDimsStaticWithDifferentSize(). This really only matters for 0-D tensors.
1717   if (AreDimsStaticWithDifferentSize(input_dims, dims, /*is_tensor=*/true)) {
1718     return errors::InvalidArgument(
1719         "Incompatible shapes: ", DebugString(input_dims), " vs. ",
1720         DebugString(dims));
1721   }
1722   // ConstantLayer requires static shapes (cannot infer -1).
1723   if (input.is_weights() && !HasStaticShape(dims)) {
1724     return errors::InvalidArgument("Shape is not fully defined: ",
1725                                    DebugString(dims));
1726   }
1727   if (validation_only) {
1728     *tensor = nullptr;
1729     return Status::OK();
1730   }
1731 
1732   TFTRT_RETURN_ERROR_IF_NULLPTR(converter, "converter is nullptr");
1733   if (input.is_tensor()) {
1734     if (DimsEqual(input_dims, dims)) {
1735       *tensor = input.tensor();
1736     } else {
1737       nvinfer1::IShuffleLayer* layer =
1738           converter->network()->addShuffle(*input.tensor());
1739       TFTRT_RETURN_ERROR_IF_NULLPTR(layer, "TF-TRT Internal Reshape");
1740       converter->SetLayerName(layer, node_def, "shuffle", op_instance);
1741       layer->setReshapeDimensions(dims);
1742       converter->MarkQuantizationRangesAsInferrable(input.tensor(),
1743                                                     layer->getOutput(0));
1744       *tensor = layer->getOutput(0);
1745     }
1746   } else {
1747     *tensor = converter->CreateConstantLayer(input.weights(), dims);
1748     TFTRT_RETURN_ERROR_IF_NULLPTR(*tensor, "TF-TRT Internal Reshape");
1749     if (converter->precision_mode() == TrtPrecisionMode::INT8 &&
1750         !converter->use_calibration()) {
1751       // If we are in int8 mode and not calibrating, we need to explicitly set a
1752       // quantization range for the output tensor of the IConstantLayer. Here we
1753       // set the range to [min(weights), max(weights)].
1754       float min_range = 0.0f;
1755       float max_range = 0.0f;
1756       TF_RETURN_IF_ERROR(
1757           converter->GetWeightRange(input.weights(), &min_range, &max_range));
1758       // Avoid setting range to 0 because TRT will throw an error. If the
1759       // weights are zero then the range doesn't matter: using 127.0f should
1760       // ensure the quantized weight will be exactly zero.
1761       if (min_range == 0.0f && max_range == 0.0f) {
1762         min_range = -127.0f;
1763         max_range = 127.0f;
1764       }
1765       converter->ProvideQuantizationRange(*tensor, min_range, max_range);
1766     }
1767   }
1768   return Status::OK();
1769 }
1770 
MarkQuantizationRangesAsInferrable(nvinfer1::ITensor * input,nvinfer1::ITensor * output)1771 void Converter::MarkQuantizationRangesAsInferrable(nvinfer1::ITensor* input,
1772                                                    nvinfer1::ITensor* output) {
1773   quantization_infer_.push_back({input, output});
1774   quantization_infer_.push_back({output, input});
1775 }
1776 
ProvideQuantizationRange(nvinfer1::ITensor * tensor,float min_range,float max_range)1777 void Converter::ProvideQuantizationRange(nvinfer1::ITensor* tensor,
1778                                          float min_range, float max_range) {
1779   float symmetric_range = std::max(std::abs(min_range), std::abs(max_range));
1780   quantization_ranges_[tensor] = symmetric_range;
1781 }
1782 
1783 namespace {
1784 
IsConvolution(const nvinfer1::ILayer * layer)1785 bool IsConvolution(const nvinfer1::ILayer* layer) {
1786   return layer->getType() == nvinfer1::LayerType::kCONVOLUTION;
1787 }
1788 
IsScale(const nvinfer1::ILayer * layer)1789 bool IsScale(const nvinfer1::ILayer* layer) {
1790   return layer->getType() == nvinfer1::LayerType::kSCALE;
1791 }
1792 
IsClipOrRelu(const nvinfer1::ILayer * layer)1793 bool IsClipOrRelu(const nvinfer1::ILayer* layer) {
1794   if (layer->getType() != nvinfer1::LayerType::kACTIVATION) {
1795     return false;
1796   }
1797   auto activation_type = static_cast<const nvinfer1::IActivationLayer*>(layer)
1798                              ->getActivationType();
1799 #if IS_TRT_VERSION_GE(5, 1, 2, 0)
1800   return activation_type == nvinfer1::ActivationType::kRELU ||
1801          activation_type == nvinfer1::ActivationType::kCLIP;
1802 #else
1803   return activation_type == nvinfer1::ActivationType::kRELU;
1804 #endif
1805 }
1806 
IsAdd(const nvinfer1::ILayer * layer)1807 bool IsAdd(const nvinfer1::ILayer* layer) {
1808   if (layer->getType() != nvinfer1::LayerType::kELEMENTWISE) {
1809     return false;
1810   }
1811   auto operation =
1812       static_cast<const nvinfer1::IElementWiseLayer*>(layer)->getOperation();
1813   return operation == nvinfer1::ElementWiseOperation::kSUM;
1814 }
1815 
1816 }  // namespace
1817 
MaybeApplyQuantizationRanges()1818 void Converter::MaybeApplyQuantizationRanges() {
1819   if (precision_mode() != TrtPrecisionMode::INT8) return;
1820 
1821   // Infer ranges across marked ops.
1822   PropagateQuantizationRanges();
1823   // Apply ranges.
1824 #if IS_TRT_VERSION_GE(5, 0, 0, 0)
1825   for (auto pair : quantization_ranges_) {
1826     nvinfer1::ITensor* tensor = pair.first;
1827     const float range = pair.second;
1828     VLOG(1) << "Setting range for: " << tensor->getName() << ": " << range;
1829     // TODO(laigd): if 'tensor' already has a range set which doesn't match
1830     // 'range', it should report error.
1831     tensor->setDynamicRange(-range, range);
1832   }
1833 #endif
1834 
1835   if (use_calibration()) return;
1836 #if !IS_TRT_VERSION_GE(6, 0, 0, 0)
1837   // Attempt to find tensors that are missing ranges, and set the corresponding
1838   // layer's precision to FP16 to avoid Builder::buildCudaEngine() failing.
1839   // This is only needed for TensorRT 5 and before because
1840   // TensorRT6 falls to FP16 internally.
1841   // TensorRT doesn't need ranges for intermediate tensors when layers are fused
1842   // so find fused layers first.
1843   // Get all tensors from network and deduce fused ops.
1844   std::map<nvinfer1::ILayer*, std::vector<nvinfer1::ILayer*>> layer_consumers;
1845   std::map<nvinfer1::ITensor*, nvinfer1::ILayer*> tensor_layer;
1846   std::set<nvinfer1::ITensor*> all_tensors;
1847   for (int i = 0; i < this->network()->getNbLayers(); i++) {
1848     nvinfer1::ILayer* layer = this->network()->getLayer(i);
1849     layer_consumers[layer] = {};
1850     for (int j = 0; j < layer->getNbInputs(); j++) {
1851       all_tensors.insert(layer->getInput(j));
1852     }
1853     for (int j = 0; j < layer->getNbOutputs(); j++) {
1854       tensor_layer[layer->getOutput(j)] = layer;
1855       all_tensors.insert(layer->getOutput(j));
1856     }
1857   }
1858   for (int i = 0; i < this->network()->getNbLayers(); i++) {
1859     nvinfer1::ILayer* layer = this->network()->getLayer(i);
1860     layer_consumers[layer] = {};
1861     for (int j = 0; j < layer->getNbInputs(); j++) {
1862       nvinfer1::ITensor* input_tensor = layer->getInput(j);
1863       auto input_layer = tensor_layer.find(input_tensor);
1864       if (input_layer != tensor_layer.end()) {
1865         auto consumed_layer = layer_consumers.find(input_layer->second);
1866         if (consumed_layer != layer_consumers.end()) {
1867           consumed_layer->second.push_back(layer);
1868         }
1869       }
1870       all_tensors.insert(input_tensor);
1871     }
1872   }
1873   // Identify fused tensors.
1874   // Conv+BiasAdd+Add+Activation(Clip or Relu), Conv+BiasAdd+Add,
1875   // Conv+BiasAdd+Activation(Clip or Relu), Conv+BiasAdd,
1876   // Conv+Activation(Clip or Relu) are fused.
1877   std::set<nvinfer1::ITensor*> fused_tensors;
1878   typedef std::function<bool(const nvinfer1::ILayer*)> matcher;
1879   const std::vector<std::pair<string, std::vector<matcher>>> fused_patterns = {
1880       {"Fused Conv+Bias+Add+Activation",
1881        {
1882            IsConvolution,
1883            IsScale,
1884            IsAdd,
1885            IsClipOrRelu,
1886        }},
1887       {"Fused Conv+Bias+Add",
1888        {
1889            IsConvolution,
1890            IsScale,
1891            IsAdd,
1892        }},
1893       {"Fused Conv+Bias+Activation",
1894        {
1895            IsConvolution,
1896            IsScale,
1897            IsClipOrRelu,
1898        }},
1899       {"Fused Conv+Bias",
1900        {
1901            IsConvolution,
1902            IsScale,
1903        }},
1904       {"Fused Conv+Activation",
1905        {
1906            IsConvolution,
1907            IsClipOrRelu,
1908        }},
1909   };
1910   for (int i = 0; i < this->network()->getNbLayers(); i++) {
1911     for (const auto& pattern : fused_patterns) {
1912       size_t last_matcher = pattern.second.size() - 1;
1913       nvinfer1::ILayer* layer = this->network()->getLayer(i);
1914       // We should skip this layer if its outputs are already marked as fused,
1915       // but all the current patterns start with a convolution and are ordered
1916       // in decreasing pattern length, so that is not necessary (yet).
1917       std::vector<nvinfer1::ILayer*> fused_candidates;
1918       for (size_t index = 0; index <= last_matcher; ++index) {
1919         if ((!pattern.second[index](layer)) ||
1920             (index < last_matcher && layer_consumers[layer].size() != 1)) {
1921           fused_candidates.clear();
1922           break;
1923         }
1924         if (index < last_matcher) {
1925           fused_candidates.push_back(layer);
1926           layer = layer_consumers[layer].front();
1927         }
1928       }
1929       if (!fused_candidates.empty()) {
1930         VLOG(1) << pattern.first;
1931         for (const auto& fused_layer : fused_candidates) {
1932           for (int i = 0; i < fused_layer->getNbOutputs(); i++) {
1933             VLOG(1) << "  Fused output tensor:"
1934                     << fused_layer->getOutput(i)->getName();
1935             fused_tensors.insert(fused_layer->getOutput(i));
1936           }
1937         }
1938         break;  // Don't try other patterns on this layer.
1939       }
1940     }
1941   }
1942   // Find tensors with no ranges that are not fused and force their layers to
1943   // not be quantized.
1944   for (auto tensor : all_tensors) {
1945     if (!quantization_ranges_.count(tensor) &&
1946         fused_tensors.find(tensor) == fused_tensors.end()) {
1947       // Note: there may be some warnings for "(Unnamed ITensor* N)". These
1948       // are tensors which are created internally by TF-TRT. The ranges for
1949       // these unnamed ITensors are always inferred from user provided ranges,
1950       // thus there will also be a warning for the range(s) the user missed.
1951       LOG_WARNING_WITH_PREFIX << "Quantization range was not found for "
1952                               << tensor->getName() << ". "
1953                               << "Setting invalid quantization range.";
1954       // Set the range to something unusable so the engine will fail if it
1955       // tries to actually use the tensor's range.
1956       tensor->setDynamicRange(0, 0);
1957       auto layer = tensor_layer.find(tensor);
1958       // If the tensor is the output of a layer, set the layer's precision
1959       // to fp16 so that it isn't quantized.
1960       // Shuffle doesn't support setting precision.
1961       if (layer != tensor_layer.end() &&
1962           layer->second->getType() != nvinfer1::LayerType::kSHUFFLE) {
1963         VLOG(1) << "And setting layer " << layer->second->getName()
1964                 << " precision to fp16.";
1965         layer->second->setPrecision(nvinfer1::DataType::kHALF);
1966       }
1967     }
1968   }
1969 #endif
1970 }
1971 
PropagateQuantizationRanges()1972 void Converter::PropagateQuantizationRanges() {
1973   // Propagate ranges across edges in quantization_infer_ until no new
1974   // information is added.
1975   // Note: this function modifies quantization_infer_, it might be better to
1976   // modify a copy instead if we for some reason need quantization_infer_
1977   // later.
1978   bool information_added = true;
1979   while (information_added) {
1980     information_added = false;
1981     for (auto it = quantization_infer_.begin();
1982          it != quantization_infer_.end();) {
1983       auto input_tensor_range = quantization_ranges_.find(it->first);
1984       auto output_tensor_range = quantization_ranges_.find(it->second);
1985       if (input_tensor_range != quantization_ranges_.end() &&
1986           output_tensor_range == quantization_ranges_.end()) {
1987         // Input has range but output doesn't: copy range
1988         // TODO(laigd): consider reporting error if it a different range is
1989         // already set.
1990         quantization_ranges_[it->second] = input_tensor_range->second;
1991         information_added = true;
1992         VLOG(1) << "Copy quantization range: " << it->first->getName() << " -> "
1993                 << it->second->getName();
1994       }
1995       // We can remove edges when the output range is known
1996       if (quantization_ranges_.find(it->second) != quantization_ranges_.end()) {
1997         it = quantization_infer_.erase(it);
1998       } else {
1999         ++it;
2000       }
2001     }
2002   }
2003 }
2004 
GetInputs(const NodeDef & node_def,std::vector<TRT_TensorOrWeights> * inputs) const2005 Status Converter::GetInputs(const NodeDef& node_def,
2006                             std::vector<TRT_TensorOrWeights>* inputs) const {
2007   for (auto const& input_name : node_def.input()) {
2008     /*************************************************************************
2009      * TODO(jie): handle case 1) here.
2010      * Normalizes the inputs and extracts associated metadata:
2011      * 1) Inputs can contain a colon followed by a suffix of characters.
2012      *    That suffix may be a single number (e.g. inputName:1) or several
2013      *    word characters separated from a number by a colon
2014      *    (e.g. inputName:foo:1). The
2015      *    latter case is used to denote inputs and outputs of functions.
2016      * 2) Control dependency inputs contain caret at the beginning and we
2017      *    remove this and annotate the edge as a control dependency.
2018      ************************************************************************/
2019     // skip control nodes
2020     if (input_name[0] == '^') continue;
2021     string name = input_name;
2022     auto last = name.find_last_of(':');
2023     // TODO(aaroey): use TensorId
2024     if (last != string::npos && last + 2 == name.size() &&
2025         name[last + 1] == '0') {
2026       name.erase(last);
2027     }
2028 
2029     if (trt_tensors_.count(name)) {
2030       TRT_TensorOrWeights input = trt_tensors_.at(name);
2031       inputs->push_back(input);
2032       VLOG(2) << "Retrieved input " << name << ": " << input.DebugString();
2033     } else {
2034       // TODO(aaroey): this should not happen, make it a CHECK.
2035       // TODO(aaroey): use StrCat for pattern like this.
2036       string msg("Node ");
2037       StrAppend(&msg, node_def.name(), " should have an input named '", name,
2038                 "' but it is not available");
2039       LOG(ERROR) << msg;
2040       return errors::InvalidArgument(msg);
2041     }
2042   }
2043   return Status::OK();
2044 }
2045 
2046 enum class TrtInputArg { kTensor = 1, kWeight = 2, kBoth = 3 };
2047 
2048 // Checks that the number of inputs match, and enforces that the inputs marked
2049 // as weights are constant. Inputs are allowed to be both weight and tensor.
CheckInputsWeights(const OpConverterParams & params,const std::vector<std::pair<string,TrtInputArg>> & expected_inputs)2050 Status CheckInputsWeights(
2051     const OpConverterParams& params,
2052     const std::vector<std::pair<string, TrtInputArg>>& expected_inputs) {
2053   const auto& inputs = params.inputs;
2054   const auto& node_def = params.node_def;
2055   if (inputs.size() != expected_inputs.size()) {
2056     return errors::InvalidArgument(
2057         node_def.op(), " got ", inputs.size(), " inputs but expected ",
2058         expected_inputs.size(), ", at ", node_def.name());
2059   }
2060   for (int i = 0; i < inputs.size(); i++) {
2061     if (expected_inputs[i].second == TrtInputArg::kWeight &&
2062         inputs.at(i).is_tensor()) {
2063       return errors::Unimplemented("The input \"", expected_inputs[i].first,
2064                                    "\" for ", node_def.op(),
2065                                    " must be a constant, at ", node_def.name());
2066     }
2067     // TODO(tfeher): Remove this check and provide a method to automatically
2068     // retrieve an input as a tensor, converting via CreateConstantLayer if it
2069     // was originally a weight. We will want a caching mechanism to prevent many
2070     // duplicate constants from being created.
2071     if (expected_inputs[i].second == TrtInputArg::kTensor &&
2072         inputs.at(i).is_weights()) {
2073       return errors::Unimplemented("The input \"", expected_inputs[i].first,
2074                                    "\" for ", node_def.op(),
2075                                    " must be a tensor, at ", node_def.name());
2076     }
2077   }
2078   return Status::OK();
2079 }
2080 
2081 // Checks that the number of inputs match, and enforces that the inputs marked
2082 // as true are constant weights. true means that the input must be a weight,
2083 // while false means the input must be a tensor.
CheckInputsWeights(const OpConverterParams & params,const std::vector<std::pair<string,bool>> & inputs_is_weight)2084 Status CheckInputsWeights(
2085     const OpConverterParams& params,
2086     const std::vector<std::pair<string, bool>>& inputs_is_weight) {
2087   std::vector<std::pair<string, TrtInputArg>> expected_inputs;
2088   expected_inputs.reserve(inputs_is_weight.size());
2089   std::transform(
2090       inputs_is_weight.begin(), inputs_is_weight.end(),
2091       std::back_inserter(expected_inputs), [](std::pair<string, bool> x) {
2092         return std::make_pair(
2093             x.first, x.second ? TrtInputArg::kWeight : TrtInputArg::kTensor);
2094       });
2095   return CheckInputsWeights(params, expected_inputs);
2096 }
2097 
GetNodeDefTfType(const NodeDef & node_def,DataType * tf_type,const char * type_attr_name)2098 Status GetNodeDefTfType(const NodeDef& node_def, DataType* tf_type,
2099                         const char* type_attr_name) {
2100   TFAttrs attrs(node_def);
2101   if (!attrs.count(type_attr_name)) {
2102     return errors::InvalidArgument("Attribute with name ", type_attr_name,
2103                                    " not found.");
2104   }
2105   *tf_type = attrs.get<DataType>(type_attr_name);
2106   return Status::OK();
2107 }
2108 
GetInputTfType(const OpConverterParams & params,DataType * tf_type,int pos)2109 Status GetInputTfType(const OpConverterParams& params, DataType* tf_type,
2110                       int pos) {
2111   const std::vector<TRT_TensorOrWeights>& inputs = params.inputs;
2112   if (inputs.size() <= pos) {
2113     return errors::Internal("Invalid input position");
2114   }
2115 
2116   return inputs[pos].GetTfType(tf_type);
2117 }
2118 
2119 constexpr const char kOutputTypeAttrName[] = "T";
2120 
GetOutputTfType(const OpConverterParams & params,DataType * tf_type)2121 Status GetOutputTfType(const OpConverterParams& params, DataType* tf_type) {
2122   return GetNodeDefTfType(params.node_def, tf_type, kOutputTypeAttrName);
2123 }
2124 
AllowDataTypes(const OpConverterParams & params,const std::set<DataType> & allowed_types,const char * type_attr_name=kOutputTypeAttrName)2125 Status AllowDataTypes(const OpConverterParams& params,
2126                       const std::set<DataType>& allowed_types,
2127                       const char* type_attr_name = kOutputTypeAttrName) {
2128   const auto& node_def = params.node_def;
2129   DataType tf_type;
2130   TF_RETURN_IF_ERROR(GetNodeDefTfType(node_def, &tf_type, type_attr_name));
2131   if (!allowed_types.count(tf_type)) {
2132     string allowed_types_string = absl::StrJoin(
2133         allowed_types, ", ", [](string* out, const DataType& type) {
2134           absl::StrAppendFormat(out, "%s", DataTypeString(type));
2135         });
2136     return errors::Unimplemented("Data type ", DataTypeString(tf_type),
2137                                  " is not supported for ", node_def.op(),
2138                                  ", must be one of [", allowed_types_string,
2139                                  "], at ", node_def.name());
2140   }
2141   return Status::OK();
2142 }
2143 
2144 // ****************************************************************************
2145 // Constant folding functions for weights.
2146 // TODO(laigd): we should probably use eigen directly.
2147 // *****************************************************************************
2148 struct LambdaFactory {
2149   enum class OP_CATEGORY : int { RSQRT = 0, NEG, RECIP };
2150   OP_CATEGORY op;
2151 
2152   template <typename T>
unarytensorflow::tensorrt::convert::LambdaFactory2153   std::function<T(T)> unary() {
2154     switch (op) {
2155       case OP_CATEGORY::RSQRT: {
2156         VLOG(2) << "RSQRT GETS DONE";
2157         return [](T t) -> T { return 1.0 / std::sqrt(t); };
2158       }
2159       case OP_CATEGORY::NEG:
2160         return [](T t) -> T { return -t; };
2161       case OP_CATEGORY::RECIP:
2162         return [](T t) -> T { return 1.0 / t; };
2163       default:
2164         LOG(ERROR) << "Not supported op for unary: " << static_cast<int>(op);
2165         return nullptr;
2166     }
2167   }
2168 };
2169 
2170 template <>
unary()2171 std::function<Eigen::half(Eigen::half)> LambdaFactory::unary<Eigen::half>() {
2172   switch (op) {
2173     case OP_CATEGORY::RSQRT: {
2174       VLOG(2) << "RSQRT GETS DONE";
2175       return [](Eigen::half t) {
2176         return Eigen::half(1.0 / std::sqrt(static_cast<float>(t)));
2177       };
2178     }
2179     case OP_CATEGORY::NEG:
2180       return [](Eigen::half t) { return -t; };
2181     case OP_CATEGORY::RECIP:
2182       return [](Eigen::half t) {
2183         return Eigen::half(1.0 / static_cast<float>(t));
2184       };
2185     default:
2186       LOG(ERROR) << "Not supported op for unary: " << static_cast<int>(op);
2187       return nullptr;
2188   }
2189 }
2190 
UnaryCompute(const TRT_ShapedWeights & iweights,TRT_ShapedWeights * oweights,LambdaFactory unary_op)2191 Status UnaryCompute(const TRT_ShapedWeights& iweights,
2192                     TRT_ShapedWeights* oweights, LambdaFactory unary_op) {
2193   CHECK(iweights.TrtDType() == oweights->TrtDType());
2194   switch (iweights.TrtDType()) {
2195     case nvinfer1::DataType::kFLOAT: {
2196       auto inp = static_cast<float const*>(iweights.GetValues());
2197       auto oup = static_cast<float*>(oweights->GetValues());
2198       std::transform(inp, inp + iweights.count(), oup, unary_op.unary<float>());
2199       break;
2200     }
2201     case nvinfer1::DataType::kHALF: {
2202       auto inp = static_cast<Eigen::half const*>(iweights.GetValues());
2203       auto oup = static_cast<Eigen::half*>(oweights->GetValues());
2204       std::transform(inp, inp + iweights.count(), oup,
2205                      unary_op.unary<Eigen::half>());
2206       break;
2207     }
2208     default:
2209       return errors::Unimplemented("Data type not supported: ",
2210                                    DebugString(iweights.TrtDType()));
2211   }
2212   return Status::OK();
2213 }
2214 
2215 // Before TRT 5.1.3, we have to calculate padding for convolutions ourselves.
Conv2DPaddingHelper(OpConverterParams * params,const TFAttrs & attrs,const nvinfer1::DimsHW & kernel_size,const nvinfer1::DimsHW & dilation,const nvinfer1::DimsHW & stride,const std::vector<int64_t> & input_dims,nvinfer1::ITensor * tensor,std::vector<std::pair<int,int>> * padding,nvinfer1::ITensor ** padded_tensor)2216 Status Conv2DPaddingHelper(OpConverterParams* params, const TFAttrs& attrs,
2217                            const nvinfer1::DimsHW& kernel_size,
2218                            const nvinfer1::DimsHW& dilation,
2219                            const nvinfer1::DimsHW& stride,
2220                            const std::vector<int64_t>& input_dims,
2221                            nvinfer1::ITensor* tensor,
2222                            std::vector<std::pair<int, int>>* padding,
2223                            nvinfer1::ITensor** padded_tensor) {
2224   if (attrs.get<string>("padding") == "SAME") {
2225     nvinfer1::DimsHW effective_kernel_size = kernel_size;
2226     effective_kernel_size.h() += (kernel_size.h() - 1) * (dilation.h() - 1);
2227     effective_kernel_size.w() += (kernel_size.w() - 1) * (dilation.w() - 1);
2228     *padding = CreateSamePadding(stride, effective_kernel_size, input_dims);
2229   } else {
2230     *padding = {{0, 0}, {0, 0}};
2231   }
2232 
2233   // Handle asymmetric padding. TensorRT 5.1 added support for asymmetric
2234   // padding via setPrePadding and setPostPadding. Due to a bug in 5.1.2, we can
2235   // only use asymmetric padding in convolutions with 5.1.3+. But in 5.1.3, we
2236   // will always use setPaddingMode for simplicity.
2237   if ((*padding)[0].first != (*padding)[0].second ||
2238       (*padding)[1].first != (*padding)[1].second) {
2239     auto pad_layer = params->converter->network()->addPadding(
2240         *tensor, nvinfer1::DimsHW((*padding)[0].first, (*padding)[1].first),
2241         nvinfer1::DimsHW((*padding)[0].second, (*padding)[1].second));
2242     TFTRT_RETURN_ERROR_IF_NULLPTR(pad_layer, params->node_def.name());
2243     params->converter->SetLayerName(pad_layer, params->node_def, "pad");
2244     params->converter->MarkQuantizationRangesAsInferrable(
2245         tensor, pad_layer->getOutput(0));
2246     *padding = {{0, 0}, {0, 0}};
2247     tensor = pad_layer->getOutput(0);
2248   }
2249   *padded_tensor = tensor;
2250   return Status::OK();
2251 }
2252 
2253 namespace {
2254 // Extracts the spatial dimensions from `output_sizes` and returns them as a
2255 // vector of size 2.
GetSpatialDimsFromOutputSizes(const TRT_TensorOrWeights & output_sizes,const int h_index,const int w_index)2256 std::vector<int64_t> GetSpatialDimsFromOutputSizes(
2257     const TRT_TensorOrWeights& output_sizes, const int h_index,
2258     const int w_index) {
2259   // We use h_index and w_index instead of 1 and 2 because we haven't
2260   // transposed output_sizes along with the input.
2261   const TRT_ShapedWeights& weights = output_sizes.weights();
2262   const int output_sizes_length = weights.count();
2263   auto output_sizes_values = static_cast<int*>(weights.GetValues());
2264   // The length of output_sizes can be 2 or 4. When the length is 4,
2265   // output_sizes represents <height,width>.
2266   return {output_sizes_values[output_sizes_length == 4 ? h_index : 0],
2267           output_sizes_values[output_sizes_length == 4 ? w_index : 1]};
2268 }
2269 }  // namespace
2270 
ConvertConv2DHelper(OpConverterParams * params,int group,bool is_conv2d_backprop_input)2271 Status ConvertConv2DHelper(OpConverterParams* params, int group,
2272                            bool is_conv2d_backprop_input) {
2273   const auto& inputs = params->inputs;
2274   const auto& node_def = params->node_def;
2275   TRT_TensorOrWeights backprop_output_size;
2276   nvinfer1::ITensor* tensor = nullptr;
2277   if (is_conv2d_backprop_input) {
2278     // In the case when Conv2dBackpropInput is used for conv2d_transpose, these
2279     // inputs correspond to: output size, filter, and input.
2280     TF_RETURN_IF_ERROR(CheckInputsWeights(
2281         *params,
2282         {{"input_sizes", true}, {"filter", true}, {"out_backprop", false}}));
2283     backprop_output_size = inputs.at(0);
2284     tensor = inputs.at(2).tensor();
2285   } else {
2286     TF_RETURN_IF_ERROR(
2287         CheckInputsWeights(*params, {{"input", false}, {"filter", true}}));
2288     tensor = inputs.at(0).tensor();
2289   }
2290   TF_RETURN_IF_ERROR(
2291       AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
2292   TRT_ShapedWeights weights_rsck = inputs.at(1).weights();
2293   if (weights_rsck.shape_.nbDims != 4) {
2294     return errors::InvalidArgument("Conv2D expects kernel of dimension 4, at " +
2295                                    node_def.name());
2296   }
2297   TFAttrs attrs(node_def);
2298   auto data_format = attrs.get<string>("data_format");
2299   int c_index = (data_format == "NHWC") ? 3 : 1;
2300   int h_index = (data_format == "NHWC") ? 1 : 2;
2301   int w_index = (data_format == "NHWC") ? 2 : 3;
2302   auto tf_dilations = attrs.get<std::vector<int64>>("dilations");
2303   if (tf_dilations.size() != 4) {
2304     return errors::InvalidArgument(
2305         "Convolution dilations field must specify 4 dimensions, at ",
2306         node_def.name());
2307   }
2308   if (tf_dilations[0] != 1 || tf_dilations[c_index] != 1) {
2309     return errors::Unimplemented(
2310         "Dilation rate must be 1 for batch and channel dimensions, at ",
2311         node_def.name());
2312   }
2313   const nvinfer1::DimsHW dilation(tf_dilations[h_index], tf_dilations[w_index]);
2314   if (is_conv2d_backprop_input && (dilation.d[0] != 1 || dilation.d[1] != 1)) {
2315     return errors::Unimplemented(
2316         "Dilation with Conv2DBackpropInput (conv2d_transpose) is not supported",
2317         ", at ", node_def.name());
2318   }
2319 
2320   const auto tf_stride = attrs.get<std::vector<int64>>("strides");
2321   if (tf_stride.size() != 4) {
2322     return errors::InvalidArgument(
2323         "Convolution strides field must specify 4 dimensions, at ",
2324         node_def.name());
2325   }
2326   if (tf_stride[0] != 1 || tf_stride[c_index] != 1) {
2327     return errors::Unimplemented(
2328         "Stride must be 1 for batch and channel dimensions, at ",
2329         node_def.name());
2330   }
2331   // Channel dim must be static for DepthwiseConv2dNative since we use that
2332   // value for num_groups at build time.
2333   if (!params->use_implicit_batch && tensor->getDimensions().d[c_index] == -1) {
2334     return errors::InvalidArgument("Channel dimension must be static, at ",
2335                                    node_def.name());
2336   }
2337   const nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]);
2338   if (params->validation_only) return Status::OK();
2339 
2340   // Transpose to NCHW (NCHW is required for IConvLayer).
2341   const bool need_transpose = (data_format == "NHWC");
2342   if (need_transpose) {
2343     TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
2344         tensor, {0, 3, 1, 2}, &tensor, node_def, "to_NCHW"));
2345   }
2346   // Dimensions of transposed tensor.
2347   const auto tensor_dim = tensor->getDimensions();
2348   const int c_dim_size = tensor_dim.d[params->use_implicit_batch ? 0 : 1];
2349 
2350   // group == 0 signifies that this is a depthwise convolution, so set
2351   // num_groups to size of input's channel dim. For a non-depthwise conv,
2352   // num_groups will be 1.
2353   const int num_groups = (group == 0) ? c_dim_size : group;
2354 
2355   // For conv, TF weights are RSCK, and TRT expects KCRS.
2356   // For backprop, TF weights are RSKC, and TRT expects CKRS.
2357   // Therefore, this reorder will work for both cases.
2358   TRT_ShapedWeights weights =
2359       params->weight_store->GetTempWeights(weights_rsck);
2360   ReorderRSCKToKCRS(weights_rsck, &weights, num_groups);
2361   TRT_ShapedWeights biases(weights.TrtDType());
2362   const int output_axis = is_conv2d_backprop_input ? 1 : 0;
2363   const int noutput = weights.shape_.d[output_axis] * num_groups;
2364   nvinfer1::DimsHW kernel_size;
2365   kernel_size.h() = weights.shape_.d[2];
2366   kernel_size.w() = weights.shape_.d[3];
2367 
2368 // Before TRT 5.1.3, we have to calculate padding ourselves.
2369 #if !IS_TRT_VERSION_GE(5, 1, 3, 0)
2370   std::vector<std::pair<int, int>> padding;
2371   std::vector<int64_t> input_dims;
2372   if (is_conv2d_backprop_input) {
2373     // For backprop, calculate padding based on "input_sizes" input, which
2374     // actually corresponds to output size. ("input_sizes" makes sense in the
2375     // context of Conv2DBackpropInput).
2376     input_dims =
2377         GetSpatialDimsFromOutputSizes(backprop_output_size, h_index, w_index);
2378   } else {
2379     // Use 1 and 2 because tensor_dim has the dimensions of the transposed
2380     // input.
2381     input_dims = {static_cast<int>(tensor_dim.d[1]),
2382                   static_cast<int>(tensor_dim.d[2])};
2383   }
2384   nvinfer1::ITensor* padded_tensor = nullptr;
2385   TF_RETURN_IF_ERROR(Conv2DPaddingHelper(params, attrs, kernel_size, dilation,
2386                                          stride, input_dims, tensor, &padding,
2387                                          &padded_tensor));
2388   tensor = padded_tensor;
2389 #endif
2390 
2391   // Add convolution.
2392   nvinfer1::ILayer* conv_layer = nullptr;
2393   if (is_conv2d_backprop_input) {
2394     nvinfer1::IDeconvolutionLayer* layer =
2395         params->converter->network()->addDeconvolution(
2396             *tensor, noutput, kernel_size, weights.GetTrtWeights(),
2397             biases.GetTrtWeights());
2398     TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
2399     layer->setStride(stride);
2400 // TensorRT 5.1.3 added support for padding modes.
2401 #if IS_TRT_VERSION_GE(5, 1, 3, 0)
2402     // VALID padding is the default TRT behavior.
2403     if (attrs.get<string>("padding") == "SAME") {
2404       // SAME_UPPER means that post padding is preferred.
2405       layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER);
2406     }
2407 #else
2408     layer->setPadding(nvinfer1::DimsHW{padding[0].first, padding[1].first});
2409 #endif
2410     layer->setNbGroups(num_groups);
2411     conv_layer = layer;
2412   } else {
2413     nvinfer1::IConvolutionLayer* layer =
2414         params->converter->network()->addConvolution(
2415             *tensor, noutput, kernel_size, weights.GetTrtWeights(),
2416             biases.GetTrtWeights());
2417     TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
2418     layer->setStride(stride);
2419 #if IS_TRT_VERSION_GE(5, 1, 3, 0)
2420     if (attrs.get<string>("padding") == "SAME") {
2421       layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER);
2422     }
2423 #else
2424     layer->setPadding(nvinfer1::DimsHW{padding[0].first, padding[1].first});
2425 #endif
2426     layer->setNbGroups(num_groups);
2427     layer->setDilation(dilation);
2428     conv_layer = layer;
2429   }
2430   params->converter->SetLayerName(conv_layer, node_def, "conv");
2431   nvinfer1::ITensor* output_tensor = conv_layer->getOutput(0);
2432   // Add an extra padding for Deconv because TRT doesn't accept the
2433   // argument output_shape and thus the TRT output shape could be wrong
2434   // in case of strides>1.
2435   if (is_conv2d_backprop_input) {
2436     std::vector<int64_t> output_spatial_dims =
2437         GetSpatialDimsFromOutputSizes(backprop_output_size, h_index, w_index);
2438     const int output_height = output_spatial_dims[0];
2439     const int output_width = output_spatial_dims[1];
2440     nvinfer1::Dims trt_output_shape = output_tensor->getDimensions();
2441     // What determines the padding size is the difference between the given
2442     // input_sizes (tf_output_shape) and TRT computed size.
2443     const int height_diff = output_height - trt_output_shape.d[1];
2444     const int width_diff = output_width - trt_output_shape.d[2];
2445     if ((height_diff < 0) || (width_diff < 0)) {
2446       return errors::InvalidArgument(
2447           "input_sizes argument of Conv2DBackprop (i.e. output_shape argument "
2448           "of conv2d_transpose) ",
2449           "is too small for the given out_backprop argument of Conv2DBackprop "
2450           "(i.e. input argument of conv2d_transpose). Expect: ",
2451           "(", output_height, ", ", output_width, ") >= ", "(",
2452           trt_output_shape.d[1], ", ", trt_output_shape.d[2], ") for op ",
2453           node_def.name());
2454     }
2455     // Only add a padding layer if padding sizes are larger than 0
2456     if ((height_diff > 0) || (width_diff > 0)) {
2457       nvinfer1::DimsHW pre_padding(0, 0);
2458       nvinfer1::DimsHW post_padding(height_diff, width_diff);
2459       nvinfer1::IPaddingLayer* padding_layer =
2460           params->converter->network()->addPadding(*output_tensor, pre_padding,
2461                                                    post_padding);
2462       output_tensor = padding_layer->getOutput(0);
2463       params->converter->SetLayerName(padding_layer, node_def, "pad");
2464     }
2465   }
2466   // Restore transpose.
2467   if (need_transpose) {
2468     TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
2469         output_tensor, {0, 2, 3, 1}, &output_tensor, node_def, "to_NHWC"));
2470   }
2471   params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
2472   return Status::OK();
2473 }
2474 
AllowInefficientTranspose()2475 bool AllowInefficientTranspose() {
2476   static bool result = [] {
2477     bool value;
2478     Status status =
2479         ReadBoolFromEnvVar("TF_DEBUG_TRT_ALLOW_INEFFICIENT_TRANSPOSE",
2480                            /*default_value=*/false, &value);
2481     if (!status.ok()) {
2482       LOG(ERROR) << status;
2483     }
2484     return value;
2485   }();
2486 
2487   return result;
2488 }
2489 
ConvertTranspose(OpConverterParams * params)2490 Status ConvertTranspose(OpConverterParams* params) {
2491   const auto& inputs = params->inputs;
2492   TF_RETURN_IF_ERROR(
2493       CheckInputsWeights(*params, {{"x", false}, {"perm", true}}));
2494   TF_RETURN_IF_ERROR(AllowDataTypes(
2495       *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}));
2496   // Get the permutation from weights.
2497   TRT_ShapedWeights weights = inputs.at(1).weights();
2498   const int* weights_ptr = static_cast<int*>(weights.GetValues());
2499   std::vector<int> perm(weights_ptr, weights_ptr + weights.count());
2500 
2501   // Verify the permutation.
2502   nvinfer1::ITensor* input_tensor = inputs.at(0).tensor();
2503   const int perm_size =
2504       params->use_implicit_batch ? perm.size() - 1 : perm.size();
2505   if (perm_size != size_t(input_tensor->getDimensions().nbDims)) {
2506     return errors::InvalidArgument(
2507         "Rank of perm for transpose does not match with that of the input.");
2508   }
2509   if (params->use_implicit_batch && perm[0] != 0) {
2510     return errors::Unimplemented(
2511         "Transpose at batch dimension is not supported.");
2512   }
2513 
2514   // TensorRT as of version 7.0.0.11 is slow transposing large tensors.
2515   // So check tensor size, and don't convert if it is too large.
2516   constexpr int64_t kMaxEfficientTranspose = 2500000;
2517   int64_t tensor_size = TrtTensorDimsNumElements(input_tensor->getDimensions());
2518   if (!AllowInefficientTranspose() && tensor_size > kMaxEfficientTranspose) {
2519     return errors::Unimplemented(StrCat("Transpose too large:", tensor_size));
2520   }
2521 
2522   if (params->validation_only) return Status::OK();
2523 
2524   // Start conversion.
2525   nvinfer1::ITensor* output_tensor = nullptr;
2526   TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
2527       input_tensor, perm, &output_tensor, params->node_def));
2528   params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
2529   return Status::OK();
2530 }
2531 
ConvertShape(OpConverterParams * params)2532 Status ConvertShape(OpConverterParams* params) {
2533   const auto& inputs = params->inputs;
2534   TF_RETURN_IF_ERROR(
2535       CheckInputsWeights(*params, {{"input", TrtInputArg::kBoth}}));
2536   if (params->use_implicit_batch) {
2537     return errors::Unimplemented(
2538         "Shape is only supported for explicit batch mode.");
2539   }
2540   if (HasStaticShape(inputs.at(0).GetTrtDims())) {
2541     if (params->validation_only) return Status::OK();
2542     nvinfer1::Dims input_dims = inputs.at(0).GetTrtDims();
2543     nvinfer1::Dims output_dims{1, {input_dims.nbDims}};
2544     // Create a const node with the values of output_dims
2545     TRT_ShapedWeights weight = params->weight_store->GetTempWeights(
2546         nvinfer1::DataType::kINT32, output_dims);
2547     int32* values_ptr = static_cast<int32*>(weight.GetValues());
2548     std::copy(input_dims.d, input_dims.d + input_dims.nbDims, values_ptr);
2549     auto output = params->converter->CreateConstantLayer(weight, output_dims);
2550     params->outputs->push_back(TRT_TensorOrWeights(output));
2551     return Status::OK();
2552   }
2553 #if IS_TRT_VERSION_GE(6, 0, 0, 0)
2554   if (params->validation_only) return Status::OK();
2555   nvinfer1::IShapeLayer* shape_layer =
2556       params->converter->network()->addShape(*inputs.at(0).tensor());
2557   TFTRT_RETURN_ERROR_IF_NULLPTR(shape_layer, params->node_def.name());
2558   params->converter->SetLayerName(shape_layer, params->node_def, "shape");
2559   params->outputs->push_back(TRT_TensorOrWeights(shape_layer->getOutput(0)));
2560   return Status::OK();
2561 #else
2562   return errors::Unavailable(
2563       "Shape op conversion requires TensorRT 6 or above");
2564 #endif
2565 }
2566 
ConvertReshape(OpConverterParams * params)2567 Status ConvertReshape(OpConverterParams* params) {
2568   const auto& inputs = params->inputs;
2569   TF_RETURN_IF_ERROR(
2570       CheckInputsWeights(*params, {{"tensor", false}, {"shape", true}}));
2571   TF_RETURN_IF_ERROR(AllowDataTypes(
2572       *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}));
2573   const TRT_TensorOrWeights& input_tensor = inputs.at(0);
2574 
2575   // TODO(bixia): we can't use inputs.at(1).weights().ToVector<int>() for two
2576   // reasons: (1) When weights.count()==0, TRT_ShapedWeights::tensor_ dtype is
2577   // not properly set to INT32. (2) I tried a fix for the first problem, I got
2578   // shared pointer related error in convert_nodes_test. We should fix the
2579   // problems and switch to use inputs.at(1).weights().ToVector<int>(), a type
2580   // safe method to access the content of the tensor.
2581   TRT_ShapedWeights weights = inputs.at(1).weights();
2582   if (weights.count() == 0) {
2583     return errors::Unimplemented("Reshape to shape=[] is not supported, at ",
2584                                  params->node_def.name());
2585   }
2586 
2587   const int* output_shape_dims = static_cast<int*>(weights.GetValues());
2588   size_t output_shape_dims_count = weights.count();
2589 
2590   const int input_batch_dim = input_tensor.batch_size();
2591   const int output_batch_dim = output_shape_dims[0];
2592 
2593   const nvinfer1::Dims input_nonbatch_dims = input_tensor.GetTrtDims();
2594   nvinfer1::Dims output_nonbatch_dims;
2595   output_nonbatch_dims.nbDims = output_shape_dims_count - 1;
2596   for (int i = 1; i < output_shape_dims_count; i++) {
2597     output_nonbatch_dims.d[i - 1] = output_shape_dims[i];
2598   }
2599 
2600   VLOG(1) << "input_batch_dim=" << input_batch_dim
2601           << ", input_nonbatch_dims=" << DebugString(input_nonbatch_dims)
2602           << "\nresult_batch_dim=" << output_batch_dim
2603           << ", result_nonbatch_dims=" << DebugString(output_nonbatch_dims);
2604 
2605   // Check whether input_batch_dim and output_batch_dim will have the same
2606   // static value.
2607   bool reshape_may_change_batch_dim = false;
2608   if (input_batch_dim != -1 && output_batch_dim != -1) {
2609     reshape_may_change_batch_dim = (input_batch_dim != output_batch_dim);
2610   } else {
2611     reshape_may_change_batch_dim =
2612         !AreDimsStaticWithSameSize(input_nonbatch_dims, output_nonbatch_dims,
2613                                    /*is_tensor=*/true);
2614   }
2615   if (reshape_may_change_batch_dim) {
2616     const string msg =
2617         StrCat("Reshape on batch dimension is not supported, at ",
2618                params->node_def.name(), ". input_batch_dim=", input_batch_dim,
2619                ", ", DebugString(input_nonbatch_dims),
2620                "; output_batch_dim=", output_batch_dim, ", ",
2621                DebugString(output_nonbatch_dims));
2622     return errors::Unimplemented(msg);
2623   }
2624 
2625   // Perform the conversion.
2626   nvinfer1::ITensor* output_tensor = nullptr;
2627   TF_RETURN_IF_ERROR(PrepareTensorForShape(
2628       params->converter, input_tensor, output_nonbatch_dims,
2629       params->validation_only, &output_tensor, params->node_def));
2630   if (params->validation_only) return Status::OK();
2631 
2632   // Record the conversion result.
2633   params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
2634   return Status::OK();
2635 }
2636 
ConvertExpandDims(OpConverterParams * params)2637 Status ConvertExpandDims(OpConverterParams* params) {
2638   const auto& inputs = params->inputs;
2639   const auto& node_def = params->node_def;
2640   TF_RETURN_IF_ERROR(
2641       CheckInputsWeights(*params, {{"input", false}, {"axis", true}}));
2642   TF_RETURN_IF_ERROR(AllowDataTypes(
2643       *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}));
2644   // Get input shape as vector.
2645   const TRT_TensorOrWeights& input_tensor = inputs.at(0);
2646   const nvinfer1::Dims dims = input_tensor.GetTrtDims();
2647   std::vector<int> input_dims(dims.d, dims.d + dims.nbDims);
2648   // Get axis to expand on.
2649   auto axis = inputs.at(1).weights().GetSpan<int>();
2650   if (axis.size() != 1) {
2651     return errors::InvalidArgument("ExpandDims axis must be a scalar, at ",
2652                                    node_def.name());
2653   }
2654   // Use rank = nbDims + 1 for ConvertAxis's bounds checking to account for
2655   // ExpandDim's ability to add an axis at end of the shape.
2656   int trt_axis;
2657   TF_RETURN_IF_ERROR(ConvertAxis(axis[0], dims.nbDims + 1, node_def.name(),
2658                                  params->use_implicit_batch, &trt_axis));
2659   if (params->validation_only) return Status::OK();
2660   nvinfer1::ITensor* output_tensor = nullptr;
2661 
2662   if (!params->use_implicit_batch && !HasStaticShape(input_dims)) {
2663     TF_RETURN_IF_ERROR(params->converter->DynamicExpandDims(
2664         input_tensor.tensor(), dims, trt_axis, params, &output_tensor));
2665   } else {
2666     // ExpandDims: Insert new dim of size 1.
2667     input_dims.insert(input_dims.begin() + trt_axis, 1);
2668     // Reshape tensor.
2669     nvinfer1::Dims new_dims;
2670     TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(input_dims, &new_dims));
2671     TF_RETURN_IF_ERROR(PrepareTensorForShape(
2672         params->converter, input_tensor, new_dims, /*validation_only=*/false,
2673         &output_tensor, params->node_def));
2674   }
2675   params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
2676   return Status::OK();
2677 }
2678 
DynamicReshape(nvinfer1::ITensor * input,std::vector<std::pair<int,int>> slices,OpConverterParams * params,nvinfer1::ITensor ** output,std::vector<int> size_for_added_dims,absl::optional<int> op_instance)2679 Status Converter::DynamicReshape(nvinfer1::ITensor* input,
2680                                  std::vector<std::pair<int, int>> slices,
2681                                  OpConverterParams* params,
2682                                  nvinfer1::ITensor** output,
2683                                  std::vector<int> size_for_added_dims,
2684                                  absl::optional<int> op_instance) {
2685   *output = nullptr;
2686   // DynamicReshape relies on INetworkDefinition::addShape that was introduced
2687   // in TensorRT 6.
2688 #if IS_TRT_VERSION_GE(6, 0, 0, 0)
2689   if (params->validation_only) {
2690     return errors::Internal(
2691         "DynamicReshape should not be used during validation");
2692   }
2693   nvinfer1::ITensor* shape = network()->addShape(*input)->getOutput(0);
2694   // Build new shape = shape[:trt_axis] + [1] + shape[trt_axis:]
2695   std::vector<nvinfer1::ITensor const*> concat_inputs;
2696   int max_num_slices = std::max(slices.size(), size_for_added_dims.size());
2697   int op_instance_value = op_instance.has_value() ? op_instance.value() : 0;
2698   for (int i = 0; i < max_num_slices; i++) {
2699     nvinfer1::ITensor* tensor;
2700     int slice_instance = i * max_num_slices + op_instance_value;
2701     // maybe_add_a_dimension(i);
2702     if (i < size_for_added_dims.size() && size_for_added_dims[i] >= 0) {
2703       TF_RETURN_IF_ERROR(
2704           CreateScalarConstant(params, size_for_added_dims[i], &tensor));
2705       concat_inputs.push_back(tensor);
2706     }
2707     if (i < slices.size()) {
2708       nvinfer1::ISliceLayer* slice_layer = network()->addSlice(
2709           *shape, {1, {slices[i].first}},
2710           {1, {slices[i].second - slices[i].first}}, {1, {1}});
2711       concat_inputs.push_back(slice_layer->getOutput(0));
2712       SetLayerName(slice_layer, params->node_def, "slice", slice_instance);
2713     }
2714   }
2715   nvinfer1::IConcatenationLayer* concat_layer = network()->addConcatenation(
2716       const_cast<nvinfer1::ITensor* const*>(concat_inputs.data()),
2717       concat_inputs.size());
2718   SetLayerName(concat_layer, params->node_def, "concat", op_instance);
2719   concat_layer->setAxis(0);
2720   nvinfer1::ITensor* new_shape = concat_layer->getOutput(0);
2721   // Reshape input using new shape
2722   nvinfer1::IShuffleLayer* shuffle = network()->addShuffle(*input);
2723   SetLayerName(shuffle, params->node_def, "shuffle", op_instance);
2724   shuffle->setInput(1, *new_shape);
2725   *output = shuffle->getOutput(0);
2726   return Status::OK();
2727 #else
2728   return errors::Unavailable(
2729       "Dynamic shape input requires TensorRT 6 or above");
2730 #endif
2731 }
2732 
DynamicExpandDims(nvinfer1::ITensor * input,const nvinfer1::Dims & dims,int axis,OpConverterParams * params,nvinfer1::ITensor ** output,absl::optional<int> op_instance)2733 Status Converter::DynamicExpandDims(nvinfer1::ITensor* input,
2734                                     const nvinfer1::Dims& dims, int axis,
2735                                     OpConverterParams* params,
2736                                     nvinfer1::ITensor** output,
2737                                     absl::optional<int> op_instance) {
2738   if (params->validation_only) {
2739     *output = nullptr;
2740     return errors::Internal(
2741         "DynamicExpandDims should not be used during validation");
2742   }
2743   std::vector<std::pair<int, int>> slices;
2744   std::vector<int> extra_dims;
2745   if (axis != 0) {
2746     slices.push_back(std::pair<int, int>{0, axis});
2747     extra_dims.push_back(-1);
2748   }
2749   extra_dims.push_back(1);
2750   if (axis != dims.nbDims) {
2751     slices.push_back(std::pair<int, int>{axis, dims.nbDims});
2752   }
2753   return DynamicReshape(input, slices, params, output, extra_dims, op_instance);
2754 }
2755 
SqueezeTensor(nvinfer1::ITensor * input,std::vector<int> * input_dims,OpConverterParams * params,nvinfer1::ITensor ** output)2756 Status Converter::SqueezeTensor(nvinfer1::ITensor* input,
2757                                 std::vector<int>* input_dims,
2758                                 OpConverterParams* params,
2759                                 nvinfer1::ITensor** output) {
2760   // If the remaining dimensions of a squeeze operation have dynamic sizes, we
2761   // need to use TRT ops to build the result shape for the squeeze operation.
2762   // This is because IShuffleLayer::setReshapeDimensions treats -1 as a special
2763   // value.
2764   if (!params->use_implicit_batch && !HasStaticShape(*input_dims)) {
2765     std::vector<std::pair<int, int>> slices;
2766     for (int i = 0; i < input_dims->size(); i++) {
2767       if (input_dims->at(i) != 0) {
2768         slices.push_back(std::pair<int, int>(i, i + 1));
2769       }
2770     }
2771     return DynamicReshape(input, slices, params, output);
2772   }
2773   // Remove all dims which are equal to 0.
2774   input_dims->erase(std::remove(input_dims->begin(), input_dims->end(), 0),
2775                     input_dims->end());
2776   // Reshape tensor.
2777   nvinfer1::Dims new_dims;
2778   VLOG(2) << "input_dims" << input_dims;
2779   TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(*input_dims, &new_dims));
2780   TF_RETURN_IF_ERROR(PrepareTensorForShape(
2781       params->converter, TRT_TensorOrWeights(input), new_dims,
2782       /*validation_only=*/false, output, params->node_def));
2783   return Status::OK();
2784 }
2785 
ConvertSqueeze(OpConverterParams * params)2786 Status ConvertSqueeze(OpConverterParams* params) {
2787   const auto& inputs = params->inputs;
2788   const auto& node_def = params->node_def;
2789   TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}}));
2790   TF_RETURN_IF_ERROR(AllowDataTypes(
2791       *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}));
2792   // Get input shape.
2793   const TRT_TensorOrWeights& input_tensor = inputs.at(0);
2794   const nvinfer1::Dims dims = input_tensor.GetTrtDims();
2795   std::vector<int> input_dims(dims.d, dims.d + dims.nbDims);
2796   TFAttrs attrs(node_def);
2797   auto squeeze_dims = attrs.get<std::vector<int64>>("squeeze_dims");
2798   if (squeeze_dims.empty()) {
2799     if (params->use_implicit_batch || !HasStaticShape(dims)) {
2800       return errors::Unimplemented(
2801           "Squeeze is not implemented for empty squeeze_dims, at ",
2802           node_def.name());
2803     } else {
2804       // explicit batch mode with static input shape we squeeze all singleton
2805       // dimensions
2806       for (int& dim : input_dims) {
2807         if (dim == 1) {
2808           // Mark it for removal by setting it to 0
2809           dim = 0;
2810         }
2811       }
2812     }
2813   } else {
2814     std::vector<int> trt_axes;
2815     trt_axes.reserve(squeeze_dims.size());
2816     for (int tf_axis : squeeze_dims) {
2817       // If the axis is valid, then convert it to TRT axis, otherwise abort
2818       // conversion.
2819       int trt_axis;
2820       TF_RETURN_IF_ERROR(ConvertAxis(tf_axis, dims.nbDims, node_def.name(),
2821                                      params->use_implicit_batch, &trt_axis));
2822       // Make sure target dimension is size 1 or unknown size (-1)
2823       if (input_dims[trt_axis] != -1 && input_dims[trt_axis] != 1) {
2824         return errors::InvalidArgument(
2825             "Dimension ", tf_axis, " with size ", input_dims[trt_axis],
2826             " cannot be squeezed because it must be size 1, at ",
2827             node_def.name());
2828       }
2829       trt_axes.push_back(trt_axis);
2830     }
2831     // Mark axes to remove by setting them to 0.
2832     for (int axis : trt_axes) {
2833       input_dims[axis] = 0;
2834     }
2835   }
2836   if (params->validation_only) return Status::OK();
2837 
2838   nvinfer1::ITensor* output_tensor = nullptr;
2839   TF_RETURN_IF_ERROR(params->converter->SqueezeTensor(
2840       input_tensor.tensor(), &input_dims, params, &output_tensor));
2841   params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
2842   return Status::OK();
2843 }
2844 
2845 template <typename Container>
ConvertStridedSliceHelper(OpConverterParams * params,const TRT_TensorOrWeights & input,Container begin,Container size,const Container & stride,const nvinfer1::Dims * final_shape=nullptr,absl::optional<int> op_instance=absl::nullopt)2846 Status ConvertStridedSliceHelper(
2847     OpConverterParams* params, const TRT_TensorOrWeights& input,
2848     Container begin, Container size, const Container& stride,
2849     const nvinfer1::Dims* final_shape = nullptr,
2850     absl::optional<int> op_instance = absl::nullopt) {
2851   const auto& node_def = params->node_def;
2852   // Get input dims.
2853   nvinfer1::Dims dims = input.GetTrtDims();
2854   std::vector<int> input_dims(dims.d, dims.d + dims.nbDims);
2855   // Temporarily add batch dimension so that indexes line up properly.
2856   input_dims.insert(input_dims.begin(), -1);
2857   // Check bounds.
2858   for (int i = 1; i < input_dims.size(); i++) {
2859     if (begin[i] < 0 || begin[i] > input_dims[i]) {
2860       return errors::InvalidArgument("\"begin\" for dimension ",
2861                                      std::to_string(i), " in ", node_def.op(),
2862                                      " is out of range, at ", node_def.name());
2863     }
2864     const int end = begin[i] + size[i];
2865     if (end < 0 || end > input_dims[i]) {
2866       return errors::InvalidArgument("\"begin\" + \"size\" for dimension ",
2867                                      std::to_string(i), " in ", node_def.op(),
2868                                      " is out of range, at ", node_def.name());
2869     }
2870     if (size[i] <= 0) {
2871       return errors::InvalidArgument("\"size\" cannot be negative or zero for ",
2872                                      node_def.op(), ", at ", node_def.name());
2873     }
2874   }
2875 
2876 // TRT 5.1 adds ISliceLayer. For older versions, we attempt to use the
2877 // padding layer with negative padding.
2878 #if IS_TRT_VERSION_GE(5, 1, 3, 1)
2879   nvinfer1::Dims begin_dims, size_dims, stride_dims;
2880   TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(begin, &begin_dims,
2881                                                /*ignore_first_dim=*/true));
2882   TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(size, &size_dims,
2883                                                /*ignore_first_dim=*/true));
2884   TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(stride, &stride_dims,
2885                                                /*ignore_first_dim=*/true));
2886   if (params->validation_only) return Status::OK();
2887 
2888   nvinfer1::ISliceLayer* layer = params->converter->network()->addSlice(
2889       *input.tensor(), begin_dims, size_dims, stride_dims);
2890   params->converter->SetLayerName(layer, params->node_def, "slice",
2891                                   op_instance);
2892   nvinfer1::ITensor* tensor = layer->getOutput(0);
2893   // Reshape for shrink_axis.
2894   if (final_shape) {
2895     TF_RETURN_IF_ERROR(PrepareTensorForShape(
2896         params->converter, TRT_TensorOrWeights(tensor), *final_shape,
2897         /*validation_only=*/false, &tensor, node_def, op_instance));
2898   }
2899   params->outputs->push_back(TRT_TensorOrWeights(tensor));
2900   return Status::OK();
2901 #else
2902   // Use IPaddingLayer.
2903   // Strides must be 1 in this case.
2904   for (int x : stride) {
2905     if (x != 1) {
2906       return errors::Unimplemented(
2907           "Strides other than 1 are not supported with this version of TRT, "
2908           "at ",
2909           node_def.name());
2910     }
2911   }
2912   // Rank must be 2, 3 or 4.
2913   if (input_dims.size() > 4) {
2914     return errors::Unimplemented(node_def.op(),
2915                                  " for tensors with rank > 4 is not supported "
2916                                  "in this version of TRT, at ",
2917                                  node_def.name());
2918   }
2919   // Reshape if necessary to 4-D, since IPaddingLayer requires a 4-D input.
2920   const bool need_reshape = (input_dims.size() != 4);
2921   int reshape_dims_added = 0;
2922   nvinfer1::Dims reshape_dims;
2923   if (need_reshape) {
2924     // Add new dims after batch dim until tensor is 4D.
2925     while (input_dims.size() < 4) {
2926       input_dims.insert(input_dims.begin() + 1, 1);
2927       begin.insert(begin.begin() + 1, 0);
2928       size.insert(size.begin() + 1, 1);
2929       reshape_dims_added++;
2930     }
2931     TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(input_dims, &reshape_dims,
2932                                                  /*ignore_first_dim=*/true));
2933   }
2934   // Find dimensions which need to be sliced.
2935   std::vector<int> pad_dims;
2936   for (int i = 1; i < input_dims.size(); i++) {
2937     if ((begin[i] != 0) || (begin[i] + size[i] != input_dims[i])) {
2938       pad_dims.push_back(i);
2939     }
2940   }
2941   if (pad_dims.empty()) {
2942     // No dimensions are changed, so this is a no-op. We could just return the
2943     // input without creating a new layer. TRT will crash if an empty engine
2944     // with no layers is attempted to be created, so we add a no-op shuffle to
2945     // prevent our unit tests from breaking.
2946     // TODO(tmorris): Allow empty engines in the unit tests and return the input
2947     // as output here.
2948     if (params->validation_only) return Status::OK();
2949     nvinfer1::IShuffleLayer* layer =
2950         params->converter->network()->addShuffle(*input.tensor());
2951     SetLayerName(layer, params->node_def, "shuffle", op_instance);
2952     params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0)));
2953     return Status::OK();
2954   } else if (pad_dims.size() == 1) {
2955     // Only one dim is modified but we have to have 2, mark a second dim which
2956     // will have padding of 0. The dim we add is chosen to avoid an unnecessary
2957     // transpose.
2958     if (pad_dims[0] != 2) {
2959       pad_dims.push_back(2);
2960     } else {
2961       pad_dims.push_back(3);
2962     }
2963   } else if (pad_dims.size() > 2) {
2964     return errors::Unimplemented(
2965         node_def.op(),
2966         " can only modify up to 2 dimensions in this version of TRT, at ",
2967         node_def.name());
2968   }
2969   std::sort(pad_dims.begin(), pad_dims.end());
2970   // Convert to pre/post padding values. Since TRT does not have a StridedSlice
2971   // or Slice layer prior to 5.1, we instead create an IPaddingLayer with
2972   // negative padding.
2973   nvinfer1::DimsHW pre_padding, post_padding;
2974   for (int i = 0; i < pad_dims.size(); i++) {
2975     const int axis = pad_dims[i];
2976     pre_padding.d[i] = -begin[axis];
2977     post_padding.d[i] = (begin[axis] + size[axis]) - input_dims[axis];
2978   }
2979 
2980   // IPaddingLayer will always apply the padding to dims 2,3 (input format is
2981   // NCHW).
2982   const bool need_transpose = !(pad_dims[0] == 2 && pad_dims[1] == 3);
2983   std::vector<int> transpose_order(input_dims.size());
2984   std::vector<int> inv_transpose_order(input_dims.size());
2985   if (need_transpose) {
2986     if (pad_dims[0] == 1 && pad_dims[1] == 3) {
2987       transpose_order = {0, 2, 1, 3};
2988       inv_transpose_order = {0, 2, 1, 3};
2989     } else if (pad_dims[0] == 1 && pad_dims[1] == 2) {
2990       transpose_order = {0, 3, 1, 2};
2991       inv_transpose_order = {0, 2, 3, 1};
2992     }
2993   }
2994   if (params->validation_only) return Status::OK();
2995 
2996   // Start conversion.
2997   nvinfer1::ITensor* tensor = input.tensor();
2998   if (need_reshape) {
2999     TF_RETURN_IF_ERROR(PrepareTensorForShape(
3000         params->converter, input, reshape_dims, /*validation_only=*/false,
3001         &tensor, node_def, op_instance));
3002   }
3003   if (need_transpose) {
3004     TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
3005         tensor, transpose_order, &tensor, node_def, "for_pad", op_instance));
3006   }
3007   // Add padding layer
3008   nvinfer1::IPaddingLayer* layer = params->converter->network()->addPadding(
3009       *tensor, pre_padding, post_padding);
3010   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
3011   SetLayerName(layer, params->node_def, "pad");
3012   params->converter->MarkQuantizationRangesAsInferrable(tensor,
3013                                                         layer->getOutput(0));
3014   tensor = layer->getOutput(0);
3015   // Restore transpose
3016   if (need_transpose) {
3017     TF_RETURN_IF_ERROR(
3018         params->converter->TransposeTensor(tensor, inv_transpose_order, &tensor,
3019                                            node_def, "after_pad", op_instance));
3020   }
3021   // Reshape for shrink_axis.
3022   if (final_shape) {
3023     TF_RETURN_IF_ERROR(PrepareTensorForShape(
3024         params->converter, TRT_TensorOrWeights(tensor), *final_shape,
3025         /*validation_only=*/false, &tensor, node_def, op_instance));
3026   } else if (need_reshape) {
3027     // Restore reshape.
3028     // Calculate output dimensions
3029     for (int i = 0; i < pad_dims.size(); i++) {
3030       const int axis = pad_dims[i];
3031       input_dims[axis] = size[axis];
3032     }
3033     // Remove added 1 dimensions
3034     for (int i = 0; i < reshape_dims_added; i++) {
3035       int value = input_dims[1];
3036       if (value != 1) {
3037         return errors::Internal("StridedSlice error when reshaping, at ",
3038                                 node_def.name());
3039       }
3040       input_dims.erase(input_dims.begin() + 1);
3041     }
3042 
3043     nvinfer1::Dims new_dims;
3044     TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(input_dims, &new_dims,
3045                                                  /*ignore_first_dim=*/true));
3046     TF_RETURN_IF_ERROR(PrepareTensorForShape(
3047         params->converter, TRT_TensorOrWeights(tensor), new_dims,
3048         /*validation_only=*/false, &tensor, node_def, op_instance));
3049   }
3050 
3051   params->outputs->push_back(TRT_TensorOrWeights(tensor));
3052   return Status::OK();
3053 #endif
3054 }
3055 
ConvertSlice(OpConverterParams * params)3056 Status ConvertSlice(OpConverterParams* params) {
3057   const auto& inputs = params->inputs;
3058   const auto& node_def = params->node_def;
3059   TF_RETURN_IF_ERROR(CheckInputsWeights(
3060       *params, {{"input", false}, {"begin", true}, {"size", true}}));
3061   TF_RETURN_IF_ERROR(AllowDataTypes(
3062       *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}));
3063   std::vector<int> begin = inputs.at(1).weights().ToVector<int>();
3064   std::vector<int> size = inputs.at(2).weights().ToVector<int>();
3065   // Get input dims.
3066   nvinfer1::Dims dims = inputs.at(0).GetTrtDims();
3067   std::vector<int> input_dims(dims.d, dims.d + dims.nbDims);
3068   // Add batch dimension so that indexes line up properly.
3069   input_dims.insert(input_dims.begin(), inputs.at(0).batch_size());
3070   if (!AllLengthsEqual({input_dims, begin, size})) {
3071     return errors::InvalidArgument(
3072         "Length of begin and size arguments must equal rank of input for "
3073         "Slice, at ",
3074         node_def.name());
3075   }
3076   // Check that batch dimension is unmodified.
3077   const bool begin_is_modified = begin[0] != 0;
3078   // If size[0]s is not -1, we can only know if the batch dimension is
3079   // unmodified when the batch size is defined. When the batch size is
3080   // undefined, we don't convert to be safe.
3081   const bool batch_size_is_defined = input_dims[0] > 0;
3082   const bool size_is_modified =
3083       size[0] != -1 && (!batch_size_is_defined || size[0] != input_dims[0]);
3084   if (begin_is_modified || size_is_modified) {
3085     return errors::Unimplemented(
3086         "TensorRT does not allow modifications to the batch dimension, at ",
3087         node_def.name());
3088   }
3089   // Size of -1 signifies to take all remaining elements.
3090   for (int i = 1; i < input_dims.size(); i++) {
3091     if (size[i] == -1) {
3092       size[i] = input_dims[i] - begin[i];
3093     }
3094   }
3095   // Stride is 1 for all dims.
3096   std::vector<int> stride(begin.size(), 1);
3097   return ConvertStridedSliceHelper(params, inputs.at(0), begin, size, stride);
3098 }
3099 
ConvertStridedSlice(OpConverterParams * params)3100 Status ConvertStridedSlice(OpConverterParams* params) {
3101   const auto& inputs = params->inputs;
3102   const auto& node_def = params->node_def;
3103   TF_RETURN_IF_ERROR(CheckInputsWeights(
3104       *params,
3105       {{"input", false}, {"begin", true}, {"end", true}, {"strides", true}}));
3106   TF_RETURN_IF_ERROR(AllowDataTypes(
3107       *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}));
3108 
3109   TFAttrs attrs(node_def);
3110   // new_axis_mask is not supported.
3111   const int32 new_axis_mask = attrs.get<int64>("new_axis_mask");
3112   if (new_axis_mask != 0) {
3113     return errors::Unimplemented(
3114         "new_axis_mask is not supported for StridedSlice, at ",
3115         node_def.name());
3116   }
3117   const int32 begin_mask = attrs.get<int64>("begin_mask");
3118   const int32 end_mask = attrs.get<int64>("end_mask");
3119   const int32 ellipsis_mask = attrs.get<int64>("ellipsis_mask");
3120   const int32 shrink_axis_mask = attrs.get<int64>("shrink_axis_mask");
3121 
3122   // Get input dims.
3123   nvinfer1::Dims dims = inputs.at(0).GetTrtDims();
3124   std::vector<int64> input_dims(dims.d, dims.d + dims.nbDims);
3125   // Add batch dimension so that indexes line up properly. Set it to -1 if it's
3126   // unknown, so ValidateStridedSliceOp() can handle it correctly below.
3127   input_dims.insert(input_dims.begin(),
3128                     std::max(-1, inputs.at(0).batch_size()));
3129 
3130   const TRT_ShapedWeights& begin_weights = inputs.at(1).weights();
3131   const TRT_ShapedWeights& end_weights = inputs.at(2).weights();
3132   const TRT_ShapedWeights& stride_weights = inputs.at(3).weights();
3133   if (!AllLengthsEqual({begin_weights.ToVector<int>(),
3134                         end_weights.ToVector<int>(),
3135                         stride_weights.ToVector<int>()})) {
3136     return errors::InvalidArgument(
3137         "Length of begin, end, and stride must be equal, at ", node_def.name());
3138   }
3139 
3140   PartialTensorShape input_shape(input_dims);
3141   PartialTensorShape processing_shape;
3142   PartialTensorShape final_shape;
3143   bool is_identity;
3144   bool is_simple_slice;
3145   bool slice_dim0;
3146   absl::InlinedVector<int64, 4> begin;
3147   absl::InlinedVector<int64, 4> end;
3148   absl::InlinedVector<int64, 4> strides;
3149   TF_RETURN_IF_ERROR(ValidateStridedSliceOp(
3150       &begin_weights.GetTensor(), &end_weights.GetTensor(),
3151       stride_weights.GetTensor(), input_shape, begin_mask, end_mask,
3152       ellipsis_mask, new_axis_mask, shrink_axis_mask, &processing_shape,
3153       &final_shape, &is_identity, &is_simple_slice, &slice_dim0, &begin, &end,
3154       &strides));
3155 
3156   // Negative or zero strides currently not supported.
3157   for (int stride : strides) {
3158     if (stride <= 0) {
3159       return errors::Unimplemented(
3160           "Negative or zero stride values are not supported for StridedSlice, "
3161           "at ",
3162           node_def.name());
3163     }
3164   }
3165 
3166   // If batch dimension is covered by the ellipsis mask, it means it's left
3167   // untouched. Otherwise we check whether it modifies the batch dimension here.
3168   if (!(ellipsis_mask & 1) ||
3169       begin_weights.shape_.nbDims >= input_dims.size()) {
3170     // Check that batch dimension is unmodified. We need to use the expanded
3171     // begin/end/strides array since the original array may be incorrect when
3172     // (ellipsis_mask&1)==1.
3173     const bool begin_is_modified = !(begin_mask & 1) && (begin[0] != 0);
3174     const bool stride_is_modified = (strides[0] != 1);
3175     // If the batch size is -1 and the end mask is not set, we can only know if
3176     // the batch dimension is unmodified when the batch size is defined. When
3177     // the batch size is undefined, we don't convert to be safe.
3178     const bool batch_size_is_defined = (input_dims[0] > 0);
3179     const bool end_is_modified =
3180         !(end_mask & 1) && (!batch_size_is_defined || end[0] != input_dims[0]);
3181     if (begin_is_modified || stride_is_modified || end_is_modified) {
3182       return errors::Unimplemented(
3183           "TensorRT does not allow modifications to the batch dimension, at ",
3184           node_def.name());
3185     }
3186   }
3187   // Can't shrink axis on batch dimension.
3188   if (shrink_axis_mask & 1) {
3189     return errors::Unimplemented(
3190         "TensorRT does not allow modifications to the batch dimension, at ",
3191         node_def.name());
3192   }
3193   // TRT Slice layer uses (begin, size) instead of (begin, end)
3194   absl::InlinedVector<int64, 4> size(input_dims.size());
3195   for (int i = 0; i < input_dims.size(); i++) {
3196     // Divide by stride (round up)
3197     size[i] = (end[i] - begin[i] + strides[i] - 1) / strides[i];
3198   }
3199 
3200   // shrink_axis_mask requires a reshape after the slice.
3201   nvinfer1::Dims final_shape_dims;
3202   nvinfer1::Dims* final_shape_dims_ptr = nullptr;
3203   if (shrink_axis_mask) {
3204     final_shape_dims =
3205         TensorShapeToTrtDims(final_shape, /*ignore_first_dim=*/true);
3206     final_shape_dims_ptr = &final_shape_dims;
3207   }
3208   return ConvertStridedSliceHelper(params, inputs.at(0), begin, size, strides,
3209                                    final_shape_dims_ptr);
3210 }
3211 
ConvertConv2D(OpConverterParams * params)3212 Status ConvertConv2D(OpConverterParams* params) {
3213   return ConvertConv2DHelper(params, 1, /*is_conv2d_backprop_input=*/false);
3214 }
3215 
ConvertConv2DDepthwise(OpConverterParams * params)3216 Status ConvertConv2DDepthwise(OpConverterParams* params) {
3217   return ConvertConv2DHelper(params, 0, /*is_conv2d_backprop_input=*/false);
3218 }
3219 
ConvertConv2DBackpropInput(OpConverterParams * params)3220 Status ConvertConv2DBackpropInput(OpConverterParams* params) {
3221   return ConvertConv2DHelper(params, 1, /*is_conv2d_backprop_input=*/true);
3222 }
3223 
3224 #if IS_TRT_VERSION_GE(6, 0, 0, 0)
ConvertConv3DHelper(OpConverterParams * params,int group,bool is_conv3d_backprop_input=false)3225 Status ConvertConv3DHelper(OpConverterParams* params, int group,
3226                            bool is_conv3d_backprop_input = false) {
3227   const int kNumDims = 5;
3228   const auto& inputs = params->inputs;
3229   const auto& node_def = params->node_def;
3230   TRT_TensorOrWeights backprop_output_size;
3231   nvinfer1::ITensor* tensor = nullptr;
3232   if (is_conv3d_backprop_input) {
3233     // In the case when Conv3dBackpropInput is used for conv3d_transpose, these
3234     // inputs correspond to: output size, filter, and input.
3235     TF_RETURN_IF_ERROR(CheckInputsWeights(
3236         *params,
3237         {{"input_sizes", true}, {"filter", true}, {"out_backprop", false}}));
3238     backprop_output_size = inputs.at(0);
3239     tensor = inputs.at(2).tensor();
3240   } else {
3241     TF_RETURN_IF_ERROR(
3242         CheckInputsWeights(*params, {{"input", false}, {"filter", true}}));
3243     tensor = inputs.at(0).tensor();
3244   }
3245   TF_RETURN_IF_ERROR(
3246       AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
3247   const TRT_ShapedWeights weights_drsck = inputs.at(1).weights();
3248   if (weights_drsck.shape_.nbDims != kNumDims) {
3249     return errors::InvalidArgument("Conv3D expects kernel of dimension 5, at ",
3250                                    node_def.name());
3251   }
3252   TFAttrs attrs(node_def);
3253   auto data_format = attrs.get<string>("data_format");
3254   const bool is_ndhwc = (data_format == "NDHWC");  // Or NCDHW 01234 - > 02341
3255   const int d_index = is_ndhwc ? 1 : 2;
3256   const int h_index = is_ndhwc ? 2 : 3;
3257   const int w_index = is_ndhwc ? 3 : 4;
3258   const int c_index = is_ndhwc ? 4 : 1;
3259   auto tf_dilations = attrs.get<std::vector<int64>>("dilations");
3260   if (tf_dilations.size() != kNumDims) {
3261     return errors::InvalidArgument(
3262         "Convolution dilations field must specify 5 dimensions, at ",
3263         node_def.name());
3264   }
3265   if (tf_dilations[0] != 1 || tf_dilations[c_index] != 1) {
3266     return errors::Unimplemented(
3267         "Dilation rate must be 1 for batch and channel dimensions, at ",
3268         node_def.name());
3269   }
3270 
3271   const nvinfer1::Dims3 dilation_dhw(
3272       tf_dilations[d_index], tf_dilations[h_index], tf_dilations[w_index]);
3273   if (is_conv3d_backprop_input &&
3274       (dilation_dhw.d[0] != 1 || dilation_dhw.d[1] != 1 ||
3275        dilation_dhw.d[2] != 1)) {
3276     return errors::Unimplemented(
3277         "Dilation with Conv3DBackpropInputV2 (conv3d_transpose) is not "
3278         "supported",
3279         ", at ", node_def.name());
3280   }
3281 
3282   const auto tf_stride = attrs.get<std::vector<int64>>("strides");
3283   if (tf_stride.size() != kNumDims) {
3284     return errors::InvalidArgument(
3285         "Convolution strides field must specify 5 dimensions, at ",
3286         node_def.name());
3287   }
3288   if (tf_stride[0] != 1 || tf_stride[c_index] != 1) {
3289     return errors::Unimplemented(
3290         "Stride must be 1 for batch and channel dimensions, at ",
3291         node_def.name());
3292   }
3293 
3294   const nvinfer1::Dims3 stride_dhw(tf_stride[d_index], tf_stride[h_index],
3295                                    tf_stride[w_index]);
3296   const auto tensor_dim = tensor->getDimensions();
3297 
3298   // Asymmetric padding on Deconv not supported for now
3299   if (is_conv3d_backprop_input && attrs.get<string>("padding") == "SAME") {
3300     TRT_ShapedWeights weights =
3301         params->weight_store->GetTempWeights(weights_drsck);
3302 
3303     nvinfer1::Dims3 effective_kernel_size(
3304         weights.shape_.d[0] +
3305             (weights.shape_.d[0] - 1) * (dilation_dhw.d[0] - 1),  // D
3306         weights.shape_.d[1] +
3307             (weights.shape_.d[1] - 1) * (dilation_dhw.d[1] - 1),  // R
3308         weights.shape_.d[2] +
3309             (weights.shape_.d[2] - 1) * (dilation_dhw.d[2] - 1)  // S
3310     );
3311 
3312     const auto output_size_weights =
3313         static_cast<int*>(backprop_output_size.weights().GetValues());
3314     const std::vector<int64_t> input_dims = {output_size_weights[d_index],
3315                                              output_size_weights[h_index],
3316                                              output_size_weights[w_index]};
3317 
3318     const std::vector<std::pair<int, int>> padding =
3319         CreateSamePadding(stride_dhw, effective_kernel_size, input_dims);
3320 
3321     if (padding[0].first != padding[0].second ||
3322         padding[1].first != padding[1].second ||
3323         padding[2].first != padding[2].second) {
3324       return errors::Unimplemented(
3325           "Asymmetric padding with Conv3DBackpropInputV2 (conv3d_transpose) is "
3326           "not supported, at ",
3327           node_def.name());
3328     }
3329   }
3330 
3331   // Finished validation checks
3332   if (params->validation_only) return Status::OK();
3333 
3334   // Transpose to NCDHW (NCDHW is required for IConvLayer).
3335   const bool need_transpose = is_ndhwc;
3336   if (need_transpose) {
3337     TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
3338         tensor, {0, 4, 1, 2, 3}, &tensor, node_def, "to_NCDHW"));
3339   }
3340 
3341   // group == 0 signifies that this is a depthwise convolution, so set
3342   // num_groups to size of input's channel dim. For a non-depthwise conv,
3343   // num_groups will be 1.
3344   const int num_groups = (group == 0) ? tensor_dim.d[0] : group;
3345 
3346   // For conv, TF weights are DRSCK, and TRT expects KCDRS.
3347   // For backprop, TF weights are DRSKC, and TRT expects KCDRS.
3348   // Therefore, this reorder will work for both cases.
3349   TRT_ShapedWeights weights =
3350       params->weight_store->GetTempWeights(weights_drsck);
3351   ReorderDRSCKToKCDRS(weights_drsck, &weights, num_groups);
3352   TRT_ShapedWeights biases(weights.TrtDType());
3353   const int output_axis = is_conv3d_backprop_input ? 1 : 0;
3354   const int noutput = weights.shape_.d[output_axis] * num_groups;
3355   nvinfer1::Dims3 kernel_size_drs(weights.shape_.d[2],  // D
3356                                   weights.shape_.d[3],  // R
3357                                   weights.shape_.d[4]   // S
3358   );
3359 
3360   // Add convolution.
3361   nvinfer1::ILayer* conv_layer = nullptr;
3362   if (is_conv3d_backprop_input) {
3363     nvinfer1::IDeconvolutionLayer* layer =
3364         params->converter->network()->addDeconvolutionNd(
3365             *tensor, noutput, kernel_size_drs, weights.GetTrtWeights(),
3366             biases.GetTrtWeights());
3367     TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
3368     layer->setStrideNd(stride_dhw);  // change to nd set stride
3369 
3370     // TensorRT 5.1.3 added support for padding modes.
3371     if (attrs.get<string>("padding") == "SAME") {
3372       VLOG(2) << "Using SAME padding";
3373       // SAME_UPPER means that post padding is preferred.
3374       layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER);
3375     }
3376 
3377     layer->setNbGroups(num_groups);
3378     conv_layer = layer;
3379   } else {
3380     nvinfer1::IConvolutionLayer* layer =
3381         params->converter->network()->addConvolutionNd(
3382             *tensor, noutput, kernel_size_drs, weights.GetTrtWeights(),
3383             biases.GetTrtWeights());
3384     TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
3385     layer->setStrideNd(stride_dhw);
3386 
3387     if (attrs.get<string>("padding") == "SAME") {
3388       VLOG(2) << "Using SAME padding";
3389       layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER);
3390     }
3391 
3392     layer->setNbGroups(num_groups);
3393     layer->setDilationNd(dilation_dhw);
3394     conv_layer = layer;
3395   }
3396   params->converter->SetLayerName(conv_layer, node_def, "conv");
3397   nvinfer1::ITensor* output_tensor = conv_layer->getOutput(0);
3398 
3399   // Restore transpose.
3400   if (need_transpose) {
3401     TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
3402         output_tensor, {0, 2, 3, 4, 1}, &output_tensor, node_def, "to_NDHWC"));
3403   }
3404   params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
3405   return Status::OK();
3406 }
3407 
ConvertConv3D(OpConverterParams * params)3408 Status ConvertConv3D(OpConverterParams* params) {
3409   return ConvertConv3DHelper(params, 1, /*is_conv3d_backprop_input=*/false);
3410 }
3411 
ConvertConv3DBackpropInputV2(OpConverterParams * params)3412 Status ConvertConv3DBackpropInputV2(OpConverterParams* params) {
3413   return ConvertConv3DHelper(params, 1, /*is_conv3d_backprop_input=*/true);
3414 }
3415 
ConvertPool3D(OpConverterParams * params)3416 Status ConvertPool3D(OpConverterParams* params) {
3417   const int kNumDims = 5;
3418   const auto& inputs = params->inputs;
3419   const auto& node_def = params->node_def;
3420   TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}}));
3421   TF_RETURN_IF_ERROR(
3422       AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
3423   nvinfer1::PoolingType type;
3424   if (node_def.op() == "MaxPool3D") {
3425     type = nvinfer1::PoolingType::kMAX;
3426   } else if (node_def.op() == "AvgPool3D") {
3427     type = nvinfer1::PoolingType::kAVERAGE;
3428   } else {
3429     return errors::Unimplemented("Unsupported pooling type: ", node_def.op(),
3430                                  ", at ", node_def.name());
3431   }
3432   TFAttrs attrs(node_def);
3433   const string padding_type = attrs.get<string>("padding");
3434   if ((padding_type != "SAME") && (padding_type != "VALID")) {
3435     return errors::Unimplemented("Unsupported padding type: ", padding_type,
3436                                  ", at ", node_def.name());
3437   }
3438   const auto data_format = attrs.get<string>("data_format");
3439   const bool is_ndhwc = (data_format == "NDHWC");
3440   const int c_index = is_ndhwc ? 4 : 1;
3441   const int d_index = is_ndhwc ? 1 : 2;
3442   const int h_index = is_ndhwc ? 2 : 3;
3443   const int w_index = is_ndhwc ? 3 : 4;
3444   const auto tf_stride = attrs.get<std::vector<int64>>("strides");
3445   if (tf_stride.size() != kNumDims) {
3446     return errors::InvalidArgument(
3447         "Pooling strides field must specify 5 dimensions, at ",
3448         node_def.name());
3449   }
3450   if (tf_stride[0] != 1 || tf_stride[c_index] != 1) {
3451     return errors::Unimplemented(
3452         "stride must be 1 for batch and channel dimensions, at ",
3453         node_def.name());
3454   }
3455   const auto tf_kernel = attrs.get<std::vector<int64>>("ksize");
3456   if (tf_kernel.size() != kNumDims) {
3457     return errors::InvalidArgument(
3458         "Pooling ksize field must specify 5 dimensions, at ", node_def.name());
3459   }
3460   if (tf_kernel[0] != 1 || tf_kernel[c_index] != 1) {
3461     return errors::Unimplemented(
3462         "ksize must be 1 for batch and channel dimensions, at ",
3463         node_def.name());
3464   }
3465   if (params->validation_only) return Status::OK();
3466 
3467   nvinfer1::ITensor* tensor = inputs.at(0).tensor();
3468   if (data_format == "NDHWC") {
3469     // NDHWC => NCDHW
3470     TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
3471         tensor, {0, 4, 1, 2, 3}, &tensor, node_def, "to_NCDHW"));
3472   }
3473 
3474   const nvinfer1::Dims3 stride(tf_stride[d_index], tf_stride[h_index],
3475                                tf_stride[w_index]);
3476   const nvinfer1::Dims3 ksize(tf_kernel[d_index], tf_kernel[h_index],
3477                               tf_kernel[w_index]);
3478 
3479   nvinfer1::IPoolingLayer* layer =
3480       params->converter->network()->addPoolingNd(*tensor, type, ksize);
3481   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
3482 
3483   params->converter->MarkQuantizationRangesAsInferrable(tensor,
3484                                                         layer->getOutput(0));
3485 
3486   layer->setStrideNd(stride);
3487   // VALID padding is the default TRT behavior.
3488   if (padding_type == "SAME") {
3489     // SAME_UPPER means that post padding is preferred.
3490     layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER);
3491   }
3492   params->converter->SetLayerName(layer, node_def, "pooling");
3493   nvinfer1::ITensor* output_tensor = layer->getOutput(0);
3494 
3495   if (data_format == "NDHWC") {
3496     // NCDHW => NDHWC
3497     TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
3498         output_tensor, {0, 2, 3, 4, 1}, &output_tensor, node_def, "to_NDHWC"));
3499   }
3500 
3501   params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
3502   return Status::OK();
3503 }
3504 #endif  // #if IS_TRT_VERSION_GE(6, 0, 0, 0)
3505 
ConvertFusedConv2DBiasActivation(OpConverterParams * params)3506 Status ConvertFusedConv2DBiasActivation(OpConverterParams* params) {
3507   const auto& inputs = params->inputs;
3508   const auto& node_def = params->node_def;
3509 
3510   TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false},
3511                                                   {"filter", true},
3512                                                   {"bias", true},
3513                                                   {"side_input", true},
3514                                                   {"conv_input_scale", true},
3515                                                   {"side_input_scale", true}}));
3516   nvinfer1::ITensor* tensor = inputs.at(0).tensor();
3517   TF_RETURN_IF_ERROR(
3518       AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
3519   TRT_ShapedWeights weights = inputs.at(1).weights();
3520   if (weights.shape_.nbDims != 4) {
3521     return errors::InvalidArgument(
3522         "FusedConv2DBiasActivation expects kernel of dimension 4, at " +
3523         node_def.name());
3524   }
3525   TFAttrs attrs(node_def);
3526   auto data_format = attrs.get<string>("data_format");
3527   if (data_format != "NHWC" && data_format != "NCHW") {
3528     return errors::InvalidArgument("Unsupported data_format:", data_format,
3529                                    " at ", node_def.name());
3530   }
3531 
3532   int c_index = (data_format == "NHWC") ? 3 : 1;
3533   int h_index = (data_format == "NHWC") ? 1 : 2;
3534   int w_index = (data_format == "NHWC") ? 2 : 3;
3535   auto tf_dilations = attrs.get<std::vector<int64>>("dilations");
3536   if (tf_dilations.size() != 4) {
3537     return errors::InvalidArgument(
3538         "Convolution dilations field must specify 4 dimensions, at ",
3539         node_def.name());
3540   }
3541   if (tf_dilations[0] != 1 || tf_dilations[c_index] != 1) {
3542     return errors::Unimplemented(
3543         "Dilation rate must be 1 for batch and channel dimensions, at ",
3544         node_def.name());
3545   }
3546   const nvinfer1::DimsHW dilation(tf_dilations[h_index], tf_dilations[w_index]);
3547 
3548   const auto tf_stride = attrs.get<std::vector<int64>>("strides");
3549   if (tf_stride.size() != 4) {
3550     return errors::InvalidArgument(
3551         "Convolution strides field must specify 4 dimensions, at ",
3552         node_def.name());
3553   }
3554   if (tf_stride[0] != 1 || tf_stride[c_index] != 1) {
3555     return errors::Unimplemented(
3556         "Stride must be 1 for batch and channel dimensions, at ",
3557         node_def.name());
3558   }
3559   const nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]);
3560   const auto activation_mode = attrs.get<string>("activation_mode");
3561   auto op_pair = ActivationTypeMap()->find(activation_mode);
3562   if (op_pair == ActivationTypeMap()->end() && activation_mode != "None") {
3563     return errors::Unimplemented("Activation mode: ", activation_mode,
3564                                  " not supported at: ", node_def.name());
3565   }
3566 
3567   const auto filter_format = attrs.get<string>("filter_format");
3568   if (filter_format != "HWIO" && filter_format != "OIHW") {
3569     return errors::InvalidArgument("Unsupported filter_format:", filter_format,
3570                                    " at ", node_def.name());
3571   }
3572   // Check that there's no side_input or conv_input_scale.
3573   TRT_ShapedWeights side_input = inputs.at(3).weights();
3574   if (side_input.count() != 0) {
3575     return errors::InvalidArgument(
3576         "FusedConv2DBiasActivation doesn't yet support side_input, at " +
3577         node_def.name());
3578   }
3579   TRT_ShapedWeights conv_input_scale = inputs.at(4).weights();
3580   if (conv_input_scale.count() != 1 ||
3581       conv_input_scale.TrtDType() != nvinfer1::DataType::kFLOAT ||
3582       conv_input_scale.GetSpan<float>()[0] != 1.0) {
3583     return errors::InvalidArgument(
3584         "FusedConv2DBiasActivation doesn't yet support conv_input_scale, at " +
3585         node_def.name());
3586   }
3587   if (params->validation_only) return Status::OK();
3588 
3589   // Transpose to NCHW (NCHW is required for IConvLayer).
3590   const bool need_transpose = (data_format == "NHWC");
3591   if (need_transpose) {
3592     TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
3593         tensor, {0, 3, 1, 2}, &tensor, node_def, "to_NCHW"));
3594   }
3595 
3596   nvinfer1::DimsHW kernel_size;
3597   if (filter_format == "OIHW") {
3598     kernel_size.h() = weights.shape_.d[2];
3599     kernel_size.w() = weights.shape_.d[3];
3600   } else {
3601     // HWIO.
3602     DCHECK_EQ(filter_format, "HWIO");
3603     kernel_size.h() = weights.shape_.d[0];
3604     kernel_size.w() = weights.shape_.d[1];
3605   }
3606 // Before TRT 5.1.3, we have to calculate padding ourselves.
3607 #if !IS_TRT_VERSION_GE(5, 1, 3, 0)
3608   const auto tensor_dim = tensor->getDimensions();
3609   std::vector<int64_t> input_dims;
3610   // Use 1 and 2 because tensor_dim has the dimensions of the transposed
3611   // input.
3612   input_dims = {static_cast<int>(tensor_dim.d[1]),
3613                 static_cast<int>(tensor_dim.d[2])};
3614   std::vector<std::pair<int, int>> padding;
3615   nvinfer1::ITensor* padded_tensor = nullptr;
3616   TF_RETURN_IF_ERROR(Conv2DPaddingHelper(params, attrs, kernel_size, dilation,
3617                                          stride, input_dims, tensor, &padding,
3618                                          &padded_tensor));
3619   tensor = padded_tensor;
3620 #endif
3621 
3622   // Add convolution.
3623   TRT_ShapedWeights biases = inputs.at(2).weights();
3624   nvinfer1::IConvolutionLayer* conv_layer = nullptr;
3625   if (filter_format == "OIHW") {
3626     // Weights are already in the right order.
3627     conv_layer = params->converter->network()->addConvolution(
3628         *tensor, weights.shape_.d[0], kernel_size, weights.GetTrtWeights(),
3629         biases.GetTrtWeights());
3630   } else {
3631     // For conv, TF weights are RSCK, and TRT expects KCRS.
3632     DCHECK_EQ(filter_format, "HWIO");
3633     TRT_ShapedWeights weights_kcrs =
3634         params->weight_store->GetTempWeights(weights);
3635     ReorderRSCKToKCRS(weights, &weights_kcrs, 1);
3636     conv_layer = params->converter->network()->addConvolution(
3637         *tensor, weights.shape_.d[3], kernel_size, weights_kcrs.GetTrtWeights(),
3638         biases.GetTrtWeights());
3639   }
3640   TFTRT_RETURN_ERROR_IF_NULLPTR(conv_layer, node_def.name());
3641   conv_layer->setStride(stride);
3642 #if IS_TRT_VERSION_GE(5, 1, 3, 0)
3643   if (attrs.get<string>("padding") == "SAME") {
3644     conv_layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER);
3645   }
3646 #else
3647   conv_layer->setPadding(nvinfer1::DimsHW{padding[0].first, padding[1].first});
3648 #endif
3649   params->converter->SetLayerName(conv_layer, node_def, "conv");
3650   conv_layer->setNbGroups(1);
3651   conv_layer->setDilation(dilation);
3652   nvinfer1::ITensor* output_tensor = conv_layer->getOutput(0);
3653 
3654   // Add activation if there is one.
3655   if (op_pair != ActivationTypeMap()->end()) {
3656     nvinfer1::IActivationLayer* activation_layer =
3657         params->converter->network()->addActivation(*output_tensor,
3658                                                     op_pair->second);
3659     TFTRT_RETURN_ERROR_IF_NULLPTR(activation_layer, node_def.name());
3660     params->converter->SetLayerName(activation_layer, node_def, "activation");
3661     output_tensor = activation_layer->getOutput(0);
3662   }
3663   // Restore transpose.
3664   if (need_transpose) {
3665     TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
3666         output_tensor, {0, 2, 3, 1}, &output_tensor, node_def, "to_NHWC"));
3667   }
3668   params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
3669   return Status::OK();
3670 }
3671 
ConvertPool(OpConverterParams * params)3672 Status ConvertPool(OpConverterParams* params) {
3673   const auto& inputs = params->inputs;
3674   const auto& node_def = params->node_def;
3675   TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}}));
3676 #if IS_TRT_VERSION_GE(5, 1, 0, 0)
3677   std::set<DataType> allowed_types{DataType::DT_FLOAT, DataType::DT_HALF,
3678                                    DataType::DT_INT8};
3679 #else
3680   std::set<DataType> allowed_types{DataType::DT_FLOAT, DataType::DT_HALF};
3681 #endif
3682   TF_RETURN_IF_ERROR(AllowDataTypes(*params, allowed_types));
3683   nvinfer1::PoolingType type;
3684   if (node_def.op() == "MaxPool") {
3685     type = nvinfer1::PoolingType::kMAX;
3686   } else if (node_def.op() == "AvgPool") {
3687     type = nvinfer1::PoolingType::kAVERAGE;
3688   } else {
3689     return errors::Unimplemented("Unsupported pooling type: ", node_def.op(),
3690                                  ", at ", node_def.name());
3691   }
3692   TFAttrs attrs(node_def);
3693   const string padding_type = attrs.get<string>("padding");
3694   if ((padding_type != "SAME") && (padding_type != "VALID")) {
3695     return errors::Unimplemented("Unsupported padding type: ", padding_type,
3696                                  ", at ", node_def.name());
3697   }
3698   if (params->validation_only) return Status::OK();
3699 
3700   nvinfer1::ITensor* tensor = inputs.at(0).tensor();
3701   int h_index = 2;
3702   int w_index = 3;
3703   const auto data_format = attrs.get<string>("data_format");
3704   if (data_format == "NHWC") {
3705     h_index = 1;
3706     w_index = 2;
3707     TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
3708         tensor, {0, 3, 1, 2}, &tensor, node_def, "to_NCHW"));
3709   }
3710 
3711   const auto tf_stride = attrs.get<std::vector<int64>>("strides");
3712   const nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]);
3713 
3714   const auto tf_kernel = attrs.get<std::vector<int64>>("ksize");
3715   const nvinfer1::DimsHW ksize(tf_kernel[h_index], tf_kernel[w_index]);
3716 
3717 // Before TRT 5.1.3, we have to calculate padding ourselves.
3718 #if !IS_TRT_VERSION_GE(5, 1, 3, 0)
3719   auto tensor_dim = tensor->getDimensions();
3720   std::vector<std::pair<int, int>> padding;
3721   if (padding_type == "SAME") {
3722     // This is NCHW tensor with no batch dimension.
3723     //  1 -> h
3724     //  2 -> w
3725     padding = CreateSamePadding(
3726         stride, ksize,
3727         {static_cast<int>(tensor_dim.d[1]), static_cast<int>(tensor_dim.d[2])});
3728   } else if (padding_type == "VALID") {
3729     padding = {{0, 0}, {0, 0}};
3730   }
3731 #endif
3732 // TensorRT 5.1 added support for asymmetric padding. Before that, we need an
3733 // extra padding layer.
3734 #if !IS_TRT_VERSION_GE(5, 1, 0, 0)
3735   // Asymmetric padding case.
3736   if (padding[0].first != padding[0].second ||
3737       padding[1].first != padding[1].second) {
3738     auto pad_layer = params->converter->network()->addPadding(
3739         *tensor, nvinfer1::DimsHW(padding[0].first, padding[1].first),
3740         nvinfer1::DimsHW(padding[0].second, padding[1].second));
3741     TFTRT_RETURN_ERROR_IF_NULLPTR(pad_layer, node_def.name());
3742     params->converter->SetLayerName(pad_layer, node_def, "pad");
3743     params->converter->MarkQuantizationRangesAsInferrable(
3744         tensor, pad_layer->getOutput(0));
3745     padding = {{0, 0}, {0, 0}};
3746     tensor = pad_layer->getOutput(0);
3747   }
3748 #endif
3749 
3750   nvinfer1::IPoolingLayer* layer =
3751       params->converter->network()->addPooling(*tensor, type, ksize);
3752   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
3753   // TODO(tmorris): Average pooling may not be entirely safe to infer
3754   // quantization range through (at least forwards - backwards should be fine).
3755   // Max pooling is okay.
3756   params->converter->MarkQuantizationRangesAsInferrable(tensor,
3757                                                         layer->getOutput(0));
3758 
3759   layer->setStride(stride);
3760 #if IS_TRT_VERSION_GE(5, 1, 3, 0)
3761   // VALID padding is the default TRT behavior.
3762   if (attrs.get<string>("padding") == "SAME") {
3763     // SAME_UPPER means that post padding is preferred.
3764     layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER);
3765   }
3766 #elif IS_TRT_VERSION_GE(5, 1, 0, 0)
3767   layer->setPrePadding(nvinfer1::DimsHW{padding[0].first, padding[1].first});
3768   layer->setPostPadding(nvinfer1::DimsHW{padding[0].second, padding[1].second});
3769 #else
3770   layer->setPadding(nvinfer1::DimsHW{padding[0].first, padding[1].first});
3771 #endif
3772   params->converter->SetLayerName(layer, node_def, "pooling");
3773   nvinfer1::ITensor* output_tensor = layer->getOutput(0);
3774 
3775   if (data_format == "NHWC") {
3776     TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
3777         output_tensor, {0, 2, 3, 1}, &output_tensor, node_def, "to_NHWC"));
3778   }
3779   params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
3780   return Status::OK();
3781 }
3782 
ConvertLeakyRelu(OpConverterParams * params)3783 Status ConvertLeakyRelu(OpConverterParams* params) {
3784   const auto& inputs = params->inputs;
3785   const auto& node_def = params->node_def;
3786   TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}}));
3787   TF_RETURN_IF_ERROR(
3788       AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
3789   TFAttrs attrs(node_def);
3790   const float alpha = attrs.get<float>("alpha");
3791 
3792 #if IS_TRT_VERSION_GE(5, 1, 2, 0)
3793   // Use IActivationLayer when available.
3794   if (params->validation_only) return Status::OK();
3795 
3796   nvinfer1::IActivationLayer* layer =
3797       params->converter->network()->addActivation(
3798           *inputs.at(0).tensor(), nvinfer1::ActivationType::kLEAKY_RELU);
3799   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
3800   params->converter->SetLayerName(layer, node_def, "activation");
3801   layer->setAlpha(alpha);
3802   params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0)));
3803   return Status::OK();
3804 #else
3805   // Use elementwise ops when IActivationLayer is not available.
3806   if (alpha < 0.0f || alpha > 1.0f) {
3807     return errors::Unimplemented(
3808         "Alpha value for LeakyRelu must be between 0 and 1, at ",
3809         node_def.name());
3810   }
3811   if (params->validation_only) return Status::OK();
3812 
3813   nvinfer1::ITensor* tensor = inputs.at(0).tensor();
3814   // Create const for alpha.
3815   nvinfer1::ITensor* const_alpha_tensor = nullptr;
3816   TF_RETURN_IF_ERROR(CreateBroadcastableScalarConstant(
3817       params, alpha, tensor->getDimensions(), &const_alpha_tensor));
3818   // alpha * x
3819   nvinfer1::IElementWiseLayer* mul_layer =
3820       params->converter->network()->addElementWise(
3821           *tensor, *const_alpha_tensor, nvinfer1::ElementWiseOperation::kPROD);
3822   TFTRT_RETURN_ERROR_IF_NULLPTR(mul_layer, node_def.name());
3823   params->converter->SetLayerName(mul_layer, node_def, "mul");
3824   // max(x, alpha * x)
3825   nvinfer1::IElementWiseLayer* max_layer =
3826       params->converter->network()->addElementWise(
3827           *tensor, *mul_layer->getOutput(0),
3828           nvinfer1::ElementWiseOperation::kMAX);
3829   TFTRT_RETURN_ERROR_IF_NULLPTR(max_layer, node_def.name());
3830   params->converter->SetLayerName(mul_layer, node_def, "max");
3831   nvinfer1::ITensor* output_tensor = max_layer->getOutput(0);
3832   params->converter->MarkQuantizationRangesAsInferrable(
3833       output_tensor, mul_layer->getOutput(0));
3834 
3835   params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
3836   return Status::OK();
3837 #endif
3838 }
3839 
3840 #if IS_TRT_VERSION_GE(5, 1, 2, 0)
ConvertClipByValue(OpConverterParams * params)3841 Status ConvertClipByValue(OpConverterParams* params) {
3842   const auto& inputs = params->inputs;
3843   const auto& node_def = params->node_def;
3844   // TODO(tmorris): We can also allow the case where min and max are tensors by
3845   // using elementwise min and max layers.
3846   TF_RETURN_IF_ERROR(CheckInputsWeights(
3847       *params,
3848       {{"t", false}, {"clip_value_min", true}, {"clip_value_max", true}}));
3849   TF_RETURN_IF_ERROR(
3850       AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
3851   if (params->validation_only) return Status::OK();
3852 
3853   TFAttrs attrs(node_def);
3854   const DataType dtype = attrs.get<DataType>("T");
3855   float clip_value_min = 0.0f;
3856   float clip_value_max = 0.0f;
3857   // TODO(tmorris): Add a templated helper function to get scalar weights of
3858   // InType casted to OutType.
3859   if (dtype == DataType::DT_FLOAT) {
3860     clip_value_min = inputs.at(1).weights().GetSpan<float>()[0];
3861     clip_value_max = inputs.at(2).weights().GetSpan<float>()[0];
3862   } else if (dtype == DataType::DT_HALF) {
3863     clip_value_min = Eigen::half_impl::half_to_float(
3864         inputs.at(1).weights().GetSpan<Eigen::half>()[0]);
3865     clip_value_max = Eigen::half_impl::half_to_float(
3866         inputs.at(2).weights().GetSpan<Eigen::half>()[0]);
3867   }
3868 
3869   nvinfer1::IActivationLayer* layer =
3870       params->converter->network()->addActivation(
3871           *inputs.at(0).tensor(), nvinfer1::ActivationType::kCLIP);
3872   layer->setAlpha(clip_value_min);
3873   layer->setBeta(clip_value_max);
3874   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
3875   params->converter->SetLayerName(layer, node_def, "activation");
3876   nvinfer1::ITensor* output_tensor = layer->getOutput(0);
3877   params->converter->ProvideQuantizationRange(output_tensor, clip_value_min,
3878                                               clip_value_max);
3879   params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
3880   return Status::OK();
3881 }
3882 #endif
3883 
3884 const std::unordered_map<string, nvinfer1::ActivationType>*
ActivationTypeMap()3885 ActivationTypeMap() {
3886   static auto* const m =
3887       new std::unordered_map<string, nvinfer1::ActivationType>({
3888         {"Relu", nvinfer1::ActivationType::kRELU},
3889             {"Sigmoid", nvinfer1::ActivationType::kSIGMOID},
3890             {"Tanh", nvinfer1::ActivationType::kTANH},
3891 #if IS_TRT_VERSION_GE(5, 1, 2, 0)
3892             {"Elu", nvinfer1::ActivationType::kELU},
3893             {"Selu", nvinfer1::ActivationType::kSELU},
3894             {"Softsign", nvinfer1::ActivationType::kSOFTSIGN},
3895             {"Softplus", nvinfer1::ActivationType::kSOFTPLUS},
3896 #endif
3897       });
3898   return m;
3899 }
3900 
ConvertActivation(OpConverterParams * params)3901 Status ConvertActivation(OpConverterParams* params) {
3902   const auto& inputs = params->inputs;
3903   const auto& node_def = params->node_def;
3904   TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}}));
3905   TF_RETURN_IF_ERROR(
3906       AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
3907   auto op_pair = ActivationTypeMap()->find(node_def.op());
3908   if (op_pair == ActivationTypeMap()->end()) {
3909     return errors::Unimplemented("Activation op: ", node_def.op(),
3910                                  " not supported at: ", node_def.name());
3911   }
3912   if (params->validation_only) return Status::OK();
3913 
3914   // Start conversion.
3915   nvinfer1::IActivationLayer* layer =
3916       params->converter->network()->addActivation(*inputs.at(0).tensor(),
3917                                                   op_pair->second);
3918   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
3919   params->converter->SetLayerName(layer, node_def, "activation");
3920   // Set parameters.
3921 #if IS_TRT_VERSION_GE(5, 1, 2, 0)
3922   if (node_def.op() == "Elu") {
3923     layer->setAlpha(1.0f);
3924   } else if (node_def.op() == "Selu") {
3925     // From tensorflow/core/kernels/relu_op_functor.h
3926     layer->setAlpha(1.7580993408473768599402175208123f);
3927     layer->setBeta(1.0507009873554804934193349852946f);
3928   } else if (node_def.op() == "Softplus") {
3929     layer->setAlpha(1.0f);
3930     layer->setBeta(1.0f);
3931   }
3932 #endif
3933   nvinfer1::ITensor* output_tensor = layer->getOutput(0);
3934   // Set quantization range for output when known.
3935   if (node_def.op() == "Sigmoid") {
3936     params->converter->ProvideQuantizationRange(output_tensor, 0.0f, 1.0f);
3937   } else if (node_def.op() == "Tanh") {
3938     params->converter->ProvideQuantizationRange(output_tensor, -1.0f, 1.0f);
3939   } else if (node_def.op() == "Softsign") {
3940     params->converter->ProvideQuantizationRange(output_tensor, -1.0f, 1.0f);
3941   }
3942   params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
3943   return Status::OK();
3944 }
3945 
ConvertQuantize(OpConverterParams * params)3946 Status ConvertQuantize(OpConverterParams* params) {
3947   const auto& inputs = params->inputs;
3948   const auto& node_def = params->node_def;
3949   if (node_def.op() == "FakeQuantWithMinMaxArgs") {
3950     TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}}));
3951   } else if (node_def.op() == "FakeQuantWithMinMaxVars") {
3952     TF_RETURN_IF_ERROR(CheckInputsWeights(
3953         *params, {{"input", false}, {"min", true}, {"max", true}}));
3954   } else if (node_def.op() == "QuantizeAndDequantizeV2") {
3955     TF_RETURN_IF_ERROR(CheckInputsWeights(
3956         *params, {{"input", false}, {"input_min", true}, {"input_max", true}}));
3957   } else if (node_def.op() == "QuantizeAndDequantizeV3") {
3958     TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false},
3959                                                     {"input_min", true},
3960                                                     {"input_max", true},
3961                                                     {"num_bits", true}}));
3962   }
3963   float min_range = 0.0f;
3964   float max_range = 0.0f;
3965   if (node_def.op() == "FakeQuantWithMinMaxArgs") {
3966     // Get ranges via node attributes.
3967     TFAttrs attrs(node_def);
3968     if (attrs.count("min") == 0 || attrs.count("max") == 0) {
3969       return errors::InvalidArgument("Min or max attribute not found for ",
3970                                      node_def.op(), " at ", node_def.name());
3971     }
3972     min_range = attrs.get<float>("min");
3973     max_range = attrs.get<float>("max");
3974   } else if (node_def.op() == "FakeQuantWithMinMaxVars" ||
3975              node_def.op() == "QuantizeAndDequantizeV2" ||
3976              node_def.op() == "QuantizeAndDequantizeV3") {
3977     // Get ranges via inputs.
3978     auto get_weights_value = [&inputs](int index) {
3979       auto raw_weights =
3980           static_cast<float*>(inputs.at(index).weights().GetValues());
3981       return raw_weights[0];
3982     };
3983     min_range = get_weights_value(1);
3984     max_range = get_weights_value(2);
3985   } else {
3986     return errors::InvalidArgument("Unknown quantization op ", node_def.op(),
3987                                    ", at ", node_def.name());
3988   }
3989   if (params->validation_only) return Status::OK();
3990 
3991   // Store ranges for tensor
3992   params->converter->ProvideQuantizationRange(inputs.at(0).tensor(), min_range,
3993                                               max_range);
3994   // Sometimes, TRT may not quantize a tensor, either because it chooses to
3995   // execute a higher precision kernel or because of op fusion. In these cases,
3996   // accuracy will suffer if the model was trained to expect quantization at
3997   // that tensor. We should consider adding a clip(tensor, min_range, max_range)
3998   // operation here to ensure that any arbitrarily placed quantize node will
3999   // execute as expected. However, this will negatively affect performance. If
4000   // users train their models in a way which models inference as close as
4001   // possible (i.e. not quantizing in place where fusion will occur), then there
4002   // is no problem with the current implementation.
4003   params->outputs->push_back(inputs.at(0));
4004   return Status::OK();
4005 }
4006 
ConvertRelu6(OpConverterParams * params)4007 Status ConvertRelu6(OpConverterParams* params) {
4008   const auto& inputs = params->inputs;
4009   const auto& node_def = params->node_def;
4010   TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}}));
4011   TF_RETURN_IF_ERROR(
4012       AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
4013   if (params->validation_only) return Status::OK();
4014 
4015 #if IS_TRT_VERSION_GE(5, 1, 2, 0)
4016   // Use IActivationLayer for TRT >= 5.1
4017   nvinfer1::IActivationLayer* layer =
4018       params->converter->network()->addActivation(
4019           *inputs.at(0).tensor(), nvinfer1::ActivationType::kCLIP);
4020   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
4021   layer->setAlpha(0.0f);
4022   layer->setBeta(6.0f);
4023   params->converter->SetLayerName(layer, node_def, "activation");
4024   nvinfer1::ITensor* output_tensor = layer->getOutput(0);
4025   params->converter->ProvideQuantizationRange(output_tensor, 0.0f, 6.0f);
4026   params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
4027   return Status::OK();
4028 #else
4029   // Convert using min(Relu(x), 6) before TRT 5.1
4030   // Input Tensor
4031   nvinfer1::ITensor* tensor = inputs.at(0).tensor();
4032 
4033   // Relu operation i.e. Relu(x) = max(0, x)
4034   nvinfer1::IActivationLayer* relu_layer =
4035       params->converter->network()->addActivation(
4036           *tensor, nvinfer1::ActivationType::kRELU);
4037   TFTRT_RETURN_ERROR_IF_NULLPTR(relu_layer, node_def.name());
4038   params->converter->SetLayerName(relu_layer, node_def, "activation");
4039 
4040   // Large range of relu is problematic during quantization in INT8 precision
4041   // mode. Setting dynamic range of relu = [0.f, 6.0f] helps with quantization.
4042   // TRT only uses dynamic ranges in INT8 precision mode,
4043   // and this does not affect the FP32 path.
4044   params->converter->ProvideQuantizationRange(relu_layer->getOutput(0), 0.0f,
4045                                               6.0f);
4046 
4047   // Create a constant layer to store the floating point weight i.e. 6.0f
4048   nvinfer1::ITensor* const6_tensor = nullptr;
4049   TF_RETURN_IF_ERROR(CreateBroadcastableScalarConstant(
4050       params, 6.0f, relu_layer->getOutput(0)->getDimensions(), &const6_tensor));
4051 
4052   // ElementWise Min Operation
4053   // Min op is a nop for INT8 execution path, as the input tensor
4054   // to this layer will only have values in range [0.f, 6.0f].
4055   nvinfer1::IElementWiseLayer* relu6_layer =
4056       params->converter->network()->addElementWise(
4057           *relu_layer->getOutput(0), *const6_tensor,
4058           nvinfer1::ElementWiseOperation::kMIN);
4059   TFTRT_RETURN_ERROR_IF_NULLPTR(relu6_layer, node_def.name());
4060   params->converter->SetLayerName(relu6_layer, node_def, "min");
4061   nvinfer1::ITensor* output_tensor = relu6_layer->getOutput(0);
4062   params->converter->ProvideQuantizationRange(output_tensor, 0.0f, 6.0f);
4063 
4064   params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
4065   return Status::OK();
4066 #endif
4067 }
4068 
ConvertBiasAddInt8WithoutCalibration(OpConverterParams * params)4069 Status ConvertBiasAddInt8WithoutCalibration(OpConverterParams* params) {
4070   const auto& inputs = params->inputs;
4071   const auto& node_def = params->node_def;
4072   TF_RETURN_IF_ERROR(
4073       CheckInputsWeights(*params, {{"value", false}, {"bias", true}}));
4074   TF_RETURN_IF_ERROR(
4075       AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
4076   if (params->validation_only) return Status::OK();
4077 
4078   nvinfer1::ITensor* tensor = inputs.at(0).tensor();
4079   const nvinfer1::Dims original_dims = tensor->getDimensions();
4080   TFAttrs attrs(node_def);
4081   const string data_format = attrs.get<string>("data_format");
4082   const int channel_index =
4083       (data_format == "NHWC" ? original_dims.nbDims - 1 : 0);
4084 
4085   nvinfer1::Permutation permutation;
4086   if (channel_index != 0) {
4087     // Permute the dimensions so that the channel dimension is the first
4088     // dimension.
4089     for (int i = 0; i < original_dims.nbDims; ++i) {
4090       permutation.order[i] = i;
4091     }
4092     permutation.order[0] = channel_index;
4093     permutation.order[channel_index] = 0;
4094     VLOG(1) << "ConvertBiasAdd permutation: "
4095             << DebugString(permutation, original_dims.nbDims);
4096   }
4097 
4098   // TensorRT addScale requires input to be of rank 3, we need to apply
4099   // transpose as well as reshape.
4100   // TODO(laigd): this doesn't match what the TRT doc says, fix the doc?
4101   if (channel_index != 0 || original_dims.nbDims != 3) {
4102     nvinfer1::IShuffleLayer* shuffle_layer =
4103         params->converter->network()->addShuffle(*tensor);
4104     TFTRT_RETURN_ERROR_IF_NULLPTR(shuffle_layer, node_def.name());
4105     params->converter->SetLayerName(shuffle_layer, node_def, "shuffle",
4106                                     /*op_instance=*/0);
4107     params->converter->MarkQuantizationRangesAsInferrable(
4108         tensor, shuffle_layer->getOutput(0));
4109 
4110     // NOTE(laigd): for some reason we need to apply the reshape
4111     // unconditionally. The default shape has nbDims==-1 and it seems the
4112     // behavior is undefined in some cases.
4113     nvinfer1::Dims reshape_dims;
4114     reshape_dims.nbDims = 3;
4115     // 0 means copying from input; -1 means inferring from the rest.
4116     reshape_dims.d[0] = 0;
4117     reshape_dims.d[1] = original_dims.nbDims >= 2 ? 0 : 1;
4118     reshape_dims.d[2] = original_dims.nbDims >= 3 ? -1 : 1;
4119     shuffle_layer->setReshapeDimensions(reshape_dims);
4120 
4121     if (channel_index != 0) {
4122       shuffle_layer->setFirstTranspose(permutation);
4123     }
4124     tensor = shuffle_layer->getOutput(0);
4125   }
4126 
4127   TRT_ShapedWeights weights = inputs.at(1).weights();
4128   nvinfer1::ScaleMode mode = nvinfer1::ScaleMode::kCHANNEL;
4129   if (weights.shape_.d[0] == 1) {
4130     mode = nvinfer1::ScaleMode::kUNIFORM;
4131   }
4132 
4133   TRT_ShapedWeights empty_weights(weights.TrtDType());
4134   nvinfer1::IScaleLayer* layer = params->converter->network()->addScale(
4135       *tensor, mode, weights.GetTrtWeights(), empty_weights.GetTrtWeights(),
4136       empty_weights.GetTrtWeights());
4137   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
4138   params->converter->SetLayerName(layer, node_def, "scale");
4139 
4140   nvinfer1::ITensor* output_tensor = layer->getOutput(0);
4141 
4142   // Restore transpose & reshape.
4143   if (channel_index != 0 || original_dims.nbDims != 3) {
4144     nvinfer1::IShuffleLayer* shuffle_layer =
4145         params->converter->network()->addShuffle(*output_tensor);
4146     TFTRT_RETURN_ERROR_IF_NULLPTR(shuffle_layer, node_def.name());
4147     params->converter->SetLayerName(shuffle_layer, node_def, "shuffle",
4148                                     /*op_instance=*/1);
4149     // NOTE: for same reason as mentioned above we need to apply the reshape
4150     // unconditionally.
4151     nvinfer1::Dims reshape_dims = original_dims;
4152     if (channel_index != 0) {
4153       // NOTE: according to NVIDIA dimension types are deprecated, so we don't
4154       // need to copy them back.
4155       reshape_dims.d[channel_index] = original_dims.d[0];
4156       reshape_dims.d[0] = original_dims.d[channel_index];
4157     }
4158     shuffle_layer->setReshapeDimensions(reshape_dims);
4159 
4160     if (channel_index != 0) {
4161       shuffle_layer->setSecondTranspose(permutation);
4162     }
4163     params->converter->MarkQuantizationRangesAsInferrable(
4164         output_tensor, shuffle_layer->getOutput(0));
4165     output_tensor = shuffle_layer->getOutput(0);
4166   }
4167 
4168   params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
4169   return Status::OK();
4170 }
4171 
ConvertBiasAdd(OpConverterParams * params)4172 Status ConvertBiasAdd(OpConverterParams* params) {
4173   if (params->precision_mode == TrtPrecisionMode::INT8 &&
4174       !params->use_calibration) {
4175     // NOTE(laigd): based on some observation, it seems TensorRT cannot fuse
4176     // IConvolutionLayer and IElementwiseLayer and will require range
4177     // information for the output of Conv2D. Using IScaleLayer will fix the
4178     // problem.
4179     return ConvertBiasAddInt8WithoutCalibration(params);
4180   }
4181   const auto& inputs = params->inputs;
4182   const auto& node_def = params->node_def;
4183 
4184   if (inputs.size() != 2) {
4185     return errors::InvalidArgument(
4186         "BiasAdd expects exactly 2 inputs, but received ", inputs.size());
4187   }
4188 
4189   if (inputs[0].is_weights() && inputs[1].is_weights()) {
4190     return errors::InvalidArgument(
4191         "All inputs are weights, but Grappler is expected to fold them.");
4192   }
4193 
4194   TF_RETURN_IF_ERROR(
4195       AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
4196 
4197   TFAttrs attrs(node_def);
4198   const string& data_format = attrs.get<string>("data_format");
4199 
4200   nvinfer1::Dims input_shape = inputs.at(0).GetTrtDims();
4201   nvinfer1::Dims bias_shape = inputs.at(1).GetTrtDims();
4202   // The bias input arg is a 1-D tensor with length C. If the input is NCHW,
4203   // then we need to unsqueeze the bias such that its shape is [1, C, 1, 1].
4204   if (data_format == "NCHW") {
4205     if (params->use_implicit_batch) {
4206       // The batch dim is not included in implicit batch mode, so the shape of
4207       // the bias tensor is [C, 1, 1].
4208       bias_shape.nbDims = input_shape.nbDims;
4209       std::fill(bias_shape.d + 1, bias_shape.d + bias_shape.nbDims, 1);
4210     } else {
4211       // In explicit batch mode we create a tensor with shape [1, C, 1, 1].
4212       std::vector<int> bias_shape_vec(bias_shape.d,
4213                                       bias_shape.d + bias_shape.nbDims);
4214       // Insert 1 before for batch dim
4215       bias_shape_vec.insert(bias_shape_vec.begin(), 1);
4216       // Trail with 1s to match input_shape size
4217       bias_shape_vec.insert(bias_shape_vec.end(),
4218                             input_shape.nbDims - bias_shape_vec.size(), 1);
4219       TF_RETURN_IF_ERROR(
4220           TensorShapeArrayToTrtDims(bias_shape_vec, &bias_shape));
4221     }
4222   } else {
4223     // Next, broadcast the bias across the input.
4224     TF_RETURN_IF_ERROR(GetTrtBroadcastShape(inputs.at(0), inputs.at(1),
4225                                             /*check_feasibility=*/true,
4226                                             params->use_implicit_batch,
4227                                             &input_shape, &bias_shape));
4228   }
4229 
4230   // Convert input to a TRT tensor
4231   nvinfer1::ITensor* input_tensor{nullptr};
4232   TF_RETURN_IF_ERROR(PrepareTensorForShape(params->converter, inputs.at(0),
4233                                            input_shape, params->validation_only,
4234                                            &input_tensor, node_def,
4235                                            /*op_instance=*/0));
4236 
4237   // Finally, reshape bias. Since the bias is usually a constant, this will
4238   // normally happen at conversion-time.
4239   nvinfer1::ITensor* bias_tensor{nullptr};
4240   TF_RETURN_IF_ERROR(PrepareTensorForShape(params->converter, inputs.at(1),
4241                                            bias_shape, params->validation_only,
4242                                            &bias_tensor, node_def,
4243                                            /*op_instance=*/1));
4244   VLOG(2) << "Bias shape adjusted to " << DebugString(bias_shape);
4245 
4246   if (params->validation_only) return Status::OK();
4247 
4248   nvinfer1::IElementWiseLayer* layer =
4249       params->converter->network()->addElementWise(
4250           *input_tensor, *bias_tensor, nvinfer1::ElementWiseOperation::kSUM);
4251   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
4252   params->converter->SetLayerName(layer, node_def, "sum");
4253   nvinfer1::ITensor* output_tensor = layer->getOutput(0);
4254 
4255   params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
4256   return Status::OK();
4257 }
4258 
GetTensorDimsWithProtoShape(const Tensor & tensor,nvinfer1::Dims * dims)4259 void GetTensorDimsWithProtoShape(const Tensor& tensor, nvinfer1::Dims* dims) {
4260   if (tensor.dims() > 0) {
4261     *dims = GetTrtDimsForTensor(tensor);
4262   } else {
4263     dims->nbDims = 1;
4264     // No dimension provided. Flatten it.
4265     dims->d[0] = tensor.NumElements();
4266     dims->type[0] = nvinfer1::DimensionType::kSPATIAL;
4267     for (int i = 1; i < nvinfer1::Dims::MAX_DIMS; ++i) {
4268       dims->d[i] = 0;
4269     }
4270   }
4271 }
4272 
4273 template <typename Input>
IsIntegerInInt32Bounds(const Input & inp)4274 inline bool IsIntegerInInt32Bounds(const Input& inp) {
4275   static_assert(std::is_integral<Input>::value,
4276                 "This function is only implemented for integral types.");
4277   // If Input is always within the range of int32, return true.
4278   if (sizeof(Input) < sizeof(int32) || std::is_same<Input, int32>::value) {
4279     return true;
4280   }
4281   // Otherwise, we need to check the value of the input. If the input is
4282   // unsigned, we only check the upper bound.
4283   if (!std::numeric_limits<Input>::is_signed) {
4284     return inp <= static_cast<Input>(std::numeric_limits<int32>::max());
4285   }
4286   // We can safely cast lowest() here since we now know that Input is signed and
4287   // sizeof(Input) >= sizeof(int32)
4288   return (inp >= static_cast<Input>(std::numeric_limits<int32>::lowest()) &&
4289           inp <= static_cast<Input>(std::numeric_limits<int32>::max()));
4290 }
4291 
4292 template <DataType dtype>
CopyToTrtInt32Array(const Tensor & tensor,int32 * dst)4293 Status CopyToTrtInt32Array(const Tensor& tensor, int32* dst) {
4294   typedef typename EnumToDataType<dtype>::Type CType;
4295   const CType* src = tensor.flat<CType>().data();
4296   for (int i = 0; i < tensor.NumElements(); ++i) {
4297     // This becomes a no-op if CType is within bounds of int32
4298     if (!IsIntegerInInt32Bounds(src[i])) {
4299       return errors::InvalidArgument("Value at index ", i,
4300                                      " is outside the range of int32");
4301     }
4302     dst[i] = static_cast<int32>(src[i]);
4303   }
4304   return Status::OK();
4305 }
4306 
TfTensorToTrtWeights(const Tensor & tensor,TrtWeightStore * weight_store,TRT_ShapedWeights * weights)4307 Status TfTensorToTrtWeights(const Tensor& tensor, TrtWeightStore* weight_store,
4308                             TRT_ShapedWeights* weights) {
4309   const DataType dtype = tensor.dtype();
4310 
4311   // We always convert the integer constants to INT32.
4312   //
4313   // TODO(aaroey): FP16 will remain in half format and is not converted to
4314   // FP32, but the converter currently uses all float weights as FP32. Fix
4315   // this.
4316   DataType converted_dtype = DataTypeIsInteger(dtype) ? DT_INT32 : dtype;
4317 
4318   // Verify that the dtype is supported by TensorRT. Otherwise, return an error.
4319   nvinfer1::DataType trt_dtype;
4320   TF_RETURN_IF_ERROR(TfTypeToTrtType(converted_dtype, &trt_dtype));
4321 
4322   if (tensor.NumElements() == 0) {
4323     // Return empty weights.
4324     *weights = TRT_ShapedWeights(trt_dtype);
4325     return Status::OK();
4326   }
4327 
4328   nvinfer1::Dims weight_dims;
4329   GetTensorDimsWithProtoShape(tensor, &weight_dims);
4330   *weights = weight_store->GetTempWeights(trt_dtype, weight_dims);
4331 
4332   // Copy the tensor directly if the tensor does not require cast to the
4333   // supported type.
4334   if (converted_dtype == dtype) {
4335     char* dst = static_cast<char*>(weights->GetValues());
4336     memcpy(dst, tensor.tensor_data().data(), tensor.TotalBytes());
4337     return Status::OK();
4338   }
4339 
4340   Status status = Status::OK();
4341   // Copy tensor elements after casting them to the converted DataType.
4342   int32* dst = static_cast<int32*>(weights->GetValues());
4343   switch (dtype) {
4344     case DT_INT8:
4345       status = CopyToTrtInt32Array<DT_INT8>(tensor, dst);
4346       break;
4347     case DT_UINT8:
4348       status = CopyToTrtInt32Array<DT_UINT8>(tensor, dst);
4349       break;
4350     case DT_INT16:
4351       status = CopyToTrtInt32Array<DT_INT16>(tensor, dst);
4352       break;
4353     case DT_UINT16:
4354       status = CopyToTrtInt32Array<DT_UINT16>(tensor, dst);
4355       break;
4356     case DT_UINT32:
4357       status = CopyToTrtInt32Array<DT_UINT32>(tensor, dst);
4358       break;
4359     case DT_INT64:
4360       status = CopyToTrtInt32Array<DT_INT64>(tensor, dst);
4361       break;
4362     case DT_UINT64:
4363       status = CopyToTrtInt32Array<DT_UINT64>(tensor, dst);
4364       break;
4365     default:
4366       return errors::Internal("Unexpected DataType: ", DataTypeString(dtype));
4367   }
4368   return status;
4369 }
4370 
4371 // Convert a Const NodeDef to TRT_ShapedWeights. This is a special converter, it
4372 // always ignores the params->validation_only parameter but adds the converted
4373 // weights to params->outputs. We did this since TrtNodeValidator needs the
4374 // weights as input to other nodes, and use it to determine whether those nodes
4375 // are supported by TRT.
ConvertConst(OpConverterParams * params)4376 Status ConvertConst(OpConverterParams* params) {
4377   const auto& inputs = params->inputs;
4378   const auto& node_def = params->node_def;
4379   if (!inputs.empty()) {
4380     return errors::InvalidArgument(
4381         "Constant node is expected to have empty input list: ",
4382         node_def.name());
4383   }
4384 
4385   // Create shaped weights as output
4386   const auto& tensor_proto = node_def.attr().at("value").tensor();
4387   Tensor tensor;
4388   if (!tensor.FromProto(tensor_proto)) {
4389     return errors::Internal("Cannot parse weight tensor proto: ",
4390                             node_def.name());
4391   }
4392 
4393   TFAttrs attrs(node_def);
4394   const DataType dtype = attrs.get<DataType>("dtype");
4395   if (dtype != tensor.dtype()) {
4396     return errors::InvalidArgument("DataType mismatch between attr (",
4397                                    DataTypeString(dtype), ") and tensor (",
4398                                    DataTypeString(tensor.dtype()), ")");
4399   }
4400 
4401   TRT_ShapedWeights weights;
4402   TF_RETURN_IF_ERROR(
4403       TfTensorToTrtWeights(tensor, params->weight_store, &weights));
4404 
4405   if (params->outputs != nullptr) {
4406     params->outputs->push_back(TRT_TensorOrWeights(weights));
4407   }
4408   return Status::OK();
4409 }
4410 
ConvertIdentity(OpConverterParams * params)4411 Status ConvertIdentity(OpConverterParams* params) {
4412   // TODO(tmorris): TRT's Identity layer does not get optimized away as of TRT
4413   // 5.0, however once we know that it does it would be nice to use that
4414   // instead.
4415   if (params->validation_only) return Status::OK();
4416   params->outputs->push_back(params->inputs.at(0));
4417   return Status::OK();
4418 }
4419 
4420 const std::unordered_map<string, nvinfer1::ElementWiseOperation>*
BinaryOperationMap()4421 BinaryOperationMap() {
4422   static auto* const m =
4423       new std::unordered_map<string, nvinfer1::ElementWiseOperation> {
4424     {"Add", nvinfer1::ElementWiseOperation::kSUM},
4425         {"AddV2", nvinfer1::ElementWiseOperation::kSUM},
4426         {"Mul", nvinfer1::ElementWiseOperation::kPROD},
4427         {"Sub", nvinfer1::ElementWiseOperation::kSUB},
4428         {"Div", nvinfer1::ElementWiseOperation::kDIV},
4429 #if IS_TRT_VERSION_GE(6, 0, 1, 0)
4430         // Use TensorRT native FloorDiv.
4431         {"FloorDiv", nvinfer1::ElementWiseOperation::kFLOOR_DIV},
4432 #elif IS_TRT_VERSION_GE(5, 1, 0, 0)
4433         // Emulate FloorDiv by doing Div then Floor.
4434         {"FloorDiv", nvinfer1::ElementWiseOperation::kDIV},
4435 #endif
4436         {"RealDiv", nvinfer1::ElementWiseOperation::kDIV},
4437         {"Minimum", nvinfer1::ElementWiseOperation::kMIN},
4438         {"Maximum", nvinfer1::ElementWiseOperation::kMAX},
4439         {"Pow", nvinfer1::ElementWiseOperation::kPOW},
4440   };
4441   return m;
4442 }
4443 
ConvertBinary(OpConverterParams * params)4444 Status ConvertBinary(OpConverterParams* params) {
4445   const auto& inputs = params->inputs;
4446   const auto& node_def = params->node_def;
4447   if (inputs.size() != 2) {
4448     return errors::InvalidArgument(node_def.op(), " got ", inputs.size(),
4449                                    " inputs but expected 2, at ",
4450                                    node_def.name());
4451   }
4452 #if IS_TRT_VERSION_GE(6, 0, 0, 0)
4453   std::set<DataType> allowed_types{DataType::DT_FLOAT, DataType::DT_HALF,
4454                                    DataType::DT_INT32};
4455 #else
4456   std::set<DataType> allowed_types{DataType::DT_FLOAT, DataType::DT_HALF};
4457 #endif
4458   TF_RETURN_IF_ERROR(AllowDataTypes(*params, allowed_types));
4459 
4460   // Constant folding should have been done by TensorFlow
4461   if (inputs.at(0).is_weights() && inputs.at(1).is_weights()) {
4462     return errors::Unimplemented(
4463         "Constant folding is falled back to TensorFlow, binary op received "
4464         "both input as constant at: ",
4465         node_def.name());
4466   }
4467   const TRT_TensorOrWeights& operand_l = inputs.at(0);
4468   const TRT_TensorOrWeights& operand_r = inputs.at(1);
4469 
4470   auto op_pair = BinaryOperationMap()->find(node_def.op());
4471   if (op_pair == BinaryOperationMap()->end()) {
4472     return errors::Unimplemented("Binary op ", node_def.op(),
4473                                  " not supported at: ", node_def.name());
4474   }
4475 
4476   nvinfer1::Dims broadcasted_dims_l, broadcasted_dims_r;
4477   TF_RETURN_IF_ERROR(GetTrtBroadcastShape(
4478       operand_l, operand_r, /*check_feasibility=*/true,
4479       params->use_implicit_batch, &broadcasted_dims_l, &broadcasted_dims_r));
4480   nvinfer1::ITensor* tensor_l = nullptr;
4481   nvinfer1::ITensor* tensor_r = nullptr;
4482   // This will also convert constants to tensors, and set quantization ranges.
4483   TF_RETURN_IF_ERROR(PrepareTensorForShape(
4484       params->converter, operand_l, broadcasted_dims_l, params->validation_only,
4485       &tensor_l, node_def, /*op_instance=*/0));
4486   TF_RETURN_IF_ERROR(PrepareTensorForShape(
4487       params->converter, operand_r, broadcasted_dims_r, params->validation_only,
4488       &tensor_r, node_def, /*op_instance=*/1));
4489   if (params->validation_only) return Status::OK();
4490 
4491   // Add ElementWise layer.
4492   nvinfer1::ILayer* layer = params->converter->network()->addElementWise(
4493       *tensor_l, *tensor_r, op_pair->second);
4494   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
4495   params->converter->SetLayerName(layer, node_def);
4496   nvinfer1::ITensor* trt_tensor = layer->getOutput(0);
4497 
4498 #if IS_TRT_VERSION_GE(5, 1, 0, 0) and !IS_TRT_VERSION_GE(6, 0, 1, 0)
4499   if (node_def.op() == "FloorDiv") {
4500     layer = params->converter->network()->addUnary(
4501         *trt_tensor, nvinfer1::UnaryOperation::kFLOOR);
4502     TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
4503     params->converter->SetLayerName(layer, node_def, "floor");
4504     trt_tensor = layer->getOutput(0);
4505   }
4506 #endif
4507   params->outputs->push_back(TRT_TensorOrWeights(trt_tensor));
4508   return Status::OK();
4509 }
4510 
ConvertRsqrt(OpConverterParams * params)4511 Status ConvertRsqrt(OpConverterParams* params) {
4512   const auto& inputs = params->inputs;
4513   const auto& node_def = params->node_def;
4514   TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}}));
4515   TF_RETURN_IF_ERROR(
4516       AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
4517   if (params->validation_only) return Status::OK();
4518 
4519   // TODO(tmorris): params->converter is null during validation. Allow
4520   // precision_mode and use_calibration to be accessed during validation and
4521   // include this check in validation.
4522   // We will need a quantization range for intermediate tensor if not using
4523   // calibration.
4524   //
4525   //   x -> [Sqrt] -> sqrt(x) -> [Recip] -> 1/sqrt(x)
4526   //                     ^
4527   //               need range here
4528   if (params->converter->precision_mode() == TrtPrecisionMode::INT8 &&
4529       !params->converter->use_calibration()) {
4530     return errors::Unimplemented(
4531         "Intermediate quantization range cannot be determined without"
4532         " calibration for Rsqrt, consider replacing with "
4533         "Sqrt -> FakeQuant -> Reciprocal ops, at ",
4534         node_def.name());
4535   }
4536   // Start conversion.
4537   nvinfer1::ITensor* tensor = inputs.at(0).tensor();
4538   // Sqrt
4539   nvinfer1::IUnaryLayer* sqrt_layer = params->converter->network()->addUnary(
4540       *tensor, nvinfer1::UnaryOperation::kSQRT);
4541   TFTRT_RETURN_ERROR_IF_NULLPTR(sqrt_layer, node_def.name());
4542   params->converter->SetLayerName(sqrt_layer, node_def, "sqrt");
4543   // Recip
4544   nvinfer1::IUnaryLayer* recip_layer = params->converter->network()->addUnary(
4545       *sqrt_layer->getOutput(0), nvinfer1::UnaryOperation::kRECIP);
4546   TFTRT_RETURN_ERROR_IF_NULLPTR(recip_layer, node_def.name());
4547   params->converter->SetLayerName(recip_layer, node_def, "recip");
4548   params->outputs->push_back(TRT_TensorOrWeights(recip_layer->getOutput(0)));
4549   return Status::OK();
4550 }
4551 
4552 const std::unordered_map<string, nvinfer1::UnaryOperation>*
UnaryOperationMap()4553 UnaryOperationMap() {
4554   static auto* const m =
4555       new std::unordered_map<string, nvinfer1::UnaryOperation>({
4556         {"Neg", nvinfer1::UnaryOperation::kNEG},
4557             {"Exp", nvinfer1::UnaryOperation::kEXP},
4558             {"Log", nvinfer1::UnaryOperation::kLOG},
4559             {"Sqrt", nvinfer1::UnaryOperation::kSQRT},
4560             {"Abs", nvinfer1::UnaryOperation::kABS},
4561             {"Reciprocal", nvinfer1::UnaryOperation::kRECIP},
4562 #if IS_TRT_VERSION_GE(5, 1, 0, 0)
4563             {"Sin", nvinfer1::UnaryOperation::kSIN},
4564             {"Cos", nvinfer1::UnaryOperation::kCOS},
4565             {"Tan", nvinfer1::UnaryOperation::kTAN},
4566             {"Sinh", nvinfer1::UnaryOperation::kSINH},
4567             {"Cosh", nvinfer1::UnaryOperation::kCOSH},
4568             {"Asin", nvinfer1::UnaryOperation::kASIN},
4569             {"Acos", nvinfer1::UnaryOperation::kACOS},
4570             {"Atan", nvinfer1::UnaryOperation::kATAN},
4571             {"Asinh", nvinfer1::UnaryOperation::kASINH},
4572             {"Acosh", nvinfer1::UnaryOperation::kACOSH},
4573             {"Atanh", nvinfer1::UnaryOperation::kATANH},
4574             {"Ceil", nvinfer1::UnaryOperation::kCEIL},
4575             {"Floor", nvinfer1::UnaryOperation::kFLOOR},
4576 #endif
4577       });
4578   return m;
4579 }
4580 
ConvertUnary(OpConverterParams * params)4581 Status ConvertUnary(OpConverterParams* params) {
4582   const auto& inputs = params->inputs;
4583   const auto& node_def = params->node_def;
4584   TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}}));
4585   TF_RETURN_IF_ERROR(
4586       AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
4587   auto op_pair = UnaryOperationMap()->find(node_def.op());
4588   if (op_pair == UnaryOperationMap()->end()) {
4589     return errors::Unimplemented("Unary op: ", node_def.op(),
4590                                  " not supported at: ", node_def.name());
4591   }
4592   if (params->validation_only) return Status::OK();
4593 
4594   // Start conversion.
4595   nvinfer1::ITensor* tensor = inputs.at(0).tensor();
4596   nvinfer1::IUnaryLayer* layer =
4597       params->converter->network()->addUnary(*tensor, op_pair->second);
4598   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
4599   params->converter->SetLayerName(layer, node_def);
4600   nvinfer1::ITensor* output_tensor = layer->getOutput(0);
4601 
4602   // Set quantization ranges.
4603   if (node_def.op() == "Sin" || node_def.op() == "Cos") {
4604     params->converter->ProvideQuantizationRange(output_tensor, -1.0f, 1.0f);
4605   } else if (node_def.op() == "Asin" || node_def.op() == "Atan") {
4606     params->converter->ProvideQuantizationRange(output_tensor, -M_PI_2, M_PI_2);
4607   } else if (node_def.op() == "Acos") {
4608     params->converter->ProvideQuantizationRange(output_tensor, 0.0f, M_PI);
4609   } else if (node_def.op() == "Neg" || node_def.op() == "Abs") {
4610     // Neg and Abs will have same range as input since TRT uses symmetric
4611     // quantization.
4612     // TODO(tmorris): Should we infer ranges for Ceil and Floor as well?
4613     params->converter->MarkQuantizationRangesAsInferrable(tensor,
4614                                                           output_tensor);
4615   }
4616   params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
4617   return Status::OK();
4618 }
4619 
ConvertSquare(OpConverterParams * params)4620 Status ConvertSquare(OpConverterParams* params) {
4621   const auto& inputs = params->inputs;
4622   const auto& node_def = params->node_def;
4623   TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}}));
4624 #if IS_TRT_VERSION_GE(6, 0, 1, 0)
4625   TF_RETURN_IF_ERROR(AllowDataTypes(
4626       *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}));
4627 #else
4628   TF_RETURN_IF_ERROR(
4629       AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
4630 #endif
4631   if (params->validation_only) return Status::OK();
4632 
4633   // Constant 2 with same rank as input
4634   nvinfer1::ITensor* const2_tensor = nullptr;
4635   TF_RETURN_IF_ERROR(CreateBroadcastableScalarConstant(
4636       params, 2.0f, inputs.at(0).GetTrtDims(), &const2_tensor));
4637 
4638   // ElementWise Pow Operation
4639   nvinfer1::IElementWiseLayer* layer =
4640       params->converter->network()->addElementWise(
4641           *inputs.at(0).tensor(), *const2_tensor,
4642           nvinfer1::ElementWiseOperation::kPOW);
4643   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
4644   params->converter->SetLayerName(layer, node_def);
4645   nvinfer1::ITensor* output_tensor = layer->getOutput(0);
4646 
4647   params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
4648   return Status::OK();
4649 }
4650 
ConvertReduce(OpConverterParams * params)4651 Status ConvertReduce(OpConverterParams* params) {
4652   const auto& inputs = params->inputs;
4653   const auto& node_def = params->node_def;
4654   TF_RETURN_IF_ERROR(
4655       CheckInputsWeights(*params, {{"input", false}, {"axis", true}}));
4656   TF_RETURN_IF_ERROR(
4657       AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
4658 
4659   nvinfer1::ITensor* tensor = inputs.at(0).tensor();
4660   auto tf_axes_list = inputs.at(1).weights().GetSpan<int>();
4661 
4662   TFAttrs attrs(node_def);
4663   // Only expect to handle INT32 as attributes for now
4664   if (attrs.get<DataType>("Tidx") != DataType::DT_INT32) {
4665     return errors::Unimplemented("Tidx supports only DT_INT32");
4666   }
4667 
4668   int axes = 0;
4669   if (tf_axes_list.size() == 0) {
4670     return errors::InvalidArgument(
4671         "TRT cannot support reduce on all (batch) dimensions, at",
4672         node_def.name());
4673   }
4674   for (int i = 0; i < tf_axes_list.size(); i++) {
4675     int trt_axis;
4676     TF_RETURN_IF_ERROR(
4677         ConvertAxis(tf_axes_list[i], tensor->getDimensions().nbDims,
4678                     node_def.name(), params->use_implicit_batch, &trt_axis));
4679     axes |= (1 << trt_axis);
4680   }
4681 
4682   nvinfer1::ReduceOperation reduce_operation;
4683   if (node_def.op() == "Sum") {
4684     reduce_operation = nvinfer1::ReduceOperation::kSUM;
4685   } else if (node_def.op() == "Prod") {
4686     reduce_operation = nvinfer1::ReduceOperation::kPROD;
4687   } else if (node_def.op() == "Max") {
4688     reduce_operation = nvinfer1::ReduceOperation::kMAX;
4689   } else if (node_def.op() == "Min") {
4690     reduce_operation = nvinfer1::ReduceOperation::kMIN;
4691   } else if (node_def.op() == "Mean") {
4692     reduce_operation = nvinfer1::ReduceOperation::kAVG;
4693   } else {
4694     return errors::Unimplemented("Op not supported ", node_def.op(), ", at ",
4695                                  node_def.name());
4696   }
4697   if (params->validation_only) return Status::OK();
4698 
4699   const auto keep_dims = attrs.get<bool>("keep_dims");
4700   nvinfer1::ILayer* layer = params->converter->network()->addReduce(
4701       *tensor, reduce_operation, axes, keep_dims);
4702   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
4703   params->converter->SetLayerName(layer, node_def);
4704 
4705   params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0)));
4706   return Status::OK();
4707 }
4708 
4709 // TensorRT does not support the Pack op natively. Therefore, Pack op is
4710 // converted by first expanding input tensors by adding a new dimension of size
4711 // one at the specified axis and then concatenating the tensors at the same
4712 // axis.
ConvertPack(OpConverterParams * params)4713 Status ConvertPack(OpConverterParams* params) {
4714   const auto& inputs = params->inputs;
4715   const auto& node_def = params->node_def;
4716 
4717   TFAttrs attrs(node_def);
4718   const int num_inputs = attrs.get<int64>("N");
4719   if (num_inputs != inputs.size()) {
4720     return errors::InvalidArgument(
4721         "Number of inputs for Pack is inconsistent with N attribute, at ",
4722         node_def.name());
4723   }
4724 
4725   // In implicit batch mode we do not allow weight input. An input tensor with
4726   // dims NCHW is represented with dims CHW during conversion time, and N is
4727   // defined only during runtime. A weight is represented with dims NCHW. We
4728   // cannot be sure that the runtime N will agree with the conversion time N,
4729   // therefore we do not convert the pack op if it has both tensor and weight
4730   // inputs. This restriction does not apply in explicit batch mode, in that
4731   // case the input tensors are also represented with full dims that include the
4732   // batch size.
4733   TrtInputArg expected_arg =
4734       params->use_implicit_batch ? TrtInputArg::kTensor : TrtInputArg::kBoth;
4735 
4736   std::vector<std::pair<string, TrtInputArg>> inputs_is_weight;
4737   inputs_is_weight.reserve(num_inputs);
4738   for (int i = 0; i < num_inputs; ++i) {
4739     inputs_is_weight.push_back({StrCat("values_", i), expected_arg});
4740   }
4741   TF_RETURN_IF_ERROR(CheckInputsWeights(*params, inputs_is_weight));
4742 
4743 #if IS_TRT_VERSION_GE(6, 0, 0, 0)
4744   std::set<DataType> allowed_types{DataType::DT_FLOAT, DataType::DT_HALF,
4745                                    DataType::DT_INT32};
4746 #else
4747   std::set<DataType> allowed_types{DataType::DT_FLOAT, DataType::DT_HALF};
4748 #endif
4749   TF_RETURN_IF_ERROR(AllowDataTypes(*params, allowed_types));
4750   if (num_inputs > 1) {
4751     // Verify that inputs are compatible for concatenation after the expansion.
4752     TF_RETURN_IF_ERROR(
4753         VerifyShapesMatch(inputs, /*masked_dim=*/-1, node_def.name()));
4754   }
4755 
4756   // Find the dimension of the inputs. In general inputs can have dynamic shape,
4757   // in that case we have to use DynamicExpandDims to calculate the expanded
4758   // dimensions. To avoid that, we try to find a weight input which is
4759   // guaranteed to have known static shape.
4760   int idx = 0;
4761   for (int i = 1; i < inputs.size(); i++) {
4762     if (HasStaticShape(inputs.at(i).GetTrtDims())) {
4763       idx = i;
4764     }
4765   }
4766   const nvinfer1::Dims dims = inputs.at(idx).GetTrtDims();
4767   // Convert axis from the TensorFlow format to TensorRT format.
4768   const int64 tf_axis = attrs.get<int64>("axis");
4769   int trt_axis;
4770   TF_RETURN_IF_ERROR(ConvertAxis(tf_axis, dims.nbDims + 1, node_def.name(),
4771                                  params->use_implicit_batch, &trt_axis));
4772 
4773   // Compute expanded dimensions and then reshape input tensors.
4774   std::vector<int> tensor_dims(dims.d, dims.d + dims.nbDims);
4775   tensor_dims.insert(tensor_dims.begin() + trt_axis, 1);
4776   nvinfer1::Dims expanded_dims;
4777   TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(tensor_dims, &expanded_dims));
4778   std::vector<nvinfer1::ITensor*> expanded_tensors;
4779   int input_index = 0;
4780   for (const TRT_TensorOrWeights& input : inputs) {
4781     nvinfer1::ITensor* expanded_tensor = nullptr;
4782     if (input.is_tensor() && !params->use_implicit_batch &&
4783         !HasStaticShape(dims)) {
4784       if (!params->validation_only) {
4785         TF_RETURN_IF_ERROR(params->converter->DynamicExpandDims(
4786             input.tensor(), dims, trt_axis, params, &expanded_tensor,
4787             input_index));
4788       }
4789     } else {
4790       TF_RETURN_IF_ERROR(PrepareTensorForShape(
4791           params->converter, input, expanded_dims, params->validation_only,
4792           &expanded_tensor, node_def, input_index));
4793     }
4794     if (!params->validation_only) {
4795       expanded_tensors.push_back(expanded_tensor);
4796     }
4797     input_index++;
4798   }
4799   if (params->validation_only) return Status::OK();
4800 
4801   // If there is only one tensor in the input, return the expanded tensor.
4802   if (num_inputs == 1) {
4803     params->outputs->push_back(TRT_TensorOrWeights(expanded_tensors[0]));
4804     return Status::OK();
4805   }
4806 
4807   // Otherwise, concatenate expanded tensors.
4808   nvinfer1::IConcatenationLayer* layer =
4809       params->converter->network()->addConcatenation(
4810           const_cast<nvinfer1::ITensor**>(expanded_tensors.data()),
4811           expanded_tensors.size());
4812   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
4813   params->converter->SetLayerName(layer, node_def, "concat");
4814   // Note that trt_axis stays the same even after expanding tensors at the axis.
4815   layer->setAxis(trt_axis);
4816   params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0)));
4817   return Status::OK();
4818 }
4819 
ConvertPad(OpConverterParams * params)4820 Status ConvertPad(OpConverterParams* params) {
4821   const auto& inputs = params->inputs;
4822   const auto& node_def = params->node_def;
4823   TF_RETURN_IF_ERROR(
4824       CheckInputsWeights(*params, {{"tensor", false}, {"paddings", true}}));
4825   TF_RETURN_IF_ERROR(AllowDataTypes(
4826       *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT8}));
4827 
4828   // Implement tensor binaryOp weight [channel wise] for now;
4829   nvinfer1::ITensor* tensor = inputs.at(0).tensor();
4830   const auto dims = tensor->getDimensions();
4831   // Restore implicit batch dimension
4832   const int nb_dims =
4833       params->use_implicit_batch ? dims.nbDims + 1 : dims.nbDims;
4834 
4835   // TODO(tfeher): Support nb_dims < 4 by inserting extra dimensions to the
4836   // original input.
4837   if (nb_dims < 4) {
4838     return errors::InvalidArgument("Convertpad requires at least 4D input, at ",
4839                                    node_def.name());
4840   }
4841   TRT_ShapedWeights pads = inputs.at(1).weights();
4842 
4843   TFAttrs attrs(node_def);
4844   // Padding type here is done through TF type
4845   //   so I can leverage their EnumToDataType for my cast
4846   auto padding_type = attrs.get<DataType>("Tpaddings");
4847   // TODO(jie): handle data type conversion for TRT?
4848 
4849   if (pads.shape_.d[0] != nb_dims || pads.shape_.d[1] != 2) {
4850     return errors::InvalidArgument("Paddings at ", node_def.name(),
4851                                    " must be a weight with shape [n, 2], "
4852                                    "where n is the rank of input tensor");
4853   }
4854 
4855   // Only expect to handle INT32 as attributes for now
4856   if (padding_type != DataType::DT_INT32) {
4857     return errors::Unimplemented("Tpaddings supports only DT_INT32");
4858   }
4859   auto pad_data = static_cast<int*>(pads.GetValues());
4860 
4861   std::vector<int32_t> tf_pad_index;
4862   for (int i = 0; i < nb_dims; i++) {
4863     if (pad_data[2 * i] != 0 || pad_data[2 * i + 1] != 0) {
4864       tf_pad_index.push_back(i);
4865     }
4866   }
4867 
4868   // No padding at all, we should exit
4869   if (tf_pad_index.empty()) {
4870     params->outputs->push_back(inputs.at(0));
4871     return Status::OK();
4872   }
4873 
4874   // TRT pad layer can only support padding on up to 2 dimensions (TRT-2579).
4875   // TODO(tfeher): Use multiple TRT pad layers to support padding on more than 2
4876   // dimensions.
4877   if (tf_pad_index.size() > 2) {
4878     return errors::InvalidArgument(
4879         "Padding layer does not support padding on > 2");
4880   }
4881 
4882   // Padding on batch dimension is not supported
4883   if (params->use_implicit_batch && tf_pad_index[0] == 0) {
4884     return errors::InvalidArgument(
4885         "Padding layer does not support padding on batch dimension");
4886   }
4887 
4888   if (params->validation_only) return Status::OK();
4889 
4890   // TRT can only do the padding at the last two dimensions. We transpose the
4891   // input tensor if needed.
4892   bool transposed_pad = false;
4893   std::vector<int> transpose_idx(nb_dims);
4894   std::iota(transpose_idx.begin(), transpose_idx.end(), 0);
4895 
4896   // trt_pad_index denotes the actual idx where the padding is performed by TRT.
4897   std::vector<int> trt_pad_index{nb_dims - 2, nb_dims - 1};
4898 
4899   // How many zeros are padded at the last two dimensions.
4900   nvinfer1::DimsHW pre_padding(0, 0);
4901   nvinfer1::DimsHW post_padding(0, 0);
4902 
4903   // Dimension to set in the pre_padding and post_padding array.
4904   std::vector<int> trt_pre_post_padding_index{0, 1};
4905 
4906   // Two special cases where we can avoid permutations.
4907   if (tf_pad_index.size() == 1 && tf_pad_index[0] == nb_dims - 1) {
4908     // Only one dimension needs to be padded. We store its index at
4909     // trt_pad_index[0]. We ignore trt_pad_index[1].
4910     trt_pad_index[0] = nb_dims - 1;
4911     trt_pre_post_padding_index[0] = 1;
4912   }
4913   if (tf_pad_index.size() == 2 && tf_pad_index[1] == nb_dims - 2) {
4914     // tf_pad_index only has two values that are in ascending order. If
4915     // tf_pad_index[1] is nb_dims-2, then swapping the two values in
4916     // trt_pad_index here makes it possible to only swap one pair of dimensions
4917     // (swap tf_pad_index[0] with nb_dims-1) in the input tensor. Otherwise, we
4918     // would have to swap two pairs of dimensions in the input tensor:
4919     // (tf_pad_index[0] with nb_dims-2) and (tf_pad_index[1], with nb_dims-1).
4920     // Here is an example for a 4D input tensor:
4921     // tf_pad_index = [1, 2]
4922     // trt_pad_index = [3, 2]
4923     // transpose_idx = [0, 3, 2, 1]
4924     std::swap(trt_pad_index[0], trt_pad_index[1]);
4925     std::swap(trt_pre_post_padding_index[0], trt_pre_post_padding_index[1]);
4926   }
4927 
4928   for (int i = 0; i < tf_pad_index.size(); i++) {
4929     const int tf_index = tf_pad_index[i];
4930     const int trt_index = trt_pad_index[i];
4931     const int k = trt_pre_post_padding_index[i];
4932     pre_padding.d[k] = pad_data[tf_index * 2];
4933     post_padding.d[k] = pad_data[tf_index * 2 + 1];
4934     if (tf_index != trt_index) {
4935       transposed_pad = true;
4936       std::swap(transpose_idx[tf_index], transpose_idx[trt_index]);
4937     }
4938   }
4939 
4940   if (transposed_pad) {
4941     TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
4942         tensor, transpose_idx, &tensor, node_def, "to_pad"));
4943   }
4944 
4945   nvinfer1::IPaddingLayer* layer = params->converter->network()->addPadding(
4946       *tensor, pre_padding, post_padding);
4947   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
4948   params->converter->SetLayerName(layer, node_def);
4949   nvinfer1::ITensor* output_tensor = layer->getOutput(0);
4950   params->converter->MarkQuantizationRangesAsInferrable(tensor, output_tensor);
4951 
4952   if (transposed_pad) {
4953     TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
4954         output_tensor, transpose_idx, &output_tensor, node_def, "from_pad"));
4955   }
4956 
4957   params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
4958   return Status::OK();
4959 }
4960 
ConvertSplitHelper(OpConverterParams * params,const TRT_TensorOrWeights & input,int tf_axis,int num_splits,bool squeeze_after)4961 Status ConvertSplitHelper(OpConverterParams* params,
4962                           const TRT_TensorOrWeights& input, int tf_axis,
4963                           int num_splits, bool squeeze_after) {
4964   const auto& node_def = params->node_def;
4965   const nvinfer1::Dims dims = input.GetTrtDims();
4966   // Convert axis.
4967   int trt_axis;
4968   TF_RETURN_IF_ERROR(ConvertAxis(tf_axis, dims.nbDims, node_def.name(),
4969                                  params->use_implicit_batch, &trt_axis));
4970   // Dimension must equal num_splits for Unstack (when squeeze_after is true)
4971   if (squeeze_after && dims.d[trt_axis] != num_splits) {
4972     return errors::InvalidArgument(
4973         "Dimension ", tf_axis, " has size ", dims.d[trt_axis],
4974         " which is not equal to num of ", num_splits, ", at ", node_def.name());
4975   }
4976   // Dimension must be evenly divisible by num_splits.
4977   if (dims.d[trt_axis] % num_splits != 0) {
4978     return errors::InvalidArgument(
4979         "Dimension ", tf_axis, " of size ", dims.d[trt_axis],
4980         " is not evenly divisble by ", num_splits, ", at ", node_def.name());
4981   }
4982 
4983   // Create parameters for StridedSliceHelper.
4984   // Slice will begin on zero for all dims, except the one being split which
4985   // will change.
4986   std::vector<int> begin(dims.nbDims, 0);
4987   // Determine size of split. Slice will get the full length of all dims, except
4988   // the one being split.
4989   std::vector<int> size(dims.d, dims.d + dims.nbDims);
4990   const int split_size_on_axis = dims.d[trt_axis] / num_splits;
4991   size[trt_axis] = split_size_on_axis;
4992   // Stride will always be 1
4993   std::vector<int> stride(dims.nbDims, 1);
4994   // Add dummy batch dimension
4995   begin.insert(begin.begin(), 0);
4996   size.insert(size.begin(), 1);
4997   stride.insert(stride.begin(), 1);
4998   // Create final shape for Unpack/Unstack, where split axis is squeezed.
4999   nvinfer1::Dims final_shape_for_unpack;
5000   nvinfer1::Dims* final_shape_for_unpack_ptr = nullptr;
5001   if (squeeze_after) {
5002     std::vector<int> size_after_squeeze(size);
5003     size_after_squeeze.erase(size_after_squeeze.begin() + trt_axis + 1);
5004     TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(
5005         size_after_squeeze, &final_shape_for_unpack, /*ignore_frst_dim=*/true));
5006     final_shape_for_unpack_ptr = &final_shape_for_unpack;
5007   }
5008 
5009   // Slice the input. ConvertStridedSliceHelper will push the outputs onto
5010   // params->outputs.
5011   for (int i = 0; i < num_splits; ++i) {
5012     begin[trt_axis + 1] = i * split_size_on_axis;
5013     TF_RETURN_IF_ERROR(ConvertStridedSliceHelper(
5014         params, input, begin, size, stride, final_shape_for_unpack_ptr, i));
5015   }
5016   return Status::OK();
5017 }
5018 
ConvertSplit(OpConverterParams * params)5019 Status ConvertSplit(OpConverterParams* params) {
5020   const auto& inputs = params->inputs;
5021   const auto& node_def = params->node_def;
5022   TF_RETURN_IF_ERROR(
5023       CheckInputsWeights(*params, {{"axis", true}, {"value", false}}));
5024   TF_RETURN_IF_ERROR(AllowDataTypes(*params, {
5025     DataType::DT_FLOAT, DataType::DT_HALF,
5026 #if IS_TRT_VERSION_GE(5, 1, 3, 1)
5027         DataType::DT_INT32,
5028 #endif
5029   }));
5030   int tf_axis = inputs.at(0).weights().GetSpan<int>()[0];
5031   TFAttrs attrs(node_def);
5032   const int num_split = attrs.get<int64>("num_split");
5033 
5034   return ConvertSplitHelper(params, inputs.at(1), tf_axis, num_split, false);
5035 }
5036 
ConvertUnpack(OpConverterParams * params)5037 Status ConvertUnpack(OpConverterParams* params) {
5038   const auto& inputs = params->inputs;
5039   const auto& node_def = params->node_def;
5040   TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"value", false}}));
5041   TF_RETURN_IF_ERROR(AllowDataTypes(*params, {
5042     DataType::DT_FLOAT, DataType::DT_HALF,
5043 #if IS_TRT_VERSION_GE(5, 1, 3, 1)
5044         DataType::DT_INT32,
5045 #endif
5046   }));
5047   // Input must be rank 1 or higher, since we can't unpack on axis 0.
5048   if (inputs.at(0).GetTrtDims().nbDims == 0) {
5049     return errors::Unimplemented(
5050         "Input \"value\" for Unpack must be rank 2 or greater, at ",
5051         node_def.name());
5052   }
5053   TFAttrs attrs(node_def);
5054   const int tf_axis = attrs.get<int64>("axis");
5055   const int num = attrs.get<int64>("num");
5056 
5057   return ConvertSplitHelper(params, inputs.at(0), tf_axis, num, true);
5058 }
5059 
5060 // Supports cast fp16=>fp32 through IIdentityLayer.
ConvertCast(OpConverterParams * params)5061 Status ConvertCast(OpConverterParams* params) {
5062   const NodeDef& node_def = params->node_def;
5063   TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}}));
5064   auto unsupport_cast_error = [&]() {
5065     return errors::Unimplemented("Cast op: ", node_def.op(),
5066                                  " not supported at: ", node_def.name());
5067   };
5068 
5069   DataType input_type;
5070   TF_RETURN_IF_ERROR(GetInputTfType(*params, &input_type, 0));
5071   if (input_type != DataType::DT_HALF) {
5072     return unsupport_cast_error();
5073   }
5074 
5075   DataType output_type;
5076   TF_RETURN_IF_ERROR(GetNodeDefTfType(params->node_def, &output_type,
5077                                       kCastOutputTypeAttrName));
5078 
5079   if (output_type != DataType::DT_FLOAT) {
5080     return unsupport_cast_error();
5081   }
5082 
5083   if (params->validation_only) return Status::OK();
5084 
5085   nvinfer1::ITensor* input = params->inputs.at(0).tensor();
5086   nvinfer1::IIdentityLayer* layer =
5087       params->converter->network()->addIdentity(*input);
5088   params->converter->SetLayerName(layer, node_def);
5089   layer->setPrecision(nvinfer1::DataType::kFLOAT);
5090 
5091   if (layer->getOutput(0)->getType() != nvinfer1::DataType::kFLOAT) {
5092     return errors::Internal("IIdentityLayer doesn't work as expected");
5093   }
5094 
5095   params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0)));
5096   return Status::OK();
5097 }
5098 
ConvertConcat(OpConverterParams * params)5099 Status ConvertConcat(OpConverterParams* params) {
5100   const auto& inputs = params->inputs;
5101   const auto& node_def = params->node_def;
5102   TFAttrs attrs(node_def);
5103   // Get number of tensor inputs.
5104   const int num_inputs = attrs.get<int64>("N");
5105   if (num_inputs != static_cast<int>(inputs.size()) - 1) {
5106     return errors::InvalidArgument(
5107         "Number of inputs for ConcatV2 is inconsistent with N attribute, at ",
5108         node_def.name());
5109   }
5110   // Validate inputs. Values must be tensors for now, although it would be
5111   // possible to accept weights in explicit batch mode. See CheckInputsWeights
5112   // for details. TODO(tfeher): Allow weight input in explicit batch mode.
5113   std::vector<std::pair<string, TrtInputArg>> inputs_kinds;
5114   TrtInputArg expected_input = TrtInputArg::kTensor;
5115   inputs_kinds.reserve(num_inputs);
5116   for (int i = 0; i < num_inputs; ++i) {
5117     inputs_kinds.push_back({StrCat("values_", i), expected_input});
5118   }
5119   inputs_kinds.push_back({"axis", TrtInputArg::kWeight});
5120   TF_RETURN_IF_ERROR(CheckInputsWeights(*params, inputs_kinds));
5121 
5122 #if IS_TRT_VERSION_GE(7, 0, 0, 0)
5123   std::set<DataType> allowed_types{DataType::DT_FLOAT, DataType::DT_HALF,
5124                                    DataType::DT_INT32};
5125 #else
5126   std::set<DataType> allowed_types{DataType::DT_FLOAT, DataType::DT_HALF};
5127 #endif
5128   TF_RETURN_IF_ERROR(AllowDataTypes(*params, allowed_types));
5129   const auto axis = inputs.at(num_inputs).weights().GetSpan<int>();
5130   if (axis.size() != 1) {
5131     return errors::InvalidArgument("Axis for ConcatV2 must be a scalar, at ",
5132                                    node_def.name());
5133   }
5134   int trt_axis = 0;
5135   const auto dim = inputs.at(0).GetTrtDims();
5136   TF_RETURN_IF_ERROR(ConvertAxis(axis[0], dim.nbDims, node_def.name(),
5137                                  params->use_implicit_batch, &trt_axis));
5138   // Check that dimensions match on non-concatenate axis.
5139   TF_RETURN_IF_ERROR(VerifyShapesMatch(
5140       absl::Span<const TRT_TensorOrWeights>(inputs).first(num_inputs), trt_axis,
5141       node_def.name()));
5142   if (params->validation_only) return Status::OK();
5143 
5144   // Gather inputs as tensors
5145   std::vector<nvinfer1::ITensor const*> input_tensors;
5146   input_tensors.reserve(num_inputs);
5147   for (int i = 0; i < num_inputs; i++) {
5148     input_tensors.push_back(inputs.at(i).tensor());
5149   }
5150   nvinfer1::IConcatenationLayer* layer =
5151       params->converter->network()->addConcatenation(
5152           const_cast<nvinfer1::ITensor* const*>(input_tensors.data()),
5153           input_tensors.size());
5154   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
5155   params->converter->SetLayerName(layer, node_def);
5156   layer->setAxis(trt_axis);
5157   params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0)));
5158   return Status::OK();
5159 }
5160 
ConvertFusedBatchNorm(OpConverterParams * params)5161 Status ConvertFusedBatchNorm(OpConverterParams* params) {
5162   const auto& inputs = params->inputs;
5163   const auto& node_def = params->node_def;
5164   TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false},
5165                                                   {"scale", true},
5166                                                   {"offset", true},
5167                                                   {"mean", true},
5168                                                   {"variance", true}}));
5169   TF_RETURN_IF_ERROR(
5170       AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
5171   TFAttrs attrs(node_def);
5172   float epsilon = attrs.get<float>("epsilon");
5173   auto data_format = attrs.get<string>("data_format");
5174   if (data_format != "NCHW") {
5175     return errors::Unimplemented(
5176         node_def.op(), " only supports data_format=NCHW, at ", node_def.name());
5177   }
5178   bool is_training = attrs.get<bool>("is_training");
5179   if (is_training) {
5180     // Trying to use batchnorm in training mode is a very common problem.
5181     // Because the error message will only be printed in VLOG(1) by the
5182     // segmenter, we issue a special warning so that users will actually see it.
5183     LOG_WARNING_WITH_PREFIX
5184         << node_def.op() << " only supports is_training=false. If you "
5185         << "are using Keras, please call "
5186         << "keras.backend.set_learning_phase(0) before constructing "
5187         << "your model. At " << node_def.name();
5188     return errors::Unimplemented(node_def.op(),
5189                                  " only supports is_training=false, at ",
5190                                  node_def.name());
5191   }
5192   nvinfer1::ITensor* tensor = inputs.at(0).tensor();
5193   if (!params->use_implicit_batch && tensor->getDimensions().d[1] == -1) {
5194     // This check is to make sure that channel dimension is known during
5195     // conversion.
5196     //
5197     // We check this only in explicit batch mode and reject an op with unknown
5198     // channel dimension during segmentation. In implicit batch mode we have
5199     // known shapes during conversion even though the shapes may not be known
5200     // during segmentation (see the actual argument for input_shapes when
5201     // ConvertGraphDefToEngine is called from TRTEngineOp::BuildEngine).
5202     return errors::InvalidArgument("Channel dimension must be static, at ",
5203                                    node_def.name());
5204   }
5205   //  Check parameter types
5206   auto parameter_type = inputs.at(1).weights().TrtDType();
5207   if ((parameter_type != nvinfer1::DataType::kFLOAT) &&
5208       (parameter_type != nvinfer1::DataType::kHALF)) {
5209     return errors::Unimplemented(
5210         "Only float32 or float16 weight data type is supported, for node ",
5211         node_def.name(), " got ", DebugString(parameter_type));
5212   }
5213   for (int i = 1; i < 5; i++) {
5214     if (inputs.at(i).weights().TrtDType() != parameter_type) {
5215       return errors::Unimplemented(
5216           "Inconsistent parameter type for batchnorm is not supported, at: " +
5217           node_def.name());
5218     }
5219   }
5220 
5221   TRT_ShapedWeights dummy_power_weights(parameter_type);
5222   size_t nweight = 0;
5223   for (int i = 1; i < 5; i++) {
5224     nweight = std::max<size_t>(nweight, inputs.at(i).weights().count());
5225   }
5226   const TRT_ShapedWeights* ptr_shape_weights = nullptr;
5227   for (int i = 1; i < 5; i++) {
5228     if (inputs.at(i).weights().count() == nweight) {
5229       ptr_shape_weights = &(inputs.at(i).weights());
5230     } else if (inputs.at(i).weights().count() != 1) {
5231       return errors::InvalidArgument(
5232           "Inconsistent batchnorm parameter count, at: " + node_def.name());
5233     }
5234   }
5235   if (params->validation_only) return Status::OK();
5236 
5237   //  We could technically have two weights with different shape.
5238   //  that requires two addScale op, arguably less performant
5239   TRT_ShapedWeights combined_scale_weights =
5240       params->weight_store->GetTempWeights(*ptr_shape_weights);
5241   TRT_ShapedWeights combined_offset_weights =
5242       params->weight_store->GetTempWeights(*ptr_shape_weights);
5243 
5244   const Eigen::half* cast_vals_array[4];
5245   const float* vals_array[4];
5246   for (int j = 0; j < 4; j++) {
5247     cast_vals_array[j] =
5248         static_cast<Eigen::half const*>(inputs.at(j + 1).weights().GetValues());
5249     vals_array[j] =
5250         static_cast<float const*>(inputs.at(j + 1).weights().GetValues());
5251   }
5252   Eigen::half* cast_combined_scale_vals =
5253       static_cast<Eigen::half*>(combined_scale_weights.GetValues());
5254   Eigen::half* cast_combined_offset_vals =
5255       static_cast<Eigen::half*>(combined_offset_weights.GetValues());
5256   float* combined_scale_vals =
5257       static_cast<float*>(combined_scale_weights.GetValues());
5258   float* combined_offset_vals =
5259       static_cast<float*>(combined_offset_weights.GetValues());
5260 
5261   for (size_t i = 0; i < nweight; ++i) {
5262     float batchnorm_data[4];
5263     for (int j = 0; j < 4; j++) {
5264       if (inputs.at(j + 1).weights().count() != 1) {
5265         if (parameter_type == nvinfer1::DataType::kFLOAT) {
5266           batchnorm_data[j] = vals_array[j][i];
5267         } else if (parameter_type == nvinfer1::DataType::kHALF) {
5268           batchnorm_data[j] =
5269               Eigen::half_impl::half_to_float(cast_vals_array[j][i]);
5270         }
5271       } else {
5272         if (parameter_type == nvinfer1::DataType::kFLOAT) {
5273           batchnorm_data[j] = vals_array[j][0];
5274         } else if (parameter_type == nvinfer1::DataType::kHALF) {
5275           batchnorm_data[j] =
5276               Eigen::half_impl::half_to_float(cast_vals_array[j][0]);
5277         }
5278       }
5279     }
5280     float scale = batchnorm_data[0];
5281     float offset = batchnorm_data[1];
5282     float mean = batchnorm_data[2];
5283     float variance = batchnorm_data[3];
5284     float combined_scale_val = scale / sqrtf(variance + epsilon);
5285     float combined_offset_val = offset - mean * combined_scale_val;
5286     if (parameter_type == nvinfer1::DataType::kFLOAT) {
5287       combined_scale_vals[i] = combined_scale_val;
5288       combined_offset_vals[i] = combined_offset_val;
5289     } else if (parameter_type == nvinfer1::DataType::kHALF) {
5290       cast_combined_scale_vals[i] = Eigen::half(combined_scale_val);
5291       cast_combined_offset_vals[i] = Eigen::half(combined_offset_val);
5292     }
5293   }
5294 
5295   nvinfer1::ScaleMode mode = nweight == 1 ? nvinfer1::ScaleMode::kUNIFORM
5296                                           : nvinfer1::ScaleMode::kCHANNEL;
5297   nvinfer1::IScaleLayer* layer = params->converter->network()->addScale(
5298       *tensor, mode, combined_offset_weights.GetTrtWeights(),
5299       combined_scale_weights.GetTrtWeights(),
5300       dummy_power_weights.GetTrtWeights());
5301   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
5302   params->converter->SetLayerName(layer, node_def);
5303   nvinfer1::ITensor* output_tensor = layer->getOutput(0);
5304   params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
5305   return Status::OK();
5306 }
5307 
ConvertGather(OpConverterParams * params)5308 Status ConvertGather(OpConverterParams* params) {
5309   const auto& inputs = params->inputs;
5310   const auto& node_def = params->node_def;
5311   // TODO(tmorris): Use CheckInputsWeights by changing bool to enum with an
5312   // option for an input to be either tensor or weight.
5313   TF_RETURN_IF_ERROR(
5314       CheckInputsWeights(*params, {{"params", TrtInputArg::kBoth},
5315                                    {"indices", TrtInputArg::kTensor},
5316                                    {"axis", TrtInputArg::kWeight}}));
5317 
5318   const auto& params_input = inputs.at(0);
5319   const auto& indices_input = inputs.at(1);
5320   const auto& axis_input = inputs.at(2);
5321 
5322   TF_RETURN_IF_ERROR(AllowDataTypes(
5323       *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32},
5324       /*dtype_attr_name=*/"Tparams"));
5325   TF_RETURN_IF_ERROR(AllowDataTypes(*params, {DataType::DT_INT32},
5326                                     /*dtype_attr_name=*/"Tindices"));
5327 
5328   absl::Span<const int> axis = axis_input.weights().GetSpan<int>();
5329   if (axis.size() != 1) {
5330     return errors::InvalidArgument("Axis for GatherV2 must be a scalar, at ",
5331                                    node_def.name());
5332   }
5333   int trt_axis = 0;
5334   TF_RETURN_IF_ERROR(ConvertAxis(
5335       axis[0], params_input.GetTrtDims().nbDims, node_def.name(),
5336       params->use_implicit_batch && params_input.is_tensor(), &trt_axis));
5337   if (params->use_implicit_batch && params_input.is_weights() &&
5338       trt_axis != 0) {
5339     return errors::Unimplemented(
5340         "The input axis must be zero when params is a weight.");
5341   }
5342   if (params->use_implicit_batch && params_input.is_tensor() &&
5343       indices_input.batch_size() != 1) {
5344     return errors::Unimplemented(
5345         "Indices must have a batch size of 1 when params is a tensor.");
5346   }
5347   // Both input are tensors, and the TF gather result will have rank:
5348   // (params.nbDims + 1) + (indices.nbDims + 1) - 1,
5349   // where "+ 1" adds the batch dim. If params is a weight, the TRT rank matches
5350   // the TF rank so we don't have to add + 1.
5351   const int params_tf_rank =
5352       params_input.GetTrtDims().nbDims +
5353       (params->use_implicit_batch && params_input.is_tensor() ? 1 : 0);
5354   const int indices_tf_rank =
5355       indices_input.GetTrtDims().nbDims + (params->use_implicit_batch ? 1 : 0);
5356   const int tf_gather_output_rank = params_tf_rank + indices_tf_rank - 1;
5357   if (tf_gather_output_rank >
5358       nvinfer1::Dims::MAX_DIMS + (params->use_implicit_batch ? 1 : 0)) {
5359     return errors::InvalidArgument(
5360         "Result of gather has dimension greater than ",
5361         nvinfer1::Dims::MAX_DIMS + 1);
5362   }
5363   if (params->validation_only) return Status::OK();
5364 
5365   // Convert params to tensor is it is a weight.
5366   nvinfer1::ITensor* params_tensor = nullptr;
5367   if (params_input.is_weights()) {
5368     params_tensor = params->converter->CreateConstantLayer(
5369         params_input.weights(), params_input.GetTrtDims());
5370   } else {
5371     params_tensor = params_input.tensor();
5372   }
5373 
5374   // Note on how IGatherLayer works: if both the data and indices tensors have
5375   // a batch size dimension of size N, it performs:
5376   // for batchid in xrange(N):
5377   //   output[batchid, a0, ..., an, i, ..., j, b0, ..., bn] = (
5378   //       data[batchid, a0, ..., an, indices[batchid, i, ..., j] b0, ..., bn])
5379   nvinfer1::IGatherLayer* layer = params->converter->network()->addGather(
5380       *params_tensor, *indices_input.tensor(), trt_axis);
5381   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
5382   params->converter->SetLayerName(layer, node_def);
5383 
5384   nvinfer1::ITensor* output_tensor = layer->getOutput(0);
5385   nvinfer1::Dims trt_gather_output_dims = output_tensor->getDimensions();
5386   // Note for the "- 2": one is for the output batch dim encapsulated by TF-TRT,
5387   // and the other is for the output dimension that is squeezed by IGatherLayer
5388   // because of the implicit batch dim in the indices (see the above note).
5389   const int expected_trt_output_rank =
5390       tf_gather_output_rank - (params_input.is_tensor() ? 2 : 1);
5391   if (params->use_implicit_batch &&
5392       trt_gather_output_dims.nbDims != expected_trt_output_rank) {
5393     return errors::Internal(
5394         "Get unexpected output dimensions of IGatherLayer. Expect nbDims: ",
5395         expected_trt_output_rank,
5396         ", actual nbDims: ", trt_gather_output_dims.nbDims);
5397   }
5398   // Reshape the output so after adding the implicit batch dim it'll match the
5399   // output shape of TF GatherV2.
5400   if (params->use_implicit_batch && params_input.is_tensor()) {
5401     for (int i = trt_gather_output_dims.nbDims; i > trt_axis; --i) {
5402       trt_gather_output_dims.d[i] = trt_gather_output_dims.d[i - 1];
5403     }
5404     trt_gather_output_dims.d[trt_axis] = 1;
5405     ++trt_gather_output_dims.nbDims;
5406 
5407     TF_RETURN_IF_ERROR(PrepareTensorForShape(
5408         params->converter, TRT_TensorOrWeights(output_tensor),
5409         trt_gather_output_dims,
5410         /*validation_only=*/false, &output_tensor, node_def));
5411   }
5412 
5413   params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
5414   return Status::OK();
5415 }
5416 
ConvertFullyConnectedHelper(OpConverterParams * params,nvinfer1::ITensor * tensor_a,TRT_ShapedWeights weights_b,bool transpose_b,const NodeDef & node_def)5417 Status ConvertFullyConnectedHelper(OpConverterParams* params,
5418                                    nvinfer1::ITensor* tensor_a,
5419                                    TRT_ShapedWeights weights_b,
5420                                    bool transpose_b, const NodeDef& node_def) {
5421   // Reshape input to 3D - this will be a no-op unless using int8 precision.
5422   auto input_dim = tensor_a->getDimensions();
5423   while (input_dim.nbDims < 3) {
5424     input_dim.d[input_dim.nbDims++] = 1;
5425   }
5426   TF_RETURN_IF_ERROR(PrepareTensorForShape(
5427       params->converter, TRT_TensorOrWeights(tensor_a), input_dim,
5428       /*validation_only=*/false, &tensor_a, node_def, /*op_instance=*/0));
5429 
5430   // FC layer will transpose weights, so we need to pre-transpose.
5431   TRT_ShapedWeights weights(weights_b.TrtDType());
5432   if (!transpose_b) {
5433     weights = params->weight_store->GetTempWeights(weights_b);
5434     ReorderCKtoKC(weights_b, &weights);
5435   } else {
5436     weights = weights_b;
5437   }
5438   TRT_ShapedWeights biases(weights.TrtDType());
5439   const int noutput = weights.shape_.d[0];
5440   nvinfer1::IFullyConnectedLayer* layer =
5441       params->converter->network()->addFullyConnected(
5442           *tensor_a, noutput, weights.GetTrtWeights(), biases.GetTrtWeights());
5443 
5444   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
5445   params->converter->SetLayerName(layer, node_def);
5446   nvinfer1::ITensor* output_tensor = layer->getOutput(0);
5447 
5448   // Reshape output to 1D - this will be a no-op unless using int8 precision.
5449   auto output_dim = output_tensor->getDimensions();
5450   output_dim.nbDims = 1;
5451   TF_RETURN_IF_ERROR(PrepareTensorForShape(
5452       params->converter, TRT_TensorOrWeights(output_tensor), output_dim,
5453       /*validation_only=*/false, &output_tensor, node_def, /*op_instance=*/1));
5454 
5455   params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
5456   return Status::OK();
5457 }
5458 
ConvertMatMulHelper(OpConverterParams * params,TRT_TensorOrWeights input_a,TRT_TensorOrWeights input_b,bool transpose_a,bool transpose_b,const NodeDef & node_def)5459 Status ConvertMatMulHelper(OpConverterParams* params,
5460                            TRT_TensorOrWeights input_a,
5461                            TRT_TensorOrWeights input_b, bool transpose_a,
5462                            bool transpose_b, const NodeDef& node_def) {
5463   // TODO: ReorderCKtoKC is currently not general enough to transpose weights
5464   // that are not 2D.
5465   if ((transpose_a && input_a.is_weights() &&
5466        input_a.GetTrtDims().nbDims != 2) ||
5467       (transpose_b && input_b.is_weights() &&
5468        input_b.GetTrtDims().nbDims != 2)) {
5469     return errors::InvalidArgument(
5470         "Cannot currently transpose constant input if it is not 2 dimensional");
5471   }
5472 
5473   // If A is a tensor, we can only transpose if it is at least 3D in TF,
5474   // or TRT will not do the correct transposition.
5475   if (transpose_a && input_a.is_tensor() && input_a.GetTrtDims().nbDims < 2) {
5476     return errors::InvalidArgument(
5477         "Cannot transpose first input if it is a tensor with fewer than 2 "
5478         "non-batch dimensions.");
5479   }
5480 
5481   // If B is a tensor, then it must be at least 3D in TF,
5482   // or TRT won't be able to handle the multiply correctly.
5483   if (input_b.is_tensor() && input_b.GetTrtDims().nbDims < 2) {
5484     return errors::InvalidArgument(
5485         "Second input must either be a constant, or contain at least 2 "
5486         "non-batch dimensions.");
5487   }
5488   if (params->validation_only) return Status::OK();
5489 
5490   // If an FC layer can be used and would be faster, use that instead.
5491   const bool can_use_fc =
5492       !transpose_a && input_a.is_tensor() && input_b.is_weights();
5493   const bool should_use_fc = can_use_fc && input_a.GetTrtDims().nbDims >= 3 &&
5494                              input_b.GetTrtDims().nbDims == 2;
5495   // If int8 is specified, FC must be used unless it is not compatible, as MM
5496   // does not support int8 at this time.
5497   if (should_use_fc || (can_use_fc && params->converter->precision_mode() ==
5498                                           TrtPrecisionMode::INT8)) {
5499     return ConvertFullyConnectedHelper(
5500         params, input_a.tensor(), input_b.weights(), transpose_b, node_def);
5501   }
5502 
5503   const auto get_matrix_op = [](nvinfer1::ITensor* in,
5504                                 bool transpose) -> nvinfer1::MatrixOperation {
5505     return (in->getDimensions().nbDims < 2) ? nvinfer1::MatrixOperation::kVECTOR
5506            : (transpose) ? nvinfer1::MatrixOperation::kTRANSPOSE
5507                          : nvinfer1::MatrixOperation::kNONE;
5508   };
5509 
5510   // If the MatMul operand is a constant, applies transposes at conversion-time
5511   // as necessary. If the operand is a tensor, does nothing. If required
5512   // transposes were applied, sets transpose to false.
5513   const auto prepare_matmul_operand =
5514       [&params](TRT_TensorOrWeights operand,
5515                 bool* transpose) -> nvinfer1::ITensor* {
5516     if (operand.is_tensor()) {
5517       return operand.tensor();
5518     } else {
5519       TRT_ShapedWeights weights(operand.weights().TrtDType());
5520       if (*transpose) {
5521         weights = params->weight_store->GetTempWeights(operand.weights());
5522         ReorderCKtoKC(operand.weights(), &weights);
5523         // Weights have been transposed, can set transpose to false
5524         *transpose = false;
5525       } else {
5526         weights = operand.weights();
5527       }
5528       return params->converter->CreateConstantLayer(weights, weights.shape_);
5529     }
5530   };
5531 
5532   nvinfer1::ITensor* tensor_a = prepare_matmul_operand(input_a, &transpose_a);
5533   nvinfer1::ITensor* tensor_b = prepare_matmul_operand(input_b, &transpose_b);
5534 
5535   nvinfer1::IMatrixMultiplyLayer* layer =
5536       params->converter->network()->addMatrixMultiply(
5537           *tensor_a, get_matrix_op(tensor_a, transpose_a), *tensor_b,
5538           get_matrix_op(tensor_b, transpose_b));
5539 
5540   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
5541   params->converter->SetLayerName(layer, node_def);
5542   nvinfer1::ITensor* output_tensor = layer->getOutput(0);
5543   params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
5544   return Status::OK();
5545 }
5546 
5547 // inputs are both two dimensional (ops::MatMul)
ConvertMatMul(OpConverterParams * params)5548 Status ConvertMatMul(OpConverterParams* params) {
5549   const auto& inputs = params->inputs;
5550   const auto& node_def = params->node_def;
5551   if (inputs.size() != 2) {
5552     return errors::InvalidArgument(node_def.op(), " got ", inputs.size(),
5553                                    " inputs but expected 2, at ",
5554                                    node_def.name());
5555   }
5556   TF_RETURN_IF_ERROR(
5557       AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
5558 
5559   TFAttrs attrs(node_def);
5560   bool transpose_a = attrs.get<bool>("transpose_a");
5561   bool transpose_b = attrs.get<bool>("transpose_b");
5562 
5563   return ConvertMatMulHelper(params, inputs.at(0), inputs.at(1), transpose_a,
5564                              transpose_b, node_def);
5565 }
5566 
ConvertBatchMatMul(OpConverterParams * params)5567 Status ConvertBatchMatMul(OpConverterParams* params) {
5568   const auto& inputs = params->inputs;
5569   const auto& node_def = params->node_def;
5570   if (inputs.size() != 2) {
5571     return errors::InvalidArgument(node_def.op(), " got ", inputs.size(),
5572                                    " inputs but expected 2, at ",
5573                                    node_def.name());
5574   }
5575   // TODO(tmorris): Enable once false is updated to mean either tensor or weight
5576   // TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}, {"y",
5577   // false}}));
5578   TF_RETURN_IF_ERROR(
5579       AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
5580   if (inputs.at(0).is_weights() && inputs.at(1).is_weights()) {
5581     return errors::InvalidArgument(
5582         "All inputs are weights, but Grappler is expected to fold them.");
5583   }
5584   if (inputs.at(0).is_tensor() && inputs.at(1).is_tensor() &&
5585       inputs.at(0).GetTrtDims().nbDims != inputs.at(1).GetTrtDims().nbDims) {
5586     return errors::Unimplemented(
5587         "Inputs must have the same rank if they are both tensors.");
5588   }
5589 
5590   TFAttrs attrs(node_def);
5591   const bool transpose_a = attrs.get<bool>("adj_x");
5592   const bool transpose_b = attrs.get<bool>("adj_y");
5593 
5594   // There is no way to batch constants in TRT. Example:
5595   // Tensor with TF Dims: 12 5 3 -> TRT Dims: 5 3
5596   // Weight with TF Dims: 12 3 6 -> TRT Dims: 12 3 6
5597   // It is not possible to treat the weight input as a batched [3, 6] tensor.
5598   const auto check_weight_is_not_batched =
5599       [](const TRT_TensorOrWeights& input_l,
5600          const TRT_TensorOrWeights& input_r) {
5601         // If input_l is a weight, then input_r must be a tensor because
5602         // otherwise the op would be handled by Grappler.
5603         if (input_l.is_weights() &&
5604             input_l.GetTrtDims().nbDims > input_r.GetTrtDims().nbDims &&
5605             input_l.GetTrtDims().d[0] != 1) {
5606           return errors::Unimplemented(
5607               "TensorRT does not support batched constants.");
5608         }
5609         return Status::OK();
5610       };
5611   TF_RETURN_IF_ERROR(check_weight_is_not_batched(inputs.at(0), inputs.at(1)));
5612   TF_RETURN_IF_ERROR(check_weight_is_not_batched(inputs.at(1), inputs.at(0)));
5613 
5614   // Broadcast inputs. We don't check feasibility since the dimensions in a
5615   // MatMul don't need to match. For example, consider a valid set of inputs
5616   // which would produce an output of shape [N, T, K]:
5617   // input 0: [N, T, C]
5618   // input 1: [1, C, K]
5619   // Since C != K and T != C, check feasiblity would fail.
5620   nvinfer1::Dims broadcasted_dims_l, broadcasted_dims_r;
5621   TF_RETURN_IF_ERROR(GetTrtBroadcastShape(
5622       inputs.at(0), inputs.at(1), /*check_feasibility=*/false,
5623       params->use_implicit_batch, &broadcasted_dims_l, &broadcasted_dims_r));
5624   nvinfer1::ITensor* tensor_l = nullptr;
5625   nvinfer1::ITensor* tensor_r = nullptr;
5626   TF_RETURN_IF_ERROR(
5627       PrepareTensorForShape(params->converter, inputs.at(0), broadcasted_dims_l,
5628                             params->validation_only, &tensor_l, node_def));
5629   TF_RETURN_IF_ERROR(
5630       PrepareTensorForShape(params->converter, inputs.at(1), broadcasted_dims_r,
5631                             params->validation_only, &tensor_r, node_def));
5632   if (params->validation_only) return Status::OK();
5633 
5634   return ConvertMatMulHelper(params, TRT_TensorOrWeights(tensor_l),
5635                              TRT_TensorOrWeights(tensor_r), transpose_a,
5636                              transpose_b, node_def);
5637 }
5638 
ConvertSoftmax(OpConverterParams * params)5639 Status ConvertSoftmax(OpConverterParams* params) {
5640   const auto& inputs = params->inputs;
5641   const auto& node_def = params->node_def;
5642   TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"logits", false}}));
5643   TF_RETURN_IF_ERROR(
5644       AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
5645   nvinfer1::ITensor* tensor = inputs.at(0).tensor();
5646 
5647   const int num_trt_dims = tensor->getDimensions().nbDims;
5648   if (num_trt_dims == 0) {
5649     return errors::InvalidArgument(
5650         "TensorRT Softmax cannot apply on batch dimension, at",
5651         node_def.name());
5652   }
5653   if (params->validation_only) return Status::OK();
5654 
5655   nvinfer1::ISoftMaxLayer* layer =
5656       params->converter->network()->addSoftMax(*tensor);
5657   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
5658   params->converter->SetLayerName(layer, node_def);
5659   // Tensorflow SoftMax assumes applying softmax on the last dimension.
5660   layer->setAxes(1 << (num_trt_dims - 1));
5661 
5662   nvinfer1::ITensor* output_tensor = layer->getOutput(0);
5663   // Quantization range for SoftMax is always (0, 1)
5664   params->converter->ProvideQuantizationRange(output_tensor, 0.0f, 1.0f);
5665   params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
5666   return Status::OK();
5667 }
5668 
ConvertArgMinMax(OpConverterParams * params)5669 Status ConvertArgMinMax(OpConverterParams* params) {
5670   const auto& inputs = params->inputs;
5671   const auto& node_def = params->node_def;
5672   TF_RETURN_IF_ERROR(
5673       CheckInputsWeights(*params, {{"input", false}, {"dimension", true}}));
5674   TF_RETURN_IF_ERROR(
5675       AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
5676   // INT64 outputs are not supported by TRT.
5677   TFAttrs attrs(node_def);
5678   DataType output_dtype = attrs.get<DataType>("output_type");
5679   if (output_dtype != DataType::DT_INT32) {
5680     return errors::Unimplemented("Output type ", DataTypeString(output_dtype),
5681                                  " is not supported, at ", node_def.name());
5682   }
5683   int tf_axis = inputs.at(1).weights().GetSpan<int>()[0];
5684   int trt_axis;
5685   nvinfer1::Dims dims = inputs.at(0).GetTrtDims();
5686   TF_RETURN_IF_ERROR(ConvertAxis(tf_axis, dims.nbDims, node_def.name(),
5687                                  params->use_implicit_batch, &trt_axis));
5688   nvinfer1::TopKOperation topk_op;
5689   if (node_def.op() == "ArgMin") {
5690     topk_op = nvinfer1::TopKOperation::kMIN;
5691   } else if (node_def.op() == "ArgMax") {
5692     topk_op = nvinfer1::TopKOperation::kMAX;
5693   } else {
5694     return errors::InvalidArgument("Unsupported ArgMin/Max operation");
5695   }
5696 
5697 #if !IS_TRT_VERSION_GE(7, 0, 0, 11)
5698   const nvinfer1::Dims trt_dims = params->inputs.at(0).GetTrtDims();
5699   if (trt_dims.nbDims >= 4) {
5700     string trt_dim_str = DebugString(trt_dims);
5701 
5702     return errors::Unimplemented(node_def.op(), "op is not able to support",
5703                                  " tensors with 4+ dimensions (excluding batch",
5704                                  " size). Received: ", trt_dim_str);
5705   }
5706 #endif
5707 
5708   if (params->validation_only) return Status::OK();
5709 
5710   // Use TopK with k = 1. Only indices output is needed (output 1).
5711   const uint32_t reduce_axes = 1 << trt_axis;
5712   nvinfer1::ITopKLayer* layer = params->converter->network()->addTopK(
5713       *inputs.at(0).tensor(), topk_op, 1, reduce_axes);
5714   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
5715   params->converter->SetLayerName(layer, node_def, "topk");
5716   nvinfer1::ITensor* output_indices_tensor = layer->getOutput(1);
5717 
5718   // Squeeze on axis.
5719   std::vector<int> input_dims(dims.d, dims.d + dims.nbDims);
5720   input_dims[trt_axis] = 0;
5721   nvinfer1::ITensor* output_tensor = nullptr;
5722   TF_RETURN_IF_ERROR(params->converter->SqueezeTensor(
5723       output_indices_tensor, &input_dims, params, &output_tensor));
5724   params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
5725 
5726   return Status::OK();
5727 }
5728 
ConvertTopK(OpConverterParams * params)5729 Status ConvertTopK(OpConverterParams* params) {
5730   const auto& inputs = params->inputs;
5731   const auto& node_def = params->node_def;
5732   TF_RETURN_IF_ERROR(
5733       CheckInputsWeights(*params, {{"input", false}, {"k", true}}));
5734   TF_RETURN_IF_ERROR(
5735       AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
5736   TFAttrs attrs(node_def);
5737   const bool sorted = attrs.get<bool>("sorted");
5738   if (!sorted) {
5739     // TensorRT only supports sorted output. Although TensorFlow API
5740     // doesn't specify the order of output elements in case sorted=false,
5741     // but it's safer to not convert because the output of TensorRT might
5742     // be different with TensorFlow which can cause confusion.
5743     return errors::InvalidArgument("Only sorted=True is supported, at",
5744                                    node_def.name());
5745   }
5746 
5747   nvinfer1::ITensor* tensor = inputs.at(0).tensor();
5748   const int num_dims = tensor->getDimensions().nbDims;
5749   if (num_dims == 0) {
5750     return errors::InvalidArgument(
5751         "TensorRT TopK cannot apply on batch dimension, at", node_def.name());
5752   }
5753 
5754   TRT_ShapedWeights k_w = inputs.at(1).weights();
5755   if (k_w.count() != 1) {
5756     return errors::InvalidArgument("k value of TopK should be a scalar, at",
5757                                    node_def.name());
5758   }
5759   // Note that ITopKLayer always have sorted outputs, so we don't need to handle
5760   // the 'sorted' attribute of the node.
5761   if (params->validation_only) return Status::OK();
5762 
5763   const nvinfer1::TopKOperation op = nvinfer1::TopKOperation::kMAX;
5764   const int k = *(static_cast<int*>(k_w.GetValues()));
5765   const uint32_t reduce_axes = 1 << (num_dims - 1);
5766   nvinfer1::ITopKLayer* layer =
5767       params->converter->network()->addTopK(*tensor, op, k, reduce_axes);
5768   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
5769   params->converter->SetLayerName(layer, node_def);
5770 
5771   nvinfer1::ITensor* output_value_tensor = layer->getOutput(0);
5772   nvinfer1::ITensor* output_indices_tensor = layer->getOutput(1);
5773   params->outputs->push_back(TRT_TensorOrWeights(output_value_tensor));
5774   params->outputs->push_back(TRT_TensorOrWeights(output_indices_tensor));
5775   return Status::OK();
5776 }
5777 
ConvertDepthSpaceShuffle(OpConverterParams * params)5778 Status ConvertDepthSpaceShuffle(OpConverterParams* params) {
5779   const auto& inputs = params->inputs;
5780   const auto& node_def = params->node_def;
5781   TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}}));
5782   TF_RETURN_IF_ERROR(AllowDataTypes(
5783       *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}));
5784   TFAttrs attrs(node_def);
5785   const int block_size = attrs.get<int64>("block_size");
5786   if (block_size < 2) {
5787     return errors::InvalidArgument("Block size must be 2 or greater, at ",
5788                                    node_def.name());
5789   }
5790   const string data_format = attrs.get<string>("data_format");
5791   if (data_format != "NCHW" && data_format != "NHWC") {
5792     return errors::Unimplemented("Data format ", data_format,
5793                                  " is not supported, at ", node_def.name());
5794   }
5795   nvinfer1::Dims dims = inputs.at(0).GetTrtDims();
5796   if (dims.nbDims != 3) {
5797     return errors::InvalidArgument("The input to ", node_def.op(),
5798                                    " must be rank 4, at ", node_def.name());
5799   }
5800   const int num_channels = data_format == "NCHW" ? dims.d[0] : dims.d[2];
5801   const int h = data_format == "NCHW" ? dims.d[1] : dims.d[0];
5802   const int w = data_format == "NCHW" ? dims.d[2] : dims.d[1];
5803   // Get shuffle parameters.
5804   nvinfer1::Dims first_shuffle_shape;
5805   nvinfer1::Permutation transpose_perm;
5806   nvinfer1::Dims second_shuffle_shape;
5807   if (node_def.op() == "DepthToSpace") {
5808     if (num_channels % (block_size * block_size) != 0) {
5809       return errors::InvalidArgument(
5810           "Number of channels must be divisible by block_size*block_size, at ",
5811           node_def.name());
5812     }
5813     // First Reshape [C, H, W] - > [r, r, C/(r*r), H, W]
5814     first_shuffle_shape = {
5815         /*nbDims=*/5,
5816         /*d=*/{block_size, block_size, num_channels / (block_size * block_size),
5817                h, w}};
5818     // Transpose [r, r, C/(r*r), H, W] -> [C/(r*r), H, r, W, r]
5819     transpose_perm = {2, 3, 0, 4, 1};
5820     // Second Reshape [C/(r*r), H, r, W, r] -> [C/(r*r), H * r, W * r]
5821     second_shuffle_shape =
5822         nvinfer1::DimsCHW(num_channels / (block_size * block_size),
5823                           h * block_size, w * block_size);
5824   } else if (node_def.op() == "SpaceToDepth") {
5825     if (h % block_size != 0 || w % block_size != 0) {
5826       return errors::InvalidArgument(
5827           "Width and height must be divisible by block_size, at ",
5828           node_def.name());
5829     }
5830     // First Reshape [C, H, W] -> [C, H/r, r, W/r, r]
5831     first_shuffle_shape = {/*nbDims=*/5,
5832                            /*d=*/{num_channels, h / block_size, block_size,
5833                                   w / block_size, block_size}};
5834     // Transpose [C, H/r, r, W/r, r] -> [r, r, C, H/r, W/r]
5835     transpose_perm = {2, 4, 0, 1, 3};
5836     // Second Reshape  [r, r, C, H/r, W/r] -> [C*r*r, H/r, W/r]
5837     second_shuffle_shape = nvinfer1::DimsCHW(
5838         num_channels * block_size * block_size, h / block_size, w / block_size);
5839   }
5840   if (params->validation_only) return Status::OK();
5841 
5842   nvinfer1::IShuffleLayer* first_shuffle =
5843       params->converter->network()->addShuffle(*inputs.at(0).tensor());
5844   TFTRT_RETURN_ERROR_IF_NULLPTR(first_shuffle, node_def.name());
5845   params->converter->SetLayerName(first_shuffle, node_def, "shuffle",
5846                                   /*op_instance=*/0);
5847   if (data_format == "NHWC") {
5848     first_shuffle->setFirstTranspose({2, 0, 1});
5849   }
5850   first_shuffle->setReshapeDimensions(first_shuffle_shape);
5851   first_shuffle->setSecondTranspose(transpose_perm);
5852 
5853   nvinfer1::IShuffleLayer* second_shuffle =
5854       params->converter->network()->addShuffle(*first_shuffle->getOutput(0));
5855   TFTRT_RETURN_ERROR_IF_NULLPTR(second_shuffle, node_def.name());
5856   params->converter->SetLayerName(second_shuffle, node_def, "shuffle",
5857                                   /*op_instance=*/1);
5858   second_shuffle->setReshapeDimensions(second_shuffle_shape);
5859   if (data_format == "NHWC") {
5860     second_shuffle->setSecondTranspose({1, 2, 0});
5861   }
5862 
5863   params->converter->MarkQuantizationRangesAsInferrable(
5864       inputs.at(0).tensor(), first_shuffle->getOutput(0));
5865   params->converter->MarkQuantizationRangesAsInferrable(
5866       first_shuffle->getOutput(0), second_shuffle->getOutput(0));
5867   params->outputs->push_back(TRT_TensorOrWeights(second_shuffle->getOutput(0)));
5868   return Status::OK();
5869 }
5870 
ConvertSquaredDifference(OpConverterParams * params)5871 Status ConvertSquaredDifference(OpConverterParams* params) {
5872   TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}, {"y", false}}));
5873   TF_RETURN_IF_ERROR(
5874       AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
5875   const auto& inputs = params->inputs;
5876   const auto& node_def = params->node_def;
5877   // Broadcast inputs.
5878   nvinfer1::Dims broadcasted_dims_l, broadcasted_dims_r;
5879   TF_RETURN_IF_ERROR(GetTrtBroadcastShape(
5880       inputs.at(0), inputs.at(1), /*check_feasibility=*/true,
5881       params->use_implicit_batch, &broadcasted_dims_l, &broadcasted_dims_r));
5882   nvinfer1::ITensor* tensor_l = nullptr;
5883   nvinfer1::ITensor* tensor_r = nullptr;
5884   TF_RETURN_IF_ERROR(
5885       PrepareTensorForShape(params->converter, inputs.at(0), broadcasted_dims_l,
5886                             params->validation_only, &tensor_l, node_def));
5887   TF_RETURN_IF_ERROR(
5888       PrepareTensorForShape(params->converter, inputs.at(1), broadcasted_dims_r,
5889                             params->validation_only, &tensor_r, node_def));
5890   if (params->validation_only) return Status::OK();
5891 
5892   // Subtract x - y.
5893   nvinfer1::IElementWiseLayer* sub =
5894       params->converter->network()->addElementWise(
5895           *tensor_l, *tensor_r, nvinfer1::ElementWiseOperation::kSUB);
5896   TFTRT_RETURN_ERROR_IF_NULLPTR(sub, node_def.name());
5897   params->converter->SetLayerName(sub, node_def, "sub");
5898 
5899   // Multiply (x - y) * (x - y).
5900   nvinfer1::IElementWiseLayer* mul =
5901       params->converter->network()->addElementWise(
5902           *sub->getOutput(0), *sub->getOutput(0),
5903           nvinfer1::ElementWiseOperation::kPROD);
5904   TFTRT_RETURN_ERROR_IF_NULLPTR(mul, node_def.name());
5905   params->converter->SetLayerName(mul, node_def, "mul");
5906 
5907   params->outputs->push_back(TRT_TensorOrWeights(mul->getOutput(0)));
5908   return Status::OK();
5909 }
5910 
5911 #if IS_TRT_VERSION_GE(7, 1, 3, 0)
ConvertCombinedNMS(OpConverterParams * params)5912 Status ConvertCombinedNMS(OpConverterParams* params) {
5913   TF_RETURN_IF_ERROR(
5914       CheckInputsWeights(*params, {{"boxes", false},
5915                                    {"scores", false},
5916                                    {"max_output_size_per_class", true},
5917                                    {"max_total_size", true},
5918                                    {"iou_threshold", true},
5919                                    {"score_threshold", true}}));
5920   const auto& inputs = params->inputs;
5921   const auto& node_def = params->node_def;
5922 
5923   nvinfer1::ITensor* boxes_tensor = inputs.at(0).tensor();
5924   nvinfer1::ITensor* scores_tensor = inputs.at(1).tensor();
5925   TRT_ShapedWeights output_size_per_class = inputs.at(2).weights();
5926   TRT_ShapedWeights total_size = inputs.at(3).weights();
5927   TRT_ShapedWeights iou_threshold = inputs.at(4).weights();
5928   TRT_ShapedWeights score_threshold = inputs.at(5).weights();
5929 
5930   // Validate tensors and weights (also set some of the needed plugin fields)
5931   const auto boxes_dims = boxes_tensor->getDimensions();
5932   const auto scores_dims = scores_tensor->getDimensions();
5933   if (!params->use_implicit_batch &&
5934       (!HasStaticShape(boxes_dims) || !HasStaticShape(scores_dims))) {
5935     return errors::Unimplemented(
5936         "TensorRT BatchedNMS Plugin requires input with static shape");
5937   }
5938   const int offset = params->use_implicit_batch ? 0 : 1;
5939   if (boxes_dims.nbDims != 3 + offset) {
5940     return errors::InvalidArgument(
5941         "TensorRT BatchedNMS Plugin input boxes must be 4-D including batch ",
5942         node_def.name());
5943   }
5944   const int class_idx = 1 + offset;
5945   const int num_classes = scores_dims.d[class_idx];
5946   const int num_boxes = boxes_dims.d[0 + offset];
5947   bool box_check =
5948       boxes_dims.d[class_idx] == 1 || boxes_dims.d[class_idx] == num_classes;
5949   if (!box_check) {
5950     return errors::InvalidArgument(
5951         "TensorRT BatchedNMS Plugin third dimension of boxes must be either 1 "
5952         "or num_classes ",
5953         node_def.name());
5954   }
5955   if (output_size_per_class.shape_.nbDims != 1) {
5956     return errors::InvalidArgument(
5957         "TensorRT BatchedNMS Plugin max_output_size_per_class must be 0-D ",
5958         node_def.name());
5959   }
5960   int max_size_per_class =
5961       *(static_cast<int*>(output_size_per_class.GetValues()));
5962   if (max_size_per_class <= 0) {
5963     return errors::InvalidArgument(
5964         "TensorRT BatchedNMS Plugin max_output_size_per_class should be > 0",
5965         node_def.name());
5966   }
5967   if (total_size.shape_.nbDims != 1) {
5968     return errors::InvalidArgument(
5969         "TensorRT BatchedNMS Plugin max_total_size must be 0-D ",
5970         node_def.name());
5971   }
5972   int max_total_size = *(static_cast<int*>(total_size.GetValues()));
5973   if (max_total_size <= 0) {
5974     return errors::InvalidArgument(
5975         "TensorRT BatchedNMS Plugin max_total_size should be > 0",
5976         node_def.name());
5977   }
5978   if (iou_threshold.shape_.nbDims != 1) {
5979     return errors::InvalidArgument(
5980         "TensorRT BatchedNMS Plugin iou_threshold must be 0-D ",
5981         node_def.name());
5982   }
5983   float iou_thresh = *(static_cast<float*>(iou_threshold.GetValues()));
5984   if (iou_thresh < 0.0 || iou_thresh > 1.0) {
5985     return errors::InvalidArgument(
5986         "TensorRT BatchedNMS Plugin iou_threshold must be in [0, 1]",
5987         node_def.name());
5988   }
5989   if (score_threshold.shape_.nbDims != 1) {
5990     return errors::InvalidArgument(
5991         "TensorRT BatchedNMS Plugin score_threshold must be 0-D ",
5992         node_def.name());
5993   }
5994 
5995   if (params->validation_only) return Status::OK();
5996 
5997   // TRT op is_normalized=False treats input corrdinates as pixels and
5998   // calculates width/height as (max - min + 1).
5999   //
6000   // TF op CombinedNonMaxSuppression doesn't care about the normalization and
6001   // calculates width/height  as (max-min).
6002   //
6003   // We set is_normalized = true to be consistent with TF IOU calculaton.
6004   const bool is_normalized = true;
6005 
6006   TFAttrs attrs(node_def);
6007   bool share_location = (boxes_dims.d[class_idx] == 1);
6008   const bool pad_per_class = attrs.get<bool>("pad_per_class");
6009   const bool clip_boxes = attrs.get<bool>("clip_boxes");
6010   int keep_top_k = 0;
6011   if (pad_per_class) {
6012     keep_top_k = std::min(max_size_per_class * num_classes, max_total_size);
6013   } else {
6014     keep_top_k = max_total_size;
6015   }
6016   // According to the batchedNMS plugin description we need to set top_k so that
6017   // keep_top_k <= top_k
6018   // https://github.com/NVIDIA/TensorRT/tree/master/plugin/batchedNMSPlugin
6019   const int top_k = std::max(num_boxes, keep_top_k);
6020   float score_thresh = *(static_cast<float*>(score_threshold.GetValues()));
6021   const int background_id = -1;
6022   nvinfer1::PluginField fields[9] = {
6023       nvinfer1::PluginField{"shareLocation", &share_location,
6024                             nvinfer1::PluginFieldType::kINT32, 1},
6025       nvinfer1::PluginField{"backgroundLabelId", &background_id,
6026                             nvinfer1::PluginFieldType::kINT32, 1},
6027       nvinfer1::PluginField{"numClasses", &num_classes,
6028                             nvinfer1::PluginFieldType::kINT32, 1},
6029       nvinfer1::PluginField{"topK", &top_k, nvinfer1::PluginFieldType::kINT32,
6030                             1},
6031       nvinfer1::PluginField{"keepTopK", &keep_top_k,
6032                             nvinfer1::PluginFieldType::kINT32, 1},
6033       nvinfer1::PluginField{"scoreThreshold", &score_thresh,
6034                             nvinfer1::PluginFieldType::kFLOAT32, 1},
6035       nvinfer1::PluginField{"iouThreshold", &iou_thresh,
6036                             nvinfer1::PluginFieldType::kFLOAT32, 1},
6037       nvinfer1::PluginField{"isNormalized", &is_normalized,
6038                             nvinfer1::PluginFieldType::kINT32, 1},
6039       nvinfer1::PluginField{"clipBoxes", &clip_boxes,
6040                             nvinfer1::PluginFieldType::kINT32, 1}};
6041   nvinfer1::PluginFieldCollection fc{9, fields};
6042 
6043   // Get plugin creator
6044   auto creator =
6045       getPluginRegistry()->getPluginCreator("BatchedNMS_TRT", "1", "");
6046   TFTRT_RETURN_ERROR_IF_NULLPTR(creator, node_def.name());
6047 
6048   // Create plugin
6049   TrtUniquePtrType<nvinfer1::IPluginV2> plugin(
6050       creator->createPlugin(node_def.name().c_str(), &fc));
6051   TFTRT_RETURN_ERROR_IF_NULLPTR(plugin, node_def.name());
6052 
6053   // Set plugin inputs
6054   std::vector<nvinfer1::ITensor*> plugin_inputs;
6055   plugin_inputs.push_back(boxes_tensor);
6056   plugin_inputs.push_back(scores_tensor);
6057 
6058   // Add plugin to network
6059   nvinfer1::IPluginV2Layer* layer = params->converter->network()->addPluginV2(
6060       &plugin_inputs[0], static_cast<int>(plugin_inputs.size()), *plugin);
6061   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
6062   params->converter->SetLayerName(layer, node_def, "plugin");
6063 
6064   // Set plugin outputs
6065   nvinfer1::ITensor* output_nmsed_boxes = layer->getOutput(1);
6066 
6067   // TRT6 fixes (removes) the extra last dimension in CombinedNMS outputs
6068   nvinfer1::ITensor* output_num_detections = layer->getOutput(0);
6069   nvinfer1::ITensor* output_nmsed_scores = layer->getOutput(2);
6070   nvinfer1::ITensor* output_nmsed_classes = layer->getOutput(3);
6071 
6072   params->outputs->push_back(TRT_TensorOrWeights(output_nmsed_boxes));
6073   params->outputs->push_back(TRT_TensorOrWeights(output_nmsed_scores));
6074   params->outputs->push_back(TRT_TensorOrWeights(output_nmsed_classes));
6075   params->outputs->push_back(TRT_TensorOrWeights(output_num_detections));
6076 
6077   return Status::OK();
6078 }
6079 #endif  // IS_TRT_VERSION_GE(7, 1, 3, 0)
6080 
6081 #if IS_TRT_VERSION_GE(6, 0, 0, 0)
ConvertResize(OpConverterParams * params)6082 Status ConvertResize(OpConverterParams* params) {
6083   const auto& inputs = params->inputs;
6084   const auto& node_def = params->node_def;
6085   TF_RETURN_IF_ERROR(
6086       CheckInputsWeights(*params, {{"input", false}, {"size", true}}));
6087   TF_RETURN_IF_ERROR(AllowDataTypes(
6088       *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}));
6089 
6090   // Get input tensor. Transpose it from NHWC to NCHW.
6091   nvinfer1::ITensor* inputs_tensor = inputs.at(0).tensor();
6092 
6093   TFTRT_RETURN_ERROR_IF_NULLPTR(inputs_tensor, params->node_def.name());
6094 
6095   // Get output size. It must constain two values i.e. [H_out, W_out]
6096   TRT_ShapedWeights weights = inputs.at(1).weights();
6097   if (weights.count() != 2) {
6098     return errors::Unimplemented("Resize to shape=[] is not supported, at ",
6099                                  node_def.name());
6100   }
6101   const int* weights_ptr = static_cast<int*>(weights.GetValues());
6102 
6103   // Verify and consume node attributes.
6104   TFAttrs attrs(node_def);
6105   bool align_corners = attrs.get<bool>("align_corners");
6106   TF_RETURN_IF_ERROR(
6107       AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
6108 
6109   // Verify resize mode. Initialize resize mode if supported.
6110   nvinfer1::ResizeMode resize_mode;
6111   if (node_def.op() == "ResizeBilinear") {
6112 #if IS_TRT_VERSION_GE(7, 1, 0, 0)
6113     if (!align_corners) {
6114       return errors::InvalidArgument(
6115           "Cannot Convert Bilinear Resize when align_corners=False");
6116     }
6117 #endif
6118     resize_mode = nvinfer1::ResizeMode::kLINEAR;
6119   } else if (node_def.op() == "ResizeNearestNeighbor") {
6120     resize_mode = nvinfer1::ResizeMode::kNEAREST;
6121   } else {
6122     return errors::Unimplemented(node_def.op(), " is not yet implemented at ",
6123                                  node_def.name());
6124   }
6125 
6126   // Validate inputs_tensor.
6127   // TODO: Allow dynamic shape for input-1 when shape input tensors are handled.
6128   const auto inputs_dims = inputs_tensor->getDimensions();
6129   if (!params->use_implicit_batch && !HasStaticShape(inputs_dims)) {
6130     return errors::Unimplemented(
6131         "TensorRT IResizeLayer requires input with static shape");
6132   }
6133 
6134   // return after validation if only validation is requested.
6135   if (params->validation_only) return Status::OK();
6136 
6137   // Transpose tensor from NHWC to NCHW format.
6138   TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
6139       inputs_tensor, {0, 3, 1, 2}, &inputs_tensor, node_def, "to_NCHW"));
6140 
6141   // Calculate output dimensions.
6142   // Given input dimensions [N, C, H, W] and output size [H_out, W_out],
6143   // output dimensions equals [N, C, H_out, W_out]
6144   nvinfer1::Dims output_dimensions;
6145   output_dimensions.nbDims = inputs_tensor->getDimensions().nbDims;
6146   for (int i = 0; i < output_dimensions.nbDims; ++i) {
6147     output_dimensions.d[i] = inputs_tensor->getDimensions().d[i];
6148   }
6149   output_dimensions.d[output_dimensions.nbDims - 2] = weights_ptr[0];
6150   output_dimensions.d[output_dimensions.nbDims - 1] = weights_ptr[1];
6151 
6152   // Add resize layer.
6153   nvinfer1::IResizeLayer* layer =
6154       params->converter->network()->addResize(*inputs_tensor);
6155   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
6156   params->converter->SetLayerName(layer, node_def);
6157 
6158   // Set layer parameters.
6159   layer->setResizeMode(resize_mode);
6160   layer->setOutputDimensions(output_dimensions);
6161   layer->setAlignCorners(align_corners);
6162 
6163   // Get output tensor. Transpose it from NCHW to NHWC.
6164   nvinfer1::ITensor* output = layer->getOutput(0);
6165 
6166   TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
6167       output, {0, 2, 3, 1}, &output, node_def, "to_NHWC"));
6168   params->outputs->push_back(TRT_TensorOrWeights(output));
6169   // Success
6170   return Status::OK();
6171 }  // ConvertResize
6172 #endif  // IS_TRT_VERSION_GE(6, 0, 0, 0)
6173 
ConvertAddN(OpConverterParams * params)6174 Status ConvertAddN(OpConverterParams* params) {
6175   const auto& inputs = params->inputs;
6176   const auto& node_def = params->node_def;
6177   TF_RETURN_IF_ERROR(
6178       AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
6179   TFAttrs attrs(node_def);
6180   const int num_inputs = attrs.get<int64>("N");
6181   if (num_inputs < 2) {
6182     return errors::InvalidArgument("AddN requires at least two inputs, at ",
6183                                    node_def.name());
6184   }
6185   if (inputs.size() != num_inputs) {
6186     return errors::InvalidArgument("Got ", inputs.size(),
6187                                    " inputs but expected ", num_inputs, ", at ",
6188                                    node_def.name());
6189   }
6190   for (const auto& input : inputs) {
6191     if (!input.is_tensor() && input.weights().shape_.d[0] != 1) {
6192       return errors::InvalidArgument(
6193           "Weights input to AddN is required to have batch dimension 1.");
6194     }
6195   }
6196   if (params->validation_only) return Status::OK();
6197 
6198   // AddN doesn't support broadcast.
6199   std::vector<nvinfer1::ITensor*> tensor_inputs;
6200   for (const auto& input : inputs) {
6201     if (input.is_tensor()) {
6202       tensor_inputs.push_back(input.tensor());
6203     } else {
6204       auto dims = input.weights().shape_;
6205       TF_RETURN_IF_ERROR(RemoveBatchDimension(&dims));
6206       tensor_inputs.push_back(
6207           params->converter->CreateConstantLayer(input.weights(), dims));
6208     }
6209   }
6210   nvinfer1::ITensor* lhs = tensor_inputs[0];
6211   for (int i = 1; i < num_inputs; ++i) {
6212     nvinfer1::ITensor* rhs = tensor_inputs[i];
6213     nvinfer1::ILayer* layer = params->converter->network()->addElementWise(
6214         *lhs, *rhs, nvinfer1::ElementWiseOperation::kSUM);
6215     TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
6216     params->converter->SetLayerName(layer, node_def, std::to_string(i));
6217     lhs = layer->getOutput(0);
6218   }
6219   params->outputs->push_back(TRT_TensorOrWeights(lhs));
6220   return Status::OK();
6221 }
6222 
RegisterValidatableOpConverters(std::unordered_map<string,OpConverter> * registration)6223 static void RegisterValidatableOpConverters(
6224     std::unordered_map<string, OpConverter>* registration) {
6225   (*registration)["BiasAdd"] = ConvertBiasAdd;
6226 #if IS_TRT_VERSION_GE(5, 1, 2, 0)
6227   (*registration)["ClipByValue"] = ConvertClipByValue;
6228 #endif
6229 #if IS_TRT_VERSION_GE(7, 1, 3, 0)
6230   (*registration)["CombinedNonMaxSuppression"] = ConvertCombinedNMS;
6231 #endif
6232   (*registration)["AddN"] = ConvertAddN;
6233   (*registration)["Cast"] = ConvertCast;
6234   (*registration)["ConcatV2"] = ConvertConcat;
6235   (*registration)["Const"] = ConvertConst;
6236   (*registration)["Conv2D"] = ConvertConv2D;
6237   (*registration)["Conv2DBackpropInput"] = ConvertConv2DBackpropInput;
6238   (*registration)["DepthToSpace"] = ConvertDepthSpaceShuffle;
6239   (*registration)["DepthwiseConv2dNative"] = ConvertConv2DDepthwise;
6240   (*registration)["ExpandDims"] = ConvertExpandDims;
6241   (*registration)["FusedConv2DBiasActivation"] =
6242       ConvertFusedConv2DBiasActivation;
6243   (*registration)["GatherV2"] = ConvertGather;
6244   (*registration)["LeakyRelu"] = ConvertLeakyRelu;
6245   (*registration)["MatMul"] = ConvertMatMul;
6246   (*registration)["Pack"] = ConvertPack;
6247   (*registration)["Pad"] = ConvertPad;
6248   (*registration)["Relu6"] = ConvertRelu6;
6249   (*registration)["Reshape"] = ConvertReshape;
6250 #if IS_TRT_VERSION_GE(6, 0, 0, 0)
6251   (*registration)["Conv3D"] = ConvertConv3D;
6252   (*registration)["Conv3DBackpropInputV2"] = ConvertConv3DBackpropInputV2;
6253   for (auto resize_mode : {"ResizeBilinear", "ResizeNearestNeighbor"}) {
6254     (*registration)[resize_mode] = ConvertResize;
6255   }
6256   for (auto pool_op_type : {"AvgPool3D", "MaxPool3D"}) {
6257     (*registration)[pool_op_type] = ConvertPool3D;
6258   }
6259 #endif
6260   (*registration)["Shape"] = ConvertShape;
6261   (*registration)["Rsqrt"] = ConvertRsqrt;
6262   (*registration)["Slice"] = ConvertSlice;
6263   (*registration)["Softmax"] = ConvertSoftmax;
6264   (*registration)["SpaceToDepth"] = ConvertDepthSpaceShuffle;
6265   (*registration)["Split"] = ConvertSplit;
6266   (*registration)["Square"] = ConvertSquare;
6267   (*registration)["SquaredDifference"] = ConvertSquaredDifference;
6268   (*registration)["Squeeze"] = ConvertSqueeze;
6269   (*registration)["StridedSlice"] = ConvertStridedSlice;
6270   (*registration)["TopKV2"] = ConvertTopK;
6271   (*registration)["Transpose"] = ConvertTranspose;
6272   (*registration)["Unpack"] = ConvertUnpack;
6273 
6274   for (auto quantization_op_type :
6275        {"QuantizeAndDequantizeV2", "QuantizeAndDequantizeV3",
6276         "FakeQuantWithMinMaxVars", "FakeQuantWithMinMaxArgs"}) {
6277     (*registration)[quantization_op_type] = ConvertQuantize;
6278   }
6279   for (const auto& binary_op_pair : *BinaryOperationMap()) {
6280     (*registration)[binary_op_pair.first] = ConvertBinary;
6281   }
6282   for (const auto& activation_op_pair : *ActivationTypeMap()) {
6283     (*registration)[activation_op_pair.first] = ConvertActivation;
6284   }
6285   for (auto pool_op_type : {"AvgPool", "MaxPool"}) {
6286     (*registration)[pool_op_type] = ConvertPool;
6287   }
6288   for (auto normalization_op_type :
6289        {"FusedBatchNorm", "FusedBatchNormV2", "FusedBatchNormV3"}) {
6290     (*registration)[normalization_op_type] = ConvertFusedBatchNorm;
6291   }
6292   for (const auto& unary_op_pair : *UnaryOperationMap()) {
6293     (*registration)[unary_op_pair.first] = ConvertUnary;
6294   }
6295   for (auto reduce_op_type : {"Sum", "Prod", "Max", "Min", "Mean"}) {
6296     (*registration)[reduce_op_type] = ConvertReduce;
6297   }
6298   for (auto arg_minmax_type : {"ArgMin", "ArgMax"}) {
6299     (*registration)[arg_minmax_type] = ConvertArgMinMax;
6300   }
6301   // The following are no-ops during inference and will not be mapped to any TRT
6302   // layer.
6303   for (auto identity_op_type : {"Identity", "Snapshot", "StopGradient"}) {
6304     (*registration)[identity_op_type] = ConvertIdentity;
6305   }
6306   for (auto batch_matmul_type : {"BatchMatMul", "BatchMatMulV2"}) {
6307     (*registration)[batch_matmul_type] = ConvertBatchMatMul;
6308   }
6309 }
6310 
RegisterOpValidators()6311 void TrtNodeValidator::RegisterOpValidators() {
6312   RegisterValidatableOpConverters(&op_validators_);
6313 }
6314 
RegisterOpConverters()6315 void Converter::RegisterOpConverters() {
6316   RegisterValidatableOpConverters(&op_registry_);
6317 }
6318 
ConvertGraphDefToEngine(const GraphDef & gdef,TrtPrecisionMode precision_mode,int max_batch_size,size_t max_workspace_size_bytes,const std::vector<PartialTensorShape> & input_shapes,nvinfer1::ILogger * trt_logger,nvinfer1::IGpuAllocator * allocator,TRTInt8Calibrator * calibrator,TrtUniquePtrType<nvinfer1::ICudaEngine> * engine,bool use_calibration,const bool use_implicit_batch,bool * convert_successfully,TrtShapeOptimizationProfile * profiles,absl::string_view engine_name)6319 Status ConvertGraphDefToEngine(
6320     const GraphDef& gdef, TrtPrecisionMode precision_mode, int max_batch_size,
6321     size_t max_workspace_size_bytes,
6322     const std::vector<PartialTensorShape>& input_shapes,
6323     nvinfer1::ILogger* trt_logger, nvinfer1::IGpuAllocator* allocator,
6324     TRTInt8Calibrator* calibrator,
6325     TrtUniquePtrType<nvinfer1::ICudaEngine>* engine, bool use_calibration,
6326     const bool use_implicit_batch, bool* convert_successfully,
6327     TrtShapeOptimizationProfile* profiles, absl::string_view engine_name) {
6328   engine->reset();
6329   if (convert_successfully) *convert_successfully = false;
6330 
6331   // Creating converter, TensorRT builder and network
6332   auto statusor = Converter::Create(precision_mode, use_calibration, trt_logger,
6333                                     use_implicit_batch, engine_name);
6334   TF_RETURN_IF_ERROR(statusor.status());
6335   auto converter = std::move(statusor.ValueOrDie());
6336 
6337   VLOG(1) << "Starting to convert TensorFlow ops to TensorRT layers";
6338   std::vector<Converter::EngineOutputInfo> output_tensors;
6339   int num_layers = converter->network()->getNbLayers();
6340   absl::flat_hash_set<const char*> layer_names;
6341   // Graph nodes are already topologically sorted during construction
6342   for (const auto& node_def : gdef.node()) {
6343     const string& node_name = node_def.name();
6344     VLOG(2) << "Converting node " << node_name << ", op=" << node_def.op();
6345     if (IsEngineInput(node_name)) {
6346       int32 slot_number = -1;
6347       string type_key;
6348       if (node_def.op() == "Placeholder") {
6349         if (!strings::safe_strto32(  // non-absl ok
6350                 node_name.c_str() + strlen(IONamePrefixes::kInputPHName),
6351                 &slot_number)) {
6352           return errors::InvalidArgument("Failed to parse slot number from ",
6353                                          node_name);
6354         }
6355         type_key = "dtype";
6356       } else if (tensorflow::grappler::IsArg(node_def)) {
6357         // Maybe remove the dependence on grappler and re-implement IsArg,
6358         // which is pretty simple (but could change if new Arg nodes are added)
6359         slot_number = node_def.attr().at("index").i();
6360         type_key = "T";
6361       } else {
6362         return errors::InvalidArgument(
6363             "Node ", node_name,
6364             " with is neither Placeholder nor Arg, instead ", node_def.op());
6365       }
6366       nvinfer1::DataType trt_dtype;
6367       nvinfer1::Dims trt_dims;
6368       int batch_size = -1;
6369       auto shape = input_shapes.at(slot_number);
6370       auto status = ValidateTensorProperties(
6371           node_def.op(), node_def.attr().at(type_key).type(), shape,
6372           use_implicit_batch, /*validation_only=*/false, &trt_dtype, &trt_dims,
6373           &batch_size);
6374       if (!status.ok()) {
6375         const string error_message =
6376             StrCat("Validation failed for ", node_name, " and input slot ",
6377                    slot_number, ": ", status.error_message());
6378         LOG_WARNING_WITH_PREFIX << error_message;
6379         return Status(status.code(), error_message);
6380       }
6381       VLOG(2) << "Adding engine input tensor " << node_name << " with shape "
6382               << DebugString(trt_dims);
6383       // TODO(laigd): the conversion should always happen at runtime where all
6384       // the shapes are known, and we can provide a mode to generate the
6385       // engines offline, by calling sess.run() and cache/serialize the engines.
6386       TF_RETURN_IF_ERROR(converter->AddInputTensor(node_name, trt_dtype,
6387                                                    trt_dims, batch_size));
6388     } else if (IsEngineOutput(node_name)) {
6389       int32 slot_number = -1;
6390       if (node_def.op() == "Identity") {
6391         if (!strings::safe_strto32(  // non-absl ok
6392                 node_name.c_str() + strlen(IONamePrefixes::kOutputPHName),
6393                 &slot_number)) {
6394           return errors::InvalidArgument("Failed to parse slot number from ",
6395                                          node_name);
6396         }
6397       } else if (tensorflow::grappler::IsRetval(node_def)) {
6398         slot_number = node_def.attr().at("index").i();
6399       } else {
6400         return errors::InvalidArgument(
6401             "Node with name ", node_name,
6402             " starting with IONamePrefixes::kOutputPHName is "
6403             "neither Identity nor Retval, instead ",
6404             node_def.op());
6405       }
6406       // Get output type that TensorFlow expects
6407       TFAttrs attrs(node_def);
6408       DataType tf_dtype = attrs.get<DataType>("T");
6409       nvinfer1::DataType trt_dtype;
6410       TF_RETURN_IF_ERROR(TfTypeToTrtType(tf_dtype, &trt_dtype));
6411       if (output_tensors.size() <= slot_number) {
6412         output_tensors.resize(slot_number + 1);
6413       }
6414       output_tensors.at(slot_number) = {node_def.input(0), node_name,
6415                                         trt_dtype};
6416     } else {
6417       TF_RETURN_IF_ERROR(converter->ConvertNode(node_def));
6418     }
6419 
6420     // To support TF-TRT profiling, we ensure each ILayer has a non-empty name.
6421     // BuildCudaEngine returns an error if there is any ILayer name collision.
6422     // We want to report the error here before BuildCudaEngine in a more
6423     // meaningful way.
6424     int new_num_layers = converter->network()->getNbLayers();
6425     for (int i = num_layers; i < new_num_layers; i++) {
6426       auto layer = converter->network()->getLayer(i);
6427       if (layer->getName() == nullptr ||
6428           !layer_names.insert(layer->getName()).second) {
6429         std::string error_message =
6430             absl::StrCat("Converting node ", node_name, ", op=", node_def.op(),
6431                          layer->getName() ? "create a layer with name collision"
6432                                           : "create a layer without a name");
6433         LOG_WARNING_WITH_PREFIX << error_message;
6434         return errors::Internal(error_message);
6435       }
6436     }
6437     num_layers = new_num_layers;
6438   }
6439   TF_RETURN_IF_ERROR(converter->RenameAndMarkOutputTensors(output_tensors));
6440   if (convert_successfully) *convert_successfully = true;
6441 
6442   // Apply user provided quantization ranges to tensors
6443   converter->MaybeApplyQuantizationRanges();
6444 
6445   // Build the engine.
6446   TF_RETURN_IF_ERROR(converter->BuildCudaEngine(
6447       engine, max_batch_size, max_workspace_size_bytes, allocator, calibrator,
6448       profiles));
6449 
6450   VLOG(1) << "Finished conversion";
6451   return Status::OK();
6452 }
6453 
ConvertSegmentToGraphDef(const Graph * graph,const grappler::GraphProperties & graph_properties,const std::vector<const Node * > & subgraph_nodes,std::vector<EngineConnection> * connections,GraphDef * segment_def,string * scope_name)6454 Status ConvertSegmentToGraphDef(
6455     const Graph* graph, const grappler::GraphProperties& graph_properties,
6456     const std::vector<const Node*>& subgraph_nodes,  // In topological order
6457     std::vector<EngineConnection>* connections, GraphDef* segment_def,
6458     string* scope_name) {
6459   std::set<string> marker_nodes;
6460   // Update connection shapes/data types and add corresponding input/output
6461   // nodes in the segment graphdef.
6462   for (size_t i = 0; i < connections->size(); ++i) {
6463     auto& connection = connections->at(i);
6464     if (connection.is_control_edge()) continue;
6465     auto outside_node = graph->FindNodeId(connection.outside_id);
6466     if (!outside_node) {
6467       // This should never happen, unless the original graph is problematic.
6468       return errors::NotFound("Cannot find node with id ",
6469                               connection.outside_id, " in the graph.");
6470     }
6471     // Updates the shape and data types of input/output connections.
6472     DataType dtype;
6473     PartialTensorShape partial_shape;
6474     if (connection.is_input_edge) {
6475       GetOutputProperties(graph_properties,
6476                           graph->FindNodeId(connection.outside_id),
6477                           connection.outside_port, &partial_shape, &dtype);
6478       connection.outside_shape = partial_shape;
6479     } else {
6480       GetInputProperties(graph_properties,
6481                          graph->FindNodeId(connection.outside_id),
6482                          connection.outside_port, &partial_shape, &dtype);
6483       connection.inside_shape = partial_shape;
6484     }
6485     connection.connection_type = dtype;
6486 
6487     // Add dummy input/output nodes to the segment graphdef.
6488     if (connection.is_input_edge) {
6489       const string node_name =
6490           StrCat(IONamePrefixes::kInputPHName, connection.port_number);
6491       if (marker_nodes.count(node_name)) {
6492         VLOG(1) << "Reusing input " << node_name << " for the edge "
6493                 << connection.outside_node_name << ":"
6494                 << connection.outside_port << " -> "
6495                 << connection.inside_node_name << ":" << connection.inside_port;
6496         continue;
6497       }
6498       marker_nodes.insert(node_name);
6499       auto seg_node = segment_def->add_node();
6500       NodeDefBuilder builder(node_name, "_Arg");
6501       auto status = builder.Attr("shape", partial_shape)
6502                         .Attr("T", dtype)
6503                         .Attr("index", connection.port_number)
6504                         .Finalize(seg_node);
6505       VLOG(1) << "Constructing input " << node_name << " for the edge "
6506               << connection.outside_node_name << ":" << connection.outside_port
6507               << " -> " << connection.inside_node_name << ":"
6508               << connection.inside_port;
6509     } else {
6510       const string node_name =
6511           StrCat(IONamePrefixes::kOutputPHName, connection.port_number);
6512       if (marker_nodes.count(node_name)) {
6513         VLOG(1) << "Reusing output " << node_name << " for the edge "
6514                 << connection.inside_node_name << ":" << connection.inside_port
6515                 << " -> " << connection.outside_node_name << ":"
6516                 << connection.outside_port;
6517         continue;
6518       }
6519       marker_nodes.insert(node_name);
6520       auto seg_node = segment_def->add_node();
6521       NodeDefBuilder builder(node_name, "_Retval");
6522       auto status =
6523           builder.Attr("T", dtype)
6524               .Attr("index", connection.port_number)
6525               .Input(connection.inside_node_name, connection.inside_port, dtype)
6526               .Finalize(seg_node);
6527       VLOG(1) << "Constructing output " << node_name << " for the edge "
6528               << connection.inside_node_name << ":" << connection.inside_port
6529               << " -> " << connection.outside_node_name << ":"
6530               << connection.outside_port;
6531     }
6532   }  // for each connection.
6533 
6534   std::unordered_map<int, int> old_to_new_id_map;
6535   // Copy internal nodes to new graphdef
6536   string local_scope = subgraph_nodes.front()->name();
6537   for (const Node* node : subgraph_nodes) {
6538     local_scope = GetCommonNameScope(local_scope, node->name());
6539     old_to_new_id_map[node->id()] = segment_def->node_size();
6540     auto snode = segment_def->add_node();
6541     *snode = node->def();
6542     VLOG(2) << "Copying " << snode->name() << " to subgraph";
6543   }
6544   // Update the inputs of the new input nodes to point to placeholder nodes.
6545   for (int i = 0; i < connections->size(); ++i) {
6546     auto& connection = connections->at(i);
6547     if (connection.is_control_edge() || !connection.is_input_edge) continue;
6548     auto snode =
6549         segment_def->mutable_node(old_to_new_id_map[connection.inside_id]);
6550     const string arg_name =
6551         StrCat(IONamePrefixes::kInputPHName, connection.port_number);
6552     VLOG(1) << "Updating " << snode->name() << ":" << connection.inside_port
6553             << " from " << snode->input(connection.inside_port) << " to "
6554             << arg_name;
6555     snode->set_input(connection.inside_port, arg_name);
6556   }
6557   std::set<string> subgraph_node_names;
6558   for (const Node* node : subgraph_nodes) {
6559     subgraph_node_names.insert(node->name());
6560   }
6561 
6562   // Remove control inputs that are not inside the segment.
6563   for (int i = 0; i < segment_def->node_size(); ++i) {
6564     auto snode = segment_def->mutable_node(i);
6565     const int input_size = snode->input_size();
6566     int input_idx = 0;
6567     int actual_input_idx = 0;
6568     while (input_idx < input_size) {
6569       TensorId input = ParseTensorName(snode->input(input_idx));
6570       if (!subgraph_node_names.count(
6571               string(input.first.data(), input.first.size())) &&
6572           !IsEngineInput(input.first)) {
6573         if (input.second == Graph::kControlSlot) {
6574           VLOG(1) << "... removing control inputs " << input.first
6575                   << " from subgraph.";
6576           ++input_idx;
6577           continue;
6578         } else {
6579           return errors::InvalidArgument(
6580               "Found non control input outside the segment that is not an "
6581               "engine connection to ",
6582               snode->name(), ": ", input.first);
6583         }
6584       }
6585       if (actual_input_idx != input_idx) {
6586         snode->set_input(actual_input_idx, snode->input(input_idx));
6587       }
6588       ++input_idx;
6589       ++actual_input_idx;
6590     }
6591     for (int remove = input_size - actual_input_idx; remove > 0; --remove) {
6592       snode->mutable_input()->RemoveLast();
6593     }
6594   }
6595   *scope_name = local_scope;
6596   return Status::OK();
6597 }
6598 
operator ()(const Edge * out_edge) const6599 bool OutputEdgeValidator::operator()(const Edge* out_edge) const {
6600   if (out_edge->IsControlEdge()) return true;
6601   if (out_edge->src()->type_string() == "Const") {
6602     VLOG(1) << "--> Need to remove output node " << out_edge->src()->name()
6603             << " which is a Const.";
6604     return false;
6605   }
6606   return true;
6607 }
6608 
6609 }  // namespace convert
6610 }  // namespace tensorrt
6611 }  // namespace tensorflow
6612 
6613 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
6614