1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
17 
18 #include <memory>
19 #include <unordered_map>
20 #include <vector>
21 
22 #include <gmock/gmock.h>
23 #include <gtest/gtest.h>
24 #include "absl/strings/match.h"
25 #include "absl/strings/numbers.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/string_view.h"
28 #include "absl/types/span.h"
29 #include "tensorflow/cc/framework/ops.h"
30 #include "tensorflow/cc/framework/scope.h"
31 #include "tensorflow/cc/ops/nn_ops_internal.h"
32 #include "tensorflow/cc/ops/standard_ops.h"
33 #include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h"
34 #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
35 #include "tensorflow/core/framework/node_def.pb.h"  // NOLINT
36 #include "tensorflow/core/framework/tensor.h"
37 #include "tensorflow/core/framework/tensor.pb.h"  // NOLINT
38 #include "tensorflow/core/framework/tensor_shape.h"
39 #include "tensorflow/core/framework/tensor_testutil.h"
40 #include "tensorflow/core/framework/types.h"
41 #include "tensorflow/core/grappler/costs/graph_properties.h"
42 #include "tensorflow/core/lib/core/status.h"
43 #include "tensorflow/core/lib/core/status_test_util.h"
44 #include "tensorflow/core/lib/strings/str_util.h"
45 #include "tensorflow/core/lib/strings/strcat.h"
46 #include "tensorflow/core/platform/protobuf.h"
47 #include "tensorflow/core/platform/test.h"
48 #include "tensorflow/core/protobuf/config.pb.h"  // NOLINT
49 #include "tensorflow/core/public/session.h"
50 
51 #if GOOGLE_CUDA
52 #if GOOGLE_TENSORRT
53 #include "cuda/include/cuda.h"
54 #include "cuda/include/cuda_runtime_api.h"
55 #include "tensorrt/include/NvInfer.h"
56 
57 namespace tensorflow {
58 namespace tensorrt {
59 namespace convert {
60 
61 using absl::StrCat;
62 using ::testing::ElementsAre;
63 using ::testing::ElementsAreArray;
64 using ::testing::NanSensitiveFloatNear;
65 
66 // TODO(laigd): put this into some test utils file.
ExpectStatus(Status status,error::Code code=error::OK,const char * substr=nullptr)67 void ExpectStatus(Status status, error::Code code = error::OK,
68                   const char* substr = nullptr) {
69   EXPECT_EQ(code, status.code())
70       << status << " vs expected error code \"" << error::Code_Name(code)
71       << "\" and message \"" << substr << "\"";
72   if (substr) {
73     EXPECT_THAT(status.error_message(), ::testing::HasSubstr(substr)) << status;
74   }
75 }
76 
GetTestDims(const std::vector<int> & d)77 nvinfer1::Dims GetTestDims(const std::vector<int>& d) {
78   nvinfer1::Dims dims;
79   dims.nbDims = d.size();
80   for (int i = 0; i < d.size(); ++i) {
81     dims.d[i] = d[i];
82   }
83   return dims;
84 }
85 
TfDataTypeToTrt(DataType tf_dtype)86 nvinfer1::DataType TfDataTypeToTrt(DataType tf_dtype) {
87   switch (tf_dtype) {
88     case DT_FLOAT:
89       return nvinfer1::DataType::kFLOAT;
90     case DT_HALF:
91       return nvinfer1::DataType::kHALF;
92     case DT_INT32:
93       return nvinfer1::DataType::kINT32;
94     default:
95       QCHECK(false) << "Unexpected data type " << DataTypeString(tf_dtype);
96   }
97 }
98 
TrtDataTypeToTf(nvinfer1::DataType trt_dtype)99 DataType TrtDataTypeToTf(nvinfer1::DataType trt_dtype) {
100   switch (trt_dtype) {
101     case nvinfer1::DataType::kFLOAT:
102       return DT_FLOAT;
103     case nvinfer1::DataType::kHALF:
104       return DT_HALF;
105     case nvinfer1::DataType::kINT32:
106       return DT_INT32;
107     default:
108       QCHECK(false) << "Unexpected data type " << static_cast<int>(trt_dtype);
109   }
110 }
111 
MakeNodeDef(const string & name,const string & op,const std::vector<string> & inputs,const std::map<string,AttrValue> attrs={})112 NodeDef MakeNodeDef(const string& name, const string& op,
113                     const std::vector<string>& inputs,
114                     const std::map<string, AttrValue> attrs = {}) {
115   NodeDef node_def;
116   node_def.set_name(name);
117   node_def.set_op(op);
118   for (const string& input : inputs) {
119     node_def.add_input(input);
120   }
121   for (const auto& attr : attrs) {
122     (*node_def.mutable_attr())[attr.first] = attr.second;
123   }
124   return node_def;
125 }
126 
127 template <typename T>
MakeConstNodeDef(const string & name,const std::vector<T> & vals,const TensorShape & shape)128 NodeDef MakeConstNodeDef(const string& name, const std::vector<T>& vals,
129                          const TensorShape& shape) {
130   Scope s = Scope::NewRootScope();
131   Tensor t = test::AsTensor<T>(vals, shape);
132   auto const_op = ops::Const(s.WithOpName(name), t);
133   return const_op.node()->def();
134 }
135 
136 template <typename T>
MakeConstNodeDef(const string & name,const std::vector<T> & vals)137 NodeDef MakeConstNodeDef(const string& name, const std::vector<T>& vals) {
138   TensorShape shape;
139   const std::vector<int32> shape_dims = {static_cast<int32>(vals.size())};
140   TF_EXPECT_OK(TensorShapeUtils::MakeShape(shape_dims, &shape));
141   return MakeConstNodeDef(name, vals, shape);
142 }
143 
TrtDimsEquals(const nvinfer1::Dims & lhs,const nvinfer1::Dims & rhs)144 bool TrtDimsEquals(const nvinfer1::Dims& lhs, const nvinfer1::Dims& rhs) {
145   if (lhs.nbDims != rhs.nbDims) return false;
146   for (int i = 0; i < lhs.nbDims; ++i) {
147     if (lhs.d[i] != rhs.d[i]) return false;
148     // We don't check the types in the tests.
149   }
150   return true;
151 }
152 
TrtDimsEqualsArray(const std::vector<int> & lhs,const nvinfer1::Dims & rhs)153 bool TrtDimsEqualsArray(const std::vector<int>& lhs,
154                         const nvinfer1::Dims& rhs) {
155   return TrtDimsEquals(GetTestDims(lhs), rhs);
156 }
157 
158 // TODO(laigd): define a parameterized matcher that can compare against the
159 // vector.
ExpectTrtDimsEqualsArray(const std::vector<int> & lhs,const nvinfer1::Dims & rhs)160 void ExpectTrtDimsEqualsArray(const std::vector<int>& lhs,
161                               const nvinfer1::Dims& rhs) {
162   EXPECT_TRUE(TrtDimsEqualsArray(lhs, rhs))
163       << "expected: " << DebugString(GetTestDims(lhs)) << "\n"
164       << "  actual: " << DebugString(rhs);
165 }
166 
167 template <typename T>
ExpectArrayNear(const std::vector<T> & lhs,absl::Span<const T> rhs)168 void ExpectArrayNear(const std::vector<T>& lhs, absl::Span<const T> rhs) {
169   ASSERT_EQ(lhs.size(), rhs.size());
170   for (int i = 0; i < lhs.size(); i++) {
171     EXPECT_FLOAT_EQ(lhs[i], rhs[i]);
172   }
173 }
174 
175 // Eigen::half cannot implicitly convert to float which is required for
176 // EXPECT_FLOAT_EQ.
177 template <>
ExpectArrayNear(const std::vector<Eigen::half> & lhs,absl::Span<const Eigen::half> rhs)178 void ExpectArrayNear(const std::vector<Eigen::half>& lhs,
179                      absl::Span<const Eigen::half> rhs) {
180   ASSERT_EQ(lhs.size(), rhs.size());
181   for (int i = 0; i < lhs.size(); i++) {
182     EXPECT_FLOAT_EQ(Eigen::half_impl::half_to_float(lhs[i]),
183                     Eigen::half_impl::half_to_float(rhs[i]));
184   }
185 }
186 
TrtShapedWeightsEquals(const TRT_ShapedWeights & lhs,const TRT_ShapedWeights & rhs)187 bool TrtShapedWeightsEquals(const TRT_ShapedWeights& lhs,
188                             const TRT_ShapedWeights& rhs) {
189   return TrtDimsEquals(lhs.shape_, rhs.shape_) && lhs.type_ == rhs.type_ &&
190          lhs.GetValues() == rhs.GetValues();
191 }
192 
193 template <typename T>
ValidateWeights(const TRT_ShapedWeights & weights,const std::vector<int> & expected_dims,const std::vector<T> & expected_value)194 void ValidateWeights(const TRT_ShapedWeights& weights,
195                      const std::vector<int>& expected_dims,
196                      const std::vector<T>& expected_value) {
197   ExpectTrtDimsEqualsArray(expected_dims, weights.shape_);
198   ASSERT_EQ(expected_value.size(), weights.count()) << weights.DebugString();
199   const T* actual_values = static_cast<const T*>(weights.GetValues());
200   for (int i = 0; i < expected_value.size(); ++i) {
201     EXPECT_EQ(expected_value[i], actual_values[i]);
202   }
203 }
204 
205 // Fake ITensor implementation for testing purposes.
206 class FakeITensor : public nvinfer1::ITensor {
207  public:
FakeITensor()208   FakeITensor() : dynamic_range_(0.0f) {}
209 
FakeITensor(const nvinfer1::Dims & dims)210   FakeITensor(const nvinfer1::Dims& dims) : dims_(dims), dynamic_range_(0.0f) {}
211 
FakeITensor(const std::vector<int> & dims)212   FakeITensor(const std::vector<int>& dims)
213       : dims_(GetTestDims(dims)), dynamic_range_(0.0f) {}
214 
setName(const char * name)215   void setName(const char* name) override { name_ = name; }
216 
getName() const217   const char* getName() const override { return name_.c_str(); }
218 
setDimensions(nvinfer1::Dims dimensions)219   void setDimensions(nvinfer1::Dims dimensions) override { dims_ = dimensions; }
220 
getDimensions() const221   nvinfer1::Dims getDimensions() const override { return dims_; }
222 
setType(nvinfer1::DataType type)223   void setType(nvinfer1::DataType type) override { type_ = type; }
224 
getType() const225   nvinfer1::DataType getType() const override { return type_; }
226 
isNetworkInput() const227   bool isNetworkInput() const override { return false; }
228 
isNetworkOutput() const229   bool isNetworkOutput() const override { return false; }
230 
setBroadcastAcrossBatch(bool broadcastAcrossBatch)231   void setBroadcastAcrossBatch(bool broadcastAcrossBatch) override {}
232 
getBroadcastAcrossBatch() const233   bool getBroadcastAcrossBatch() const override { return false; }
234 
getLocation() const235   nvinfer1::TensorLocation getLocation() const override { return location_; }
236 
setLocation(nvinfer1::TensorLocation location)237   void setLocation(nvinfer1::TensorLocation location) override {
238     location_ = location;
239   }
240 
241 #if IS_TRT_VERSION_GE(5, 0, 0)
setDynamicRange(float min,float max)242   bool setDynamicRange(float min, float max) override {
243     dynamic_range_ = std::max(std::abs(min), std::abs(max));
244     return true;
245   }
246 
getDynamicRange() const247   float getDynamicRange() const override { return dynamic_range_; }
248 #endif
249 
250 #if IS_TRT_VERSION_GE(5, 1, 0)
dynamicRangeIsSet() const251   bool dynamicRangeIsSet() const override { return true; }
252 
resetDynamicRange()253   void resetDynamicRange() override {}
254 
getDynamicRangeMin() const255   float getDynamicRangeMin() const override { return 0.f; }
256 
getDynamicRangeMax() const257   float getDynamicRangeMax() const override { return 0.f; }
258 #endif
259 
260  private:
261   string name_;
262   nvinfer1::Dims dims_;
263   nvinfer1::DataType type_;
264   nvinfer1::TensorLocation location_;
265   float dynamic_range_;
266 };
267 
TEST(TRT_ShapedWeights_Test,Basic)268 TEST(TRT_ShapedWeights_Test, Basic) {
269   // Test constructor with no arguments.
270   {
271     TRT_ShapedWeights weights;
272     TRT_ShapedWeights copy(weights);
273     for (auto ptr : {&weights, &copy}) {
274       nvinfer1::Weights trt_weights = ptr->GetTrtWeights();
275       EXPECT_EQ(nvinfer1::DataType::kFLOAT, trt_weights.type);
276       EXPECT_EQ(nullptr, trt_weights.values);
277       EXPECT_EQ(0, trt_weights.count);
278 
279       EXPECT_EQ(nullptr, ptr->GetValues());
280       EXPECT_EQ(0, ptr->count());
281       EXPECT_EQ(0, ptr->size_bytes());
282     }
283   }
284   // Test constructor with DataType argument.
285   {
286     TRT_ShapedWeights weights(DT_FLOAT);
287     TRT_ShapedWeights copy(weights);
288     for (auto ptr : {&weights, &copy}) {
289       nvinfer1::Weights trt_weights = ptr->GetTrtWeights();
290       EXPECT_EQ(nvinfer1::DataType::kFLOAT, trt_weights.type);
291       EXPECT_EQ(nullptr, trt_weights.values);
292       EXPECT_EQ(0, trt_weights.count);
293 
294       EXPECT_EQ(nullptr, ptr->GetValues());
295       EXPECT_EQ(0, ptr->count());
296       EXPECT_EQ(0, ptr->size_bytes());
297     }
298   }
299   // Test constructor with DataType and nvinfer1::Dims arguments.
300   {
301     TrtWeightStore store;
302     TRT_ShapedWeights weights =
303         store.GetTempWeights(DT_FLOAT, GetTestDims({2, 5}));
304     TRT_ShapedWeights copy(weights);
305     for (auto ptr : {&weights, &copy}) {
306       nvinfer1::Weights trt_weights = ptr->GetTrtWeights();
307       EXPECT_EQ(nvinfer1::DataType::kFLOAT, trt_weights.type);
308       EXPECT_NE(nullptr, trt_weights.values);
309       EXPECT_EQ(10, trt_weights.count);
310 
311       EXPECT_EQ(trt_weights.values, ptr->GetValues());
312       EXPECT_EQ(10, ptr->count());
313       EXPECT_EQ(40, ptr->size_bytes());
314     }
315     // Test that it doesn't copy the underlying buffer.
316     EXPECT_EQ(weights.GetValues(), copy.GetValues());
317   }
318 }
319 
TEST(TRT_TensorOrWeights_Test,Basic)320 TEST(TRT_TensorOrWeights_Test, Basic) {
321   // Test constructor with no arguments.
322   {
323     TRT_TensorOrWeights tw;
324     TRT_TensorOrWeights copy(tw);
325     TRT_TensorOrWeights assigned;
326     assigned = tw;
327     for (auto ptr : {&tw, &copy, &assigned}) {
328       EXPECT_EQ(false, ptr->is_tensor());
329       EXPECT_EQ(false, ptr->is_weights());
330       EXPECT_EQ(-1, ptr->batch_size());
331     }
332   }
333 
334   // Test constructor with ITensor and batch size argument.
335   {
336     nvinfer1::Dims dims;
337     dims.nbDims = 1;
338     dims.d[0] = 1;
339     FakeITensor itensor(dims);
340     TRT_TensorOrWeights tw(&itensor);
341     TRT_TensorOrWeights tw1(&itensor, /*batch_size=*/1);
342 
343     for (auto original_ptr : {&tw, &tw1}) {
344       TRT_TensorOrWeights copy(*original_ptr);
345       TRT_TensorOrWeights assigned;
346       assigned = *original_ptr;
347 
348       for (auto ptr : {original_ptr, &copy, &assigned}) {
349         EXPECT_EQ(true, ptr->is_tensor());
350         EXPECT_EQ(false, ptr->is_weights());
351         if (original_ptr == &tw) {
352           EXPECT_EQ(-1, ptr->batch_size());
353         } else {
354           EXPECT_EQ(1, ptr->batch_size());
355         }
356         EXPECT_EQ(&itensor, ptr->tensor());
357         ExpectTrtDimsEqualsArray({1}, ptr->GetTrtDims());
358       }
359     }
360   }
361   // Test constructor which creates and owns an ITensor.
362   {
363     nvinfer1::Dims dims;
364     dims.nbDims = 1;
365     dims.d[0] = 1;
366     TRT_TensorOrWeights tw(nvinfer1::DataType::kFLOAT, dims, /*batch_size=*/1);
367     TRT_TensorOrWeights copy(tw);
368     TRT_TensorOrWeights assigned;
369     assigned = tw;
370 
371     for (auto ptr : {&tw, &copy, &assigned}) {
372       EXPECT_EQ(true, ptr->is_tensor());
373       EXPECT_EQ(false, ptr->is_weights());
374       EXPECT_EQ(1, ptr->batch_size());
375       EXPECT_NE(nullptr, ptr->tensor());
376       ExpectTrtDimsEqualsArray({1}, ptr->GetTrtDims());
377     }
378   }
379   // Test constructor with TRT_ShapedWeights argument.
380   {
381     TRT_ShapedWeights weights;
382     TRT_TensorOrWeights tw(weights);
383     TRT_TensorOrWeights copy(tw);
384     TRT_TensorOrWeights assigned;
385     assigned = tw;
386     for (auto ptr : {&tw, &copy, &assigned}) {
387       EXPECT_EQ(false, ptr->is_tensor());
388       EXPECT_EQ(true, ptr->is_weights());
389       EXPECT_TRUE(TrtShapedWeightsEquals(weights, ptr->weights()));
390       ExpectTrtDimsEqualsArray({}, ptr->GetTrtDims());
391     }
392   }
393 }
394 
395 class ValidatorTest : public ::testing::Test {
396  public:
op_validators()397   std::unordered_map<string, OpConverter>& op_validators() {
398     return validator_.op_validators_;
399   }
400 
ConvertToTensorOrWeights(const NodeDef & node_def,int output_port,const grappler::GraphProperties & graph_properties,TRT_TensorOrWeights * tensor_or_weights)401   Status ConvertToTensorOrWeights(
402       const NodeDef& node_def, int output_port,
403       const grappler::GraphProperties& graph_properties,
404       TRT_TensorOrWeights* tensor_or_weights) {
405     return validator_.ConvertToTensorOrWeights(
406         node_def, output_port, graph_properties, tensor_or_weights);
407   }
408 
GetQuantizeOps()409   const std::set<string>* GetQuantizeOps() { return validator_.quantize_ops; }
410 
411  protected:
412   TrtNodeValidator validator_;
413 };
414 
TEST_F(ValidatorTest,QuantizeOpsAreRegistered)415 TEST_F(ValidatorTest, QuantizeOpsAreRegistered) {
416   for (const string& quantize_op : *GetQuantizeOps()) {
417     QCHECK(op_validators().count(quantize_op));
418   }
419 }
420 
TEST_F(ValidatorTest,ConvertToTensorOrWeights)421 TEST_F(ValidatorTest, ConvertToTensorOrWeights) {
422   // Convert Const.
423   {
424     NodeDef node_def = MakeConstNodeDef<float>("my_const", {1.0f, 2.0f});
425     TRT_TensorOrWeights output;
426     grappler::GrapplerItem item;
427     grappler::GraphProperties graph_properties(item);
428     ExpectStatus(ConvertToTensorOrWeights(node_def, /*output_port=*/0,
429                                           graph_properties, &output));
430     ValidateWeights<float>(output.weights(), {2}, {1.0, 2.0});
431   }
432 
433   // Helper method to run ConvertToTensorOrWeights() with predefined parameters.
434   auto convert_to_tensor_or_weights = [this](const std::vector<int64>& dims,
435                                              TRT_TensorOrWeights* output) {
436     Scope s = Scope::NewRootScope();
437     const auto attrs = ops::Placeholder::Shape(PartialTensorShape{dims});
438     auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT, attrs);
439     auto add = ops::Add(s.WithOpName("add"), feed, feed);
440 
441     grappler::GrapplerItem item;
442     TF_EXPECT_OK(s.ToGraphDef(&item.graph));
443     grappler::GraphProperties graph_properties(item);
444     TF_EXPECT_OK(graph_properties.InferStatically(true));
445     const NodeDef& node_def = add.operation.node()->def();
446     return this->ConvertToTensorOrWeights(node_def, /*output_port=*/0,
447                                           graph_properties, output);
448   };
449   // Convert non-Const with #dims > nvinfer1::Dims::MAX_DIMS+1.
450   {
451     TRT_TensorOrWeights output;
452     ExpectStatus(
453         convert_to_tensor_or_weights(
454             std::vector<int64>(nvinfer1::Dims::MAX_DIMS + 2, 1), &output),
455         error::OUT_OF_RANGE, "Input tensor rank is greater than 9");
456   }
457   // Convert non-Const with #dims < 2.
458   {
459     TRT_TensorOrWeights output;
460     ExpectStatus(
461         convert_to_tensor_or_weights({1}, &output), error::INVALID_ARGUMENT,
462         "Input tensor with rank<2 is not supported since the first dimension "
463         "is treated as batch dimension by TRT");
464   }
465   // Convert non-Const. We test the case where the non-batch dimemsion is
466   // unknown as well, to make sure the validator allows that.
467   for (const int32 non_batch_dim : {-1, 2}) {
468     const int32 batch_size = 12;
469     TRT_TensorOrWeights output;
470     ExpectStatus(
471         convert_to_tensor_or_weights({batch_size, non_batch_dim}, &output));
472     EXPECT_EQ(true, output.is_tensor());
473     EXPECT_EQ(batch_size, output.batch_size());
474     EXPECT_NE(nullptr, output.tensor());
475     ExpectTrtDimsEqualsArray({non_batch_dim}, output.GetTrtDims());
476   }
477 }
478 
TEST_F(ValidatorTest,ValidateNode)479 TEST_F(ValidatorTest, ValidateNode) {
480   grappler::GrapplerItem item;
481   grappler::GraphProperties graph_properties(item);
482 
483   bool start_conversion = false;
484   bool should_fail = false;
485   auto op_converter = [&start_conversion,
486                        &should_fail](OpConverterParams* params) -> Status {
487     if (should_fail) return errors::InvalidArgument("");
488     if (!params->validation_only) start_conversion = true;
489     return Status::OK();
490   };
491   NodeDef node_def = MakeNodeDef("my_op", "MyOp", {});
492 
493   // Validator not registered.
494   ExpectStatus(validator_.ValidateNode(node_def, {}, TrtPrecisionMode::FP32,
495                                        graph_properties),
496                error::UNIMPLEMENTED, "Op type MyOp is not supported.");
497 
498   // Register validator.
499   op_validators()["MyOp"] = op_converter;
500   TF_EXPECT_OK(validator_.ValidateNode(node_def, {}, TrtPrecisionMode::FP32,
501                                        graph_properties));
502   EXPECT_EQ(false, start_conversion);
503 
504   // Let the converter return error.
505   should_fail = true;
506   ExpectStatus(validator_.ValidateNode(node_def, {}, TrtPrecisionMode::FP32,
507                                        graph_properties),
508                error::INVALID_ARGUMENT);
509 
510   // Test quantization ops, they're only supported in INT8 mode. The success
511   // case is tested in OpConverterTest.ConvertQuantize.
512   node_def = MakeNodeDef("my_op", "FakeQuantWithMinMaxArgs", {});
513   ExpectStatus(validator_.ValidateNode(node_def, {}, TrtPrecisionMode::FP32,
514                                        graph_properties),
515                error::UNIMPLEMENTED,
516                "Op type FakeQuantWithMinMaxArgs is not supported.");
517 }
518 
519 class ConverterTest : public ::testing::Test {
520  public:
ConverterTest()521   ConverterTest() {
522     builder_.reset(nvinfer1::createInferBuilder(logger_));
523     network_.reset(builder_->createNetwork());
524     converter_.reset(new Converter(network_.get(), TrtPrecisionMode::FP32,
525                                    /*use_calibration=*/false));
526     weight_store_ = &converter_->weight_store_;
527   }
528 
AddOpConverter(const string & op_name,OpConverter op_converter)529   void AddOpConverter(const string& op_name, OpConverter op_converter) {
530     converter_->op_registry_[op_name] = op_converter;
531   }
532 
533   // Below we expose private methods of Converter for testing.
534 
MaybeUpdateBatchSize(int batch_size)535   Status MaybeUpdateBatchSize(int batch_size) {
536     return converter_->MaybeUpdateBatchSize(batch_size);
537   }
538 
AddTensorOrWeights(const string & name,TRT_TensorOrWeights input)539   Status AddTensorOrWeights(const string& name, TRT_TensorOrWeights input) {
540     return converter_->AddTensorOrWeights(name, input);
541   }
542 
GetTensorOrWeights(const string & name,TRT_TensorOrWeights * output)543   Status GetTensorOrWeights(const string& name, TRT_TensorOrWeights* output) {
544     return converter_->GetTensorOrWeights(name, output);
545   }
546 
GetInputs(const NodeDef & node_def,std::vector<TRT_TensorOrWeights> * inputs) const547   Status GetInputs(const NodeDef& node_def,
548                    std::vector<TRT_TensorOrWeights>* inputs) const {
549     return converter_->GetInputs(node_def, inputs);
550   }
551 
GetWeightRange(const TRT_ShapedWeights & weights,float * out_min,float * out_max) const552   Status GetWeightRange(const TRT_ShapedWeights& weights, float* out_min,
553                         float* out_max) const {
554     return converter_->GetWeightRange(weights, out_min, out_max);
555   }
556 
PropagateQuantizationRanges()557   void PropagateQuantizationRanges() {
558     converter_->PropagateQuantizationRanges();
559   }
560 
batch_size() const561   int batch_size() const { return converter_->batch_size_; }
562 
quantization_ranges()563   std::unordered_map<nvinfer1::ITensor*, float>& quantization_ranges() {
564     return converter_->quantization_ranges_;
565   }
566 
567  private:
568   Logger logger_;
569   // These members are ordered in a way such that the destruction order is:
570   // converter_ -> network_ -> builder_
571   TrtUniquePtrType<nvinfer1::IBuilder> builder_;
572   TrtUniquePtrType<nvinfer1::INetworkDefinition> network_;
573 
574  protected:
575   std::unique_ptr<Converter> converter_;
576   TrtWeightStore* weight_store_;
577 };
578 
TEST_F(ConverterTest,ConvertNode)579 TEST_F(ConverterTest, ConvertNode) {
580   FakeITensor output_tensors[2];
581   auto op_converter = [&output_tensors](OpConverterParams* params) -> Status {
582     nvinfer1::Dims dims = params->inputs[0].tensor()->getDimensions();
583     for (int i = 0; i < 2; ++i) {
584       dims.d[0] += 1;
585       output_tensors[i].setDimensions(dims);
586       params->outputs->push_back(TRT_TensorOrWeights(&output_tensors[i]));
587     }
588     return Status::OK();
589   };
590   NodeDef node_def = MakeNodeDef("my_op", "MyOp", {"my_input"});
591   TF_EXPECT_OK(converter_->AddInputTensor(
592       "my_input", nvinfer1::DataType::kFLOAT, GetTestDims({123}), 1));
593 
594   // Converter not registered.
595   ExpectStatus(converter_->ConvertNode(node_def), error::UNIMPLEMENTED,
596                "No converter registered for op: MyOp");
597 
598   // Register the converter and retry.
599   AddOpConverter("MyOp", op_converter);
600   TF_EXPECT_OK(converter_->ConvertNode(node_def));
601 
602   TRT_TensorOrWeights actual_output_1;
603   TF_EXPECT_OK(GetTensorOrWeights("my_op", &actual_output_1));
604   EXPECT_EQ(&output_tensors[0], actual_output_1.tensor());
605   EXPECT_EQ(124, actual_output_1.tensor()->getDimensions().d[0]);
606 
607   TRT_TensorOrWeights actual_output_2;
608   TF_EXPECT_OK(GetTensorOrWeights("my_op:1", &actual_output_2));
609   EXPECT_EQ(&output_tensors[1], actual_output_2.tensor());
610   EXPECT_EQ(125, actual_output_2.tensor()->getDimensions().d[0]);
611 }
612 
TEST_F(ConverterTest,AddAndGetInputs)613 TEST_F(ConverterTest, AddAndGetInputs) {
614   NodeDef node_def;
615   node_def.add_input("^control_input");
616   node_def.add_input("input");
617   node_def.add_input("input:0");
618   node_def.add_input("input:1");
619   node_def.add_input("weird_input:2:3:4:0");
620 
621   TF_EXPECT_OK(converter_->AddInputTensor("input", nvinfer1::DataType::kFLOAT,
622                                           GetTestDims({1}), 1));
623   TF_EXPECT_OK(converter_->AddInputTensor("input:1", nvinfer1::DataType::kINT32,
624                                           GetTestDims({2, 3}), 1));
625   TF_EXPECT_OK(converter_->AddInputTensor(
626       "weird_input:2:3:4", nvinfer1::DataType::kHALF, GetTestDims({5, 3}), 1));
627 
628   std::vector<TRT_TensorOrWeights> inputs;
629   TF_EXPECT_OK(GetInputs(node_def, &inputs));
630 
631   EXPECT_EQ(4, inputs.size());
632   EXPECT_EQ(inputs[0].tensor(), inputs[1].tensor());
633 
634   EXPECT_EQ(nvinfer1::DataType::kFLOAT, inputs[0].tensor()->getType());
635   EXPECT_EQ(nvinfer1::DataType::kINT32, inputs[2].tensor()->getType());
636   EXPECT_EQ(nvinfer1::DataType::kHALF, inputs[3].tensor()->getType());
637   ExpectTrtDimsEqualsArray({1}, inputs[0].tensor()->getDimensions());
638   ExpectTrtDimsEqualsArray({2, 3}, inputs[2].tensor()->getDimensions());
639   ExpectTrtDimsEqualsArray({5, 3}, inputs[3].tensor()->getDimensions());
640 }
641 
TEST_F(ConverterTest,RenameAndMarkOutputTensors)642 TEST_F(ConverterTest, RenameAndMarkOutputTensors) {
643   // Test that the tensor are actually named and marked as output after
644   // Converter::RenameAndMarkOutputTensors() is called.
645 
646   // Register a custom converter which shuffles the input. We use it to build a
647   // TRT network whose output will be later marked.
648   std::vector<nvinfer1::ITensor*> output_tensors;
649   auto op_converter = [&output_tensors](OpConverterParams* params) -> Status {
650     nvinfer1::Permutation perm;
651     perm.order[0] = 1;
652     perm.order[1] = 0;
653     for (int i = 0; i < 2; ++i) {
654       nvinfer1::ITensor* input_tensor =
655           const_cast<nvinfer1::ITensor*>(params->inputs[0].tensor());
656       nvinfer1::IShuffleLayer* layer =
657           params->converter->network()->addShuffle(*input_tensor);
658       layer->setFirstTranspose(perm);
659       nvinfer1::ITensor* output_tensor = layer->getOutput(0);
660       params->outputs->emplace_back(output_tensor);
661       output_tensors.push_back(output_tensor);
662     }
663     TRT_ShapedWeights output_weights(DT_FLOAT);
664     params->outputs->emplace_back(output_weights);
665     return Status::OK();
666   };
667   AddOpConverter("MyOp", op_converter);
668 
669   // Run the conversion.
670   NodeDef node_def = MakeNodeDef("my_op", "MyOp", {"my_input"});
671   TF_EXPECT_OK(converter_->AddInputTensor(
672       "my_input", nvinfer1::DataType::kFLOAT, GetTestDims({1, 2}), 1));
673   TF_EXPECT_OK(converter_->ConvertNode(node_def));
674 
675   // Mark a weight as output, should fail.
676   ExpectStatus(
677       converter_->RenameAndMarkOutputTensors({{"my_op:2", "my_output"}}),
678       error::INVALID_ARGUMENT, "Output my_op:2 is weights not tensor");
679 
680   // Mark tensors as output, should pass.
681   TF_EXPECT_OK(converter_->RenameAndMarkOutputTensors(
682       {{"my_op", "my_output"}, {"my_op:1", "my_output_1"}}));
683   EXPECT_EQ(2, output_tensors.size());
684   for (auto output_tensor : output_tensors) {
685     ExpectTrtDimsEqualsArray({2, 1}, output_tensor->getDimensions());
686   }
687   EXPECT_EQ("my_output", string(output_tensors[0]->getName()));
688   EXPECT_EQ("my_output_1", string(output_tensors[1]->getName()));
689 }
690 
TEST_F(ConverterTest,TransposeTensor)691 TEST_F(ConverterTest, TransposeTensor) {
692   nvinfer1::ITensor* input_tensor = converter_->network()->addInput(
693       "", nvinfer1::DataType::kFLOAT, GetTestDims({2, 3, 5}));
694   const nvinfer1::ITensor* output_tensor = nullptr;
695 
696   // Rank doesn't match.
697   ExpectStatus(
698       converter_->TransposeTensor(input_tensor, {0, 1}, &output_tensor),
699       error::INVALID_ARGUMENT,
700       "Rank of perm for transpose does not match with that of the input");
701 
702   // Transpose at batch dimension.
703   ExpectStatus(
704       converter_->TransposeTensor(input_tensor, {1, 0, 2, 3}, &output_tensor),
705       error::UNIMPLEMENTED, "Transpose at batch dimension is not supported.");
706 
707   // OK.
708   TF_EXPECT_OK(
709       converter_->TransposeTensor(input_tensor, {0, 3, 1, 2}, &output_tensor));
710   ExpectTrtDimsEqualsArray({5, 2, 3}, output_tensor->getDimensions());
711 }
712 
TEST_F(ConverterTest,PrepareTensorForShape_Tensor)713 TEST_F(ConverterTest, PrepareTensorForShape_Tensor) {
714   nvinfer1::ITensor* input_tensor = converter_->network()->addInput(
715       "", nvinfer1::DataType::kFLOAT, GetTestDims({2, 3, 5}));
716   TRT_TensorOrWeights tw(input_tensor);
717   const nvinfer1::ITensor* output_tensor = nullptr;
718 
719   for (bool validation_only : {false, true}) {
720     // Shape size doesn't match.
721     ExpectStatus(
722         converter_->PrepareTensorForShape(tw, GetTestDims({2, 3, 6}),
723                                           validation_only, &output_tensor),
724         error::INVALID_ARGUMENT, "Reshape shapes are not compatible");
725 
726     // TODO(aaroey): we should check the case where uninferred dimensions are
727     // not an exact divisor of input dim ensions, e.g. for dims {-1, 7}.
728 
729     // Infer shape, ok.
730     TF_EXPECT_OK(converter_->PrepareTensorForShape(
731         tw, GetTestDims({-1, 2}), validation_only, &output_tensor));
732     if (validation_only) {
733       EXPECT_EQ(nullptr, output_tensor);
734     } else {
735       ExpectTrtDimsEqualsArray({15, 2}, output_tensor->getDimensions());
736     }
737 
738     // Regular shape.
739     TF_EXPECT_OK(converter_->PrepareTensorForShape(
740         tw, GetTestDims({10, 3}), validation_only, &output_tensor));
741     if (validation_only) {
742       EXPECT_EQ(nullptr, output_tensor);
743     } else {
744       ExpectTrtDimsEqualsArray({10, 3}, output_tensor->getDimensions());
745     }
746   }
747 }
748 
TEST_F(ConverterTest,PrepareTensorForShape_Weights)749 TEST_F(ConverterTest, PrepareTensorForShape_Weights) {
750   TRT_ShapedWeights weights =
751       weight_store_->GetTempWeights(DT_FLOAT, GetTestDims({2, 3, 5}));
752   TRT_TensorOrWeights tw(weights);
753   const nvinfer1::ITensor* output_tensor = nullptr;
754   for (bool validation_only : {false, true}) {
755     TF_EXPECT_OK(converter_->PrepareTensorForShape(
756         tw, GetTestDims({10, 3}), validation_only, &output_tensor));
757     if (validation_only) {
758       EXPECT_EQ(nullptr, output_tensor);
759     } else {
760       ExpectTrtDimsEqualsArray({10, 3}, output_tensor->getDimensions());
761     }
762   }
763 }
764 
TEST_F(ConverterTest,MaybeUpdateBatchSize)765 TEST_F(ConverterTest, MaybeUpdateBatchSize) {
766   EXPECT_EQ(-1, batch_size());
767 
768   TF_EXPECT_OK(MaybeUpdateBatchSize(-1));
769   EXPECT_EQ(-1, batch_size());
770 
771   TF_EXPECT_OK(MaybeUpdateBatchSize(123));
772   EXPECT_EQ(123, batch_size());
773 
774   TF_EXPECT_OK(MaybeUpdateBatchSize(123));
775   EXPECT_EQ(123, batch_size());
776 
777   TF_EXPECT_OK(MaybeUpdateBatchSize(-1));
778   EXPECT_EQ(123, batch_size());
779 
780   ExpectStatus(MaybeUpdateBatchSize(124), error::INVALID_ARGUMENT,
781                "Provided batch size does not match converter batch size");
782 }
783 
TEST_F(ConverterTest,AddAndGetTensorOrWeights)784 TEST_F(ConverterTest, AddAndGetTensorOrWeights) {
785   // Add a tensor.
786   FakeITensor fake_tensor;
787   TRT_TensorOrWeights tensor(&fake_tensor);
788   EXPECT_EQ(-1, tensor.batch_size());
789   TF_EXPECT_OK(MaybeUpdateBatchSize(123));
790   TF_EXPECT_OK(AddTensorOrWeights("my_tensor", tensor));
791 
792   // Get the added tensor.
793   TRT_TensorOrWeights added_tensor;
794   TF_EXPECT_OK(GetTensorOrWeights("my_tensor", &added_tensor));
795   EXPECT_EQ(123, added_tensor.batch_size());
796 
797   // Add the same tensor again.
798   ExpectStatus(AddTensorOrWeights("my_tensor", tensor), error::ALREADY_EXISTS,
799                "tensor/weights my_tensor already exist");
800 }
801 
802 template <typename T>
TestGetWeightRange(ConverterTest * test,TrtWeightStore * weight_store)803 void TestGetWeightRange(ConverterTest* test, TrtWeightStore* weight_store) {
804   TRT_ShapedWeights weights =
805       weight_store->GetTempWeights(DataTypeToEnum<T>::v(), GetTestDims({2, 3}));
806   const std::vector<T> values = {T(3), T(1), T(2), T(6), T(5), T(4)};
807   memcpy(const_cast<void*>(weights.GetValues()), values.data(),
808          weights.size_bytes());
809 
810   float out_min = 0.0f;
811   float out_max = 0.0f;
812   TF_EXPECT_OK(test->GetWeightRange(weights, &out_min, &out_max));
813   EXPECT_EQ(1.0f, out_min);
814   EXPECT_EQ(6.0f, out_max);
815 }
816 
TEST_F(ConverterTest,GetWeightRange)817 TEST_F(ConverterTest, GetWeightRange) {
818   TestGetWeightRange<float>(this, weight_store_);
819   TestGetWeightRange<Eigen::half>(this, weight_store_);
820   TestGetWeightRange<int32>(this, weight_store_);
821 }
822 
TEST_F(ConverterTest,ProvideQuantizationRange)823 TEST_F(ConverterTest, ProvideQuantizationRange) {
824   FakeITensor fake_tensor;
825   // Assymetric range
826   converter_->ProvideQuantizationRange(&fake_tensor, 0.0f, 6.0f);
827   EXPECT_EQ(6.0f, quantization_ranges()[&fake_tensor]);
828   converter_->ProvideQuantizationRange(&fake_tensor, 1.0f, 6.0f);
829   EXPECT_EQ(6.0f, quantization_ranges()[&fake_tensor]);
830   converter_->ProvideQuantizationRange(&fake_tensor, -8.0f, 6.0f);
831   EXPECT_EQ(8.0f, quantization_ranges()[&fake_tensor]);
832   converter_->ProvideQuantizationRange(&fake_tensor, -8.123f, -6.123f);
833   EXPECT_EQ(8.123f, quantization_ranges()[&fake_tensor]);
834   // Symmetric range
835   converter_->ProvideQuantizationRange(&fake_tensor, -6.123f, 6.123f);
836   EXPECT_EQ(6.123f, quantization_ranges()[&fake_tensor]);
837 }
838 
TEST_F(ConverterTest,MaybeApplyQuantizationRanges)839 TEST_F(ConverterTest, MaybeApplyQuantizationRanges) {
840   // input -> infer1 -> infer2 -> infer3
841   FakeITensor input, infer_1, infer_2, infer_3;
842   FakeITensor not_infer;
843   Converter int8_converter(/*trt_network=*/nullptr, TrtPrecisionMode::INT8,
844                            /*use_calibration=*/true);
845   int8_converter.ProvideQuantizationRange(&input, -5.0f, 5.0f);
846   int8_converter.ProvideQuantizationRange(&not_infer, -100.0f, 100.0f);
847   int8_converter.MarkQuantizationRangesAsInferrable(&input, &infer_1);
848   int8_converter.MarkQuantizationRangesAsInferrable(&infer_1, &infer_2);
849   int8_converter.MarkQuantizationRangesAsInferrable(&infer_2, &infer_3);
850 
851   // Input range should be inferred along the chain and applied to tensors.
852   int8_converter.MaybeApplyQuantizationRanges();
853 #if IS_TRT_VERSION_GE(5, 0, 0)
854   EXPECT_EQ(input.getDynamicRange(), 5.0f);
855   EXPECT_EQ(infer_1.getDynamicRange(), 5.0f);
856   EXPECT_EQ(infer_2.getDynamicRange(), 5.0f);
857   EXPECT_EQ(infer_3.getDynamicRange(), 5.0f);
858   EXPECT_EQ(not_infer.getDynamicRange(), 100.0f);
859 #endif
860 }
861 
TEST_F(ConverterTest,PropagateQuantizationRanges)862 TEST_F(ConverterTest, PropagateQuantizationRanges) {
863   // infer0 <-> infer1 <-> infer2 <-> infer3
864   //              |
865   //            infer4 <-> infer5
866   FakeITensor infer[6];
867   FakeITensor not_infer;
868   converter_->ProvideQuantizationRange(&infer[4], -5.0f, 5.0f);
869   converter_->MarkQuantizationRangesAsInferrable(&infer[0], &infer[1]);
870   converter_->MarkQuantizationRangesAsInferrable(&infer[1], &infer[2]);
871   converter_->MarkQuantizationRangesAsInferrable(&infer[3], &infer[2]);
872   converter_->MarkQuantizationRangesAsInferrable(&infer[4], &infer[1]);
873   converter_->MarkQuantizationRangesAsInferrable(&infer[4], &infer[5]);
874 
875   // Input range should be inferred along the chain.
876   PropagateQuantizationRanges();
877   auto ranges = quantization_ranges();
878   for (int i = 0; i < 6; ++i) {
879     EXPECT_EQ(5.0f, ranges[&infer[i]]);
880   }
881   EXPECT_EQ(ranges.count(&not_infer), 0);
882 }
883 
TEST_F(ConverterTest,GetTrtBroadcastShape)884 TEST_F(ConverterTest, GetTrtBroadcastShape) {
885   const bool kIsTensor = true;
886   const bool kIsNotTensor = false;
887   auto symmetric_test = [this](const std::vector<int>& operand_1_shape,
888                                const std::vector<int>& operand_2_shape,
889                                const bool operand_1_is_tensor,
890                                const bool operand_2_is_tensor,
891                                const std::vector<int>& expected_operand_1_shape,
892                                const std::vector<int>& expected_operand_2_shape,
893                                error::Code expected_code = error::OK,
894                                const char* expected_error_msg_substr = nullptr,
895                                const int operand_1_batch_size = -1,
896                                const int operand_2_batch_size = -1) {
897     auto create_tensor_or_weights = [](const std::vector<int>& shape,
898                                        bool is_tensor, int batch_size = -1) {
899       if (is_tensor) {
900         return TRT_TensorOrWeights{nvinfer1::DataType::kFLOAT,
901                                    GetTestDims(shape), batch_size};
902       }
903       TRT_ShapedWeights weights;
904       weights.shape_ = GetTestDims(shape);
905       return TRT_TensorOrWeights(weights);
906     };
907 
908     nvinfer1::Dims operand_1_new_dims, operand_2_new_dims;
909     TRT_TensorOrWeights operand_1 = create_tensor_or_weights(
910         operand_1_shape, operand_1_is_tensor, operand_1_batch_size);
911     TRT_TensorOrWeights operand_2 = create_tensor_or_weights(
912         operand_2_shape, operand_2_is_tensor, operand_2_batch_size);
913 
914     // operand_1 broadcast operand_2
915     ExpectStatus(
916         this->converter_->GetTrtBroadcastShape(
917             operand_1, operand_2, &operand_1_new_dims, &operand_2_new_dims),
918         expected_code, expected_error_msg_substr);
919     if (expected_code == error::OK) {
920       ExpectTrtDimsEqualsArray(expected_operand_1_shape, operand_1_new_dims);
921       ExpectTrtDimsEqualsArray(expected_operand_2_shape, operand_2_new_dims);
922     }
923     // operand_2 broadcast operand_1
924     ExpectStatus(
925         this->converter_->GetTrtBroadcastShape(
926             operand_2, operand_1, &operand_2_new_dims, &operand_1_new_dims),
927         expected_code, expected_error_msg_substr);
928     if (expected_code == error::OK) {
929       ExpectTrtDimsEqualsArray(expected_operand_1_shape, operand_1_new_dims);
930       ExpectTrtDimsEqualsArray(expected_operand_2_shape, operand_2_new_dims);
931     }
932   };
933 
934   // Both inputs are weights.
935   symmetric_test(
936       {1}, {1}, kIsNotTensor, kIsNotTensor, {}, {}, error::INVALID_ARGUMENT,
937       "Broadcasting requires at least one of the operands be tensors");
938 
939   // One tensor and one weights.
940   symmetric_test({1, 1, 1}, {2}, kIsTensor, kIsNotTensor, {1, 1, 1}, {1, 1, 2});
941   symmetric_test({1, 1, 2}, {2}, kIsTensor, kIsNotTensor, {1, 1, 2}, {1, 1, 2});
942   symmetric_test({1, 3, 2}, {1}, kIsTensor, kIsNotTensor, {1, 3, 2}, {1, 1, 1});
943   symmetric_test({1, 1, 1}, {2, 3}, kIsTensor, kIsNotTensor, {1, 1, 1},
944                  {1, 2, 3});
945   symmetric_test({1, 1, 1}, {2, 3, 4}, kIsTensor, kIsNotTensor, {1, 1, 1},
946                  {2, 3, 4});
947   symmetric_test({1, 1, 1}, {1, 2, 3, 4}, kIsTensor, kIsNotTensor, {1, 1, 1},
948                  {2, 3, 4});
949   symmetric_test({1, 3, 4}, {1, 2, 1, 4}, kIsTensor, kIsNotTensor, {1, 3, 4},
950                  {2, 1, 4});
951   symmetric_test({1, 1, 1}, {2, 1, 1, 1}, kIsTensor, kIsNotTensor, {}, {},
952                  error::INVALID_ARGUMENT, "Infeasible broadcast scheme");
953   symmetric_test({1, 1, 1}, {2, 1, 1, 1}, kIsTensor, kIsNotTensor, {}, {},
954                  error::INVALID_ARGUMENT, "Infeasible broadcast scheme",
955                  /*operand_1_batch_size=*/2);
956   symmetric_test({1, 1, 1}, {1, 1, 1, 1, 1}, kIsTensor, kIsNotTensor, {}, {},
957                  error::INVALID_ARGUMENT,
958                  "Broadcasting beyond batch dimension is not supported "
959                  "(tensor #dims 4 vs broadcast #dims 5)");
960 
961   // Both inputs are tensors.
962   symmetric_test({1, 1, 1}, {1, 1}, kIsTensor, kIsTensor, {}, {},
963                  error::INVALID_ARGUMENT,
964                  "Broadcasting beyond batch dimension is not supported "
965                  "(tensor #dims 3 vs broadcast #dims 4)");
966   symmetric_test({1, 3, 4}, {2, 1, 4}, kIsTensor, kIsTensor, {1, 3, 4},
967                  {2, 1, 4});
968   symmetric_test({1, 1, 1}, {1, 1, 1, 1}, kIsTensor, kIsTensor, {}, {},
969                  error::INVALID_ARGUMENT,
970                  "Broadcasting beyond batch dimension is not supported "
971                  "(tensor #dims 4 vs broadcast #dims 5)");
972 }
973 
TEST_F(ConverterTest,CreateConstantLayer)974 TEST_F(ConverterTest, CreateConstantLayer) {
975   for (auto dtype : {DT_FLOAT, DT_INT32}) {
976     TRT_ShapedWeights weights =
977         weight_store_->GetTempWeights(dtype, GetTestDims({2, 3, 5}));
978     nvinfer1::ITensor* tensor =
979         converter_->CreateConstantLayer(weights, GetTestDims({3, 10}));
980     ASSERT_NE(nullptr, tensor);
981     EXPECT_EQ(TfDataTypeToTrt(dtype), tensor->getType())
982         << "Expected " << DebugString(TfDataTypeToTrt(dtype)) << " vs. actual "
983         << DebugString(tensor->getType());
984     ExpectTrtDimsEqualsArray({3, 10}, tensor->getDimensions());
985   }
986 }
987 
988 class ConvertGraphDefToEngineTest : public ::testing::Test {
989  public:
RunConvertGraphDefToEngine(Scope * s)990   Status RunConvertGraphDefToEngine(Scope* s) {
991     GraphDef gdef;
992     TF_EXPECT_OK(s->ToGraphDef(&gdef));
993     std::vector<PartialTensorShape> input_shapes;
994     int batch_size = -1;
995     for (const NodeDef& node : gdef.node()) {
996       absl::string_view node_name(node.name());
997       if (str_util::ConsumePrefix(&node_name, kInputPHName)) {
998         int port = -1;
999         EXPECT_TRUE(absl::SimpleAtoi(node_name, &port)) << node.name();
1000         if (input_shapes.size() < port + 1) input_shapes.resize(port + 1);
1001         input_shapes[port] =
1002             PartialTensorShape(node.attr().at("shape").shape());
1003         if (batch_size == -1) {
1004           batch_size = input_shapes[port].dim_size(0);
1005         } else {
1006           EXPECT_EQ(batch_size, input_shapes[port].dim_size(0));
1007         }
1008       }
1009     }
1010     // TODO(laigd): execute the engine and get outputs.
1011     return ConvertGraphDefToEngine(
1012         gdef, TrtPrecisionMode::FP32, /*max_batch_size=*/1,
1013         /*max_workspace_size_bytes=*/64 << 20, input_shapes, &logger_,
1014         /*allocator=*/nullptr, /*calibrator=*/nullptr, &engine_,
1015         /*use_calibration=*/false, /*convert_successfully=*/nullptr);
1016   }
1017 
1018  protected:
1019   TrtUniquePtrType<nvinfer1::ICudaEngine> engine_;
1020 
1021  private:
1022   Logger logger_;
1023 };
1024 
TEST_F(ConvertGraphDefToEngineTest,IdentityGraph)1025 TEST_F(ConvertGraphDefToEngineTest, IdentityGraph) {
1026   Scope s = Scope::NewRootScope();
1027   auto input = ops::Placeholder(s.WithOpName(StrCat(kInputPHName, 0)), DT_FLOAT,
1028                                 ops::Placeholder::Shape({1, 1}));
1029   auto output = ops::Identity(s.WithOpName("identity1"), input);
1030   output = ops::Identity(s.WithOpName("identity2"), output);
1031   output = ops::Identity(s.WithOpName(StrCat(kOutputPHName, 0)), output);
1032   // If the converter marks the input tensor as output tensor, the conversion
1033   // below will fail with:
1034   // > TensorRTOutputPH_0 cannot be both input and output
1035   // > Network must have at least one output
1036   TF_EXPECT_OK(RunConvertGraphDefToEngine(&s));
1037 }
1038 
1039 // Input/output data format for OpConverterTest::BuildAndRun().
1040 struct InputOutputData {
Buffertensorflow::tensorrt::convert::InputOutputData1041   void* Buffer() const {
1042     return const_cast<char*>(tensor.tensor_data().data());
1043   }
1044 
TotalBytestensorflow::tensorrt::convert::InputOutputData1045   size_t TotalBytes() const { return tensor.TotalBytes(); }
1046 
1047   const char* name;
1048   Tensor tensor;
1049 };
1050 
1051 template <typename T>
ConstructTensor(int data_size,const T & value=T ())1052 Tensor ConstructTensor(int data_size, const T& value = T()) {
1053   std::vector<T> values(data_size, value);
1054   return test::AsTensor<T>(values);
1055 }
1056 
1057 using DataVec = std::vector<InputOutputData>;
1058 
1059 template <typename T>
GetSpanForData(const InputOutputData & data)1060 inline absl::Span<const T> GetSpanForData(const InputOutputData& data) {
1061   const auto& tensor_map = data.tensor.flat<T>();
1062   return absl::Span<const T>(tensor_map.data(), tensor_map.size());
1063 }
1064 
1065 // Class to test various op converters, using both a TrtNodeValidator and
1066 // Converter.
1067 class OpConverterTest : public ::testing::Test {
1068  public:
OpConverterTest()1069   OpConverterTest() : scope_(Scope::NewRootScope()) {
1070     QCHECK_EQ(0, cudaStreamCreate(&stream_));
1071     Reset();
1072   }
1073 
~OpConverterTest()1074   ~OpConverterTest() override { QCHECK_EQ(0, cudaStreamDestroy(stream_)); }
1075 
GetTensorOrWeights(const string & name,TRT_TensorOrWeights * output)1076   Status GetTensorOrWeights(const string& name, TRT_TensorOrWeights* output) {
1077     return converter_->GetTensorOrWeights(name, output);
1078   }
1079 
Reset()1080   void Reset() {
1081     validator_.reset(nullptr);
1082     converter_.reset(nullptr);
1083 
1084     // Reset the INetworkDefinition.
1085     engine_.reset(nullptr);
1086     network_.reset(nullptr);
1087     builder_.reset(nvinfer1::createInferBuilder(logger_));
1088     network_.reset(builder_->createNetwork());
1089     builder_->setMaxBatchSize(1);
1090     builder_->setMaxWorkspaceSize(1 << 26);
1091 
1092     // Reset the validator and converter.
1093     validator_.reset(new TrtNodeValidator);
1094     converter_.reset(new Converter(network_.get(), precision_mode_to_test_,
1095                                    /*use_calibration=*/false));
1096 
1097     // Reset other related artifacts.
1098     scope_ = Scope::NewRootScope();
1099     validator_inputs_.clear();
1100   }
1101 
CheckDataTypeMatches(const DataVec & datas)1102   void CheckDataTypeMatches(const DataVec& datas) {
1103     for (const auto& data : datas) {
1104       const int input_index = engine_->getBindingIndex(data.name);
1105       ASSERT_NE(-1, input_index);
1106       const nvinfer1::DataType trt_dtype =
1107           engine_->getBindingDataType(input_index);
1108       const DataType tf_dtype = TrtDataTypeToTf(trt_dtype);
1109       ASSERT_EQ(data.tensor.dtype(), tf_dtype)
1110           << DataTypeString(data.tensor.dtype()) << " vs. "
1111           << DataTypeString(tf_dtype);
1112     }
1113   }
1114 
1115   // TODO(laigd): test fp16 and int8 support for more converters.
BuildAndRun(const DataVec & input_data,DataVec * output_data,TrtPrecisionMode precision_mode=TrtPrecisionMode::FP32)1116   void BuildAndRun(const DataVec& input_data, DataVec* output_data,
1117                    TrtPrecisionMode precision_mode = TrtPrecisionMode::FP32) {
1118     // Mark the output tensor as TRT engine output.
1119     std::vector<Converter::EngineOutputInfo> output_info;
1120     for (const auto& data : *output_data) {
1121       output_info.push_back(
1122           {data.name, data.name, TfDataTypeToTrt(data.tensor.dtype())});
1123     }
1124     TF_EXPECT_OK(converter_->RenameAndMarkOutputTensors(output_info));
1125 
1126     // Build the TRT engine.
1127     if (precision_mode == TrtPrecisionMode::FP16) {
1128       builder_->setFp16Mode(true);
1129     } else if (precision_mode == TrtPrecisionMode::INT8) {
1130       // Setting FP16 mode as well allows TRT to also consider FP16 kernels and
1131       // use them in situations where they are faster than INT8 or where INT8 is
1132       // not supported for a given layer.
1133       builder_->setFp16Mode(true);
1134       builder_->setInt8Mode(true);
1135     }
1136     ASSERT_EQ(nullptr, engine_.get());
1137     engine_.reset(builder_->buildCudaEngine(*converter_->network()));
1138     CHECK_NOTNULL(engine_.get());
1139     CheckDataTypeMatches(input_data);
1140     CheckDataTypeMatches(*output_data);
1141 
1142     // Execute the TRT engine.
1143     const int num_bindings = input_data.size() + output_data->size();
1144     std::vector<void*> buffers(num_bindings);
1145 
1146     for (const auto& data : input_data) {
1147       const int input_index = engine_->getBindingIndex(data.name);
1148       ASSERT_EQ(0, cudaMalloc(&buffers[input_index], data.TotalBytes()));
1149       ASSERT_EQ(0, cudaMemcpyAsync(buffers[input_index], data.Buffer(),
1150                                    data.TotalBytes(), cudaMemcpyHostToDevice,
1151                                    stream_));
1152     }
1153     struct SizeAndIndex {
1154       SizeAndIndex(int in_size, int in_index)
1155           : size(in_size), index(in_index) {}
1156       int size;
1157       int index;
1158     };
1159     std::vector<SizeAndIndex> output_infos;
1160     for (const auto& data : *output_data) {
1161       const int output_index = engine_->getBindingIndex(data.name);
1162       output_infos.emplace_back(data.TotalBytes(), output_index);
1163       ASSERT_EQ(0, cudaMalloc(&buffers[output_index], data.TotalBytes()));
1164     }
1165 
1166     ASSERT_EQ(engine_->getNbBindings(), num_bindings);
1167     TrtUniquePtrType<nvinfer1::IExecutionContext> execution_context(
1168         engine_->createExecutionContext());
1169     execution_context->enqueue(/*batchSize=*/1, buffers.data(), stream_,
1170                                nullptr);
1171 
1172     for (int i = 0; i < output_infos.size(); ++i) {
1173       const auto& output_info = output_infos[i];
1174       ASSERT_EQ(0, cudaMemcpyAsync(output_data->at(i).Buffer(),
1175                                    buffers[output_info.index], output_info.size,
1176                                    cudaMemcpyDeviceToHost, stream_));
1177     }
1178     cudaStreamSynchronize(stream_);
1179 
1180     for (int i = 0; i < num_bindings; ++i) {
1181       ASSERT_EQ(0, cudaFree(buffers[i]));
1182     }
1183   }
1184 
HasStaticShape(const nvinfer1::Dims & dims) const1185   bool HasStaticShape(const nvinfer1::Dims& dims) const {
1186     if (dims.nbDims < 0) return false;
1187     for (int i = 0; i < dims.nbDims; ++i) {
1188       if (dims.d[i] < 0) return false;
1189     }
1190     return true;
1191   }
1192 
1193   // Add ITensor for both validation and conversion.
AddTestTensor(const char * name,const std::vector<int32> & dims,int batch_size=1,nvinfer1::DataType trt_dtype=nvinfer1::DataType::kFLOAT)1194   void AddTestTensor(
1195       const char* name, const std::vector<int32>& dims, int batch_size = 1,
1196       nvinfer1::DataType trt_dtype = nvinfer1::DataType::kFLOAT) {
1197     DataType tf_dtype = TrtDataTypeToTf(trt_dtype);
1198     ops::Placeholder::Attrs attrs;
1199     TF_EXPECT_OK(TensorShapeUtils::MakeShape(dims, &attrs.shape_));
1200     attrs.shape_.InsertDim(0, batch_size);
1201     auto input = ops::Placeholder(scope_.WithOpName(name), tf_dtype, attrs);
1202     validator_inputs_[name] = input.operation.node()->def();
1203 
1204     // Add a real ITensor for conversion conditionally.
1205     const nvinfer1::Dims trt_dims = GetTestDims(dims);
1206     if (HasStaticShape(trt_dims)) {
1207       TF_EXPECT_OK(
1208           converter_->AddInputTensor(name, trt_dtype, trt_dims, batch_size));
1209       ASSERT_EQ(batch_size, converter_->batch_size_);
1210     }
1211   }
1212 
1213   // Add weights for both validation and conversion.
1214   template <typename T>
AddTestWeights(const char * name,const std::vector<int> & dims,const std::vector<T> & values)1215   void AddTestWeights(const char* name, const std::vector<int>& dims,
1216                       const std::vector<T>& values) {
1217     const DataType dtype = DataTypeToEnum<T>::v();
1218     const nvinfer1::Dims trt_dims = GetTestDims(dims);
1219     const int64_t num_elements = TrtDimsNumElements(trt_dims);
1220     QCHECK_EQ(num_elements, values.size())
1221         << num_elements << " vs " << values.size();
1222     TRT_ShapedWeights weights(dtype);
1223     if (num_elements) {
1224       weights = converter_->weight_store_.GetTempWeights(dtype, trt_dims);
1225       QCHECK_EQ(weights.size_bytes(), sizeof(T) * values.size())
1226           << weights.size_bytes() << " vs " << sizeof(T) * values.size();
1227       memcpy(const_cast<void*>(weights.GetValues()), values.data(),
1228              weights.size_bytes());
1229     }
1230     // Add weights for validation.
1231     TensorShape shape;
1232     TF_EXPECT_OK(TensorShapeUtils::MakeShape(dims, &shape));
1233     validator_inputs_[name] = MakeConstNodeDef<T>(name, values, shape);
1234     // Add weights for conversion.
1235     TF_EXPECT_OK(
1236         converter_->AddTensorOrWeights(name, TRT_TensorOrWeights{weights}));
1237   }
1238 
1239   // Test validation in validation-only mode.
RunValidation(const NodeDef & node_def,error::Code expected_code=error::OK,const char * expected_msg_substr=nullptr)1240   void RunValidation(const NodeDef& node_def,
1241                      error::Code expected_code = error::OK,
1242                      const char* expected_msg_substr = nullptr) {
1243     std::vector<std::pair<const NodeDef*, int>> input_node_and_ports;
1244     for (const string& input : node_def.input()) {
1245       input_node_and_ports.emplace_back(&validator_inputs_[input], 0);
1246     }
1247     grappler::GrapplerItem item;
1248     TF_EXPECT_OK(scope_.ToGraphDef(&item.graph));
1249     grappler::GraphProperties graph_properties(item);
1250     TF_EXPECT_OK(graph_properties.InferStatically(true));
1251 
1252     ExpectStatus(
1253         validator_->ValidateNode(node_def, input_node_and_ports,
1254                                  precision_mode_to_test_, graph_properties),
1255         expected_code, expected_msg_substr);
1256   }
1257 
RunConversion(const NodeDef & node_def,error::Code expected_code=error::OK,const char * expected_msg_substr=nullptr)1258   void RunConversion(const NodeDef& node_def,
1259                      error::Code expected_code = error::OK,
1260                      const char* expected_msg_substr = nullptr) {
1261     ExpectStatus(converter_->ConvertNode(node_def), expected_code,
1262                  expected_msg_substr);
1263   }
1264 
1265   // Helper method to run both validation and conversion, when the expected
1266   // output are same.
RunValidationAndConversion(const NodeDef & node_def,error::Code expected_code=error::OK,const char * expected_msg_substr=nullptr,bool should_run_conversion=true)1267   void RunValidationAndConversion(const NodeDef& node_def,
1268                                   error::Code expected_code = error::OK,
1269                                   const char* expected_msg_substr = nullptr,
1270                                   bool should_run_conversion = true) {
1271     RunValidation(node_def, expected_code, expected_msg_substr);
1272     if (should_run_conversion) {
1273       RunConversion(node_def, expected_code, expected_msg_substr);
1274     }
1275   }
1276 
1277   // Expose quantization_ranges_ for tests
quantization_ranges()1278   std::unordered_map<nvinfer1::ITensor*, float>& quantization_ranges() {
1279     return converter_->quantization_ranges_;
1280   }
1281 
1282   std::unique_ptr<Converter> converter_;
1283   std::unique_ptr<TrtNodeValidator> validator_;
1284 
1285  protected:
1286   // TODO(laigd): parameterize the test and make the precision mode a parameter.
1287   TrtPrecisionMode precision_mode_to_test_ = TrtPrecisionMode::FP32;
1288 
1289  private:
1290   Logger logger_;
1291   TrtUniquePtrType<nvinfer1::IBuilder> builder_;
1292   TrtUniquePtrType<nvinfer1::INetworkDefinition> network_;
1293   TrtUniquePtrType<nvinfer1::ICudaEngine> engine_;
1294   cudaStream_t stream_;
1295   // Used to create placeholders with shape and data type information. The
1296   // created placeholders will be used as inputs to the node to be verified,
1297   // thus we need the shape and data type information to get a non-empty
1298   // GraphProperties.
1299   // TODO(laigd): consider use this Scope to create the NodeDef to verify.
1300   Scope scope_;
1301   std::unordered_map<string, NodeDef> validator_inputs_;
1302 };
1303 
1304 template <typename T>
CopyTensorElements(const Tensor & tensor,protobuf::RepeatedField<T> * out)1305 void CopyTensorElements(const Tensor& tensor, protobuf::RepeatedField<T>* out) {
1306   out->Clear();
1307   if (tensor.NumElements() == 0) return;
1308 
1309   // TensorProto does not need to have all the elements present and can truncate
1310   // trailing elements with the same value for compressed representation. Such
1311   // elements are derived based on the tensor shape.
1312   const auto flat = tensor.flat<T>();
1313   int64 last_index = 0;
1314   for (int64 i = 0; i < tensor.NumElements(); ++i) {
1315     if (flat(i) != flat(last_index)) {
1316       last_index = i;
1317     }
1318   }
1319 
1320   int num_out_elements = last_index + 1;
1321   out->Reserve(num_out_elements);
1322   out->AddNAlreadyReserved(num_out_elements);
1323   const T* src = flat.data();
1324   T* dst = out->mutable_data();
1325   std::copy(src, src + num_out_elements, dst);
1326 }
1327 
1328 template <DataType dtype, typename InputCType, typename OutputCType>
TestConvertConst(OpConverterTest * test)1329 void TestConvertConst(OpConverterTest* test) {
1330   NodeDef node_def;
1331   node_def.set_name("my_const");
1332   node_def.set_op("Const");
1333 
1334   auto reset_and_test = [&node_def, test](
1335                             const Tensor& tensor, const bool as_tensor_content,
1336                             const std::vector<int>& expected_dims,
1337                             const std::vector<OutputCType>& expected_value) {
1338     test->Reset();
1339 
1340     TensorProto* tensor_attr =
1341         (*node_def.mutable_attr())["value"].mutable_tensor();
1342     tensor_attr->Clear();
1343 
1344     if (as_tensor_content) {
1345       tensor.AsProtoTensorContent(tensor_attr);
1346     } else {
1347       tensor.shape().AsProto(tensor_attr->mutable_tensor_shape());
1348       tensor_attr->set_dtype(tensor.dtype());
1349 
1350       if (tensor.dtype() == DT_FLOAT) {
1351         CopyTensorElements<float>(tensor, tensor_attr->mutable_float_val());
1352       } else if (tensor.dtype() == DT_INT32) {
1353         CopyTensorElements<int32>(tensor, tensor_attr->mutable_int_val());
1354       } else {
1355         tensor.AsProtoField(tensor_attr);
1356       }
1357     }
1358     test->RunValidationAndConversion(node_def);
1359     TRT_TensorOrWeights output;
1360     TF_EXPECT_OK(test->GetTensorOrWeights("my_const", &output));
1361     ValidateWeights(output.weights(), expected_dims, expected_value);
1362   };
1363 
1364   auto& attr = *node_def.mutable_attr();
1365   attr["dtype"].set_type(dtype);
1366   {
1367     // By default empty tensor will pick DT_FLOAT as data type and we fix it
1368     // here.
1369     Tensor t(dtype);  // Empty tensor.
1370     reset_and_test(t, false, {}, {});
1371   }
1372   {
1373     Tensor t = test::AsScalar<InputCType>(12);
1374     reset_and_test(t, false, {1}, {12});
1375     reset_and_test(t, true, {1}, {12});
1376   }
1377   {
1378     Tensor t = test::AsTensor<InputCType>({1, 2});
1379     reset_and_test(t, false, {2}, {1, 2});
1380     reset_and_test(t, true, {2}, {1, 2});
1381   }
1382   {
1383     Tensor t =
1384         test::AsTensor<InputCType>({1, 2, 3, 4, 5, 6}, TensorShape({2, 3}));
1385     reset_and_test(t, false, {2, 3}, {1, 2, 3, 4, 5, 6});
1386     reset_and_test(t, true, {2, 3}, {1, 2, 3, 4, 5, 6});
1387   }
1388   {
1389     // Set all tensor elements to the same value. Such tensors are encoded
1390     // using a single element list in tensor proto.
1391     Tensor t =
1392         test::AsTensor<InputCType>({1, 1, 1, 1, 1, 1}, TensorShape({2, 3}));
1393     reset_and_test(t, false, {2, 3}, {1, 1, 1, 1, 1, 1});
1394     reset_and_test(t, true, {2, 3}, {1, 1, 1, 1, 1, 1});
1395   }
1396   {
1397     // Set trailing tensor elements to the same value. Such tensors are
1398     // encoded by truncating all equal elements except the first one.
1399     Tensor t =
1400         test::AsTensor<InputCType>({2, 2, 1, 1, 1, 1}, TensorShape({2, 3}));
1401     reset_and_test(t, false, {2, 3}, {2, 2, 1, 1, 1, 1});
1402     reset_and_test(t, true, {2, 3}, {2, 2, 1, 1, 1, 1});
1403   }
1404 }
1405 
TEST_F(OpConverterTest,ConvertConst)1406 TEST_F(OpConverterTest, ConvertConst) {
1407   {
1408     Reset();
1409     NodeDef node_def = MakeNodeDef("my_const", "Const", {"input"});
1410     AddTestTensor("input", {1});
1411     RunValidationAndConversion(
1412         node_def, error::INVALID_ARGUMENT,
1413         "Constant node is expected to have empty input list: my_const");
1414   }
1415   {
1416     Reset();
1417     NodeDef node_def = MakeConstNodeDef<double>("my_const", {});
1418     RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
1419                                "Unsupported data type double");
1420   }
1421 
1422   TestConvertConst<DT_FLOAT, float, float>(this);
1423   TestConvertConst<DT_INT8, int8, int32>(this);
1424   TestConvertConst<DT_INT32, int32, int32>(this);
1425 }
1426 
TEST_F(OpConverterTest,ConvertTranspose)1427 TEST_F(OpConverterTest, ConvertTranspose) {
1428   {
1429     // Input list is empty, should fail.
1430     NodeDef node_def = MakeNodeDef("my_transpose", "Transpose", {});
1431     RunValidationAndConversion(
1432         node_def, error::INVALID_ARGUMENT,
1433         "Transpose got 0 inputs but expected 2, at my_transpose");
1434   }
1435 
1436   // Get the NodeDef for Transpose.
1437   Scope s = Scope::NewRootScope();
1438   auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
1439   auto weights = ops::Placeholder(s.WithOpName("weights"), DT_INT32);
1440   auto transpose = ops::Transpose(s.WithOpName("my_transpose"), input, weights);
1441   const NodeDef& node_def = transpose.operation.node()->def();
1442 
1443   {
1444     // Permutation is a tensor, should fail.
1445     Reset();
1446     AddTestTensor("input", {1, 2, 3});
1447     AddTestTensor("weights", {3});
1448     RunValidationAndConversion(
1449         node_def, error::UNIMPLEMENTED,
1450         "The input \"perm\" for Transpose must be a constant, at my_transpose");
1451   }
1452   {
1453     // Transpose at batch dimension, should fail.
1454     Reset();
1455     AddTestTensor("input", {1, 2, 3});
1456     AddTestWeights<int32>("weights", {4}, {1, 0, 2, 3});
1457     RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
1458                                "Transpose at batch dimension is not supported");
1459   }
1460   {
1461     // Permutation rank doesn't match, should fail.
1462     Reset();
1463     AddTestTensor("input", {1, 2, 3});
1464     AddTestWeights<int32>("weights", {3}, {0, 1, 2});
1465     RunValidationAndConversion(
1466         node_def, error::INVALID_ARGUMENT,
1467         "Rank of perm for transpose does not match with that of the input.");
1468   }
1469   {
1470     // Ok.
1471     Reset();
1472     AddTestTensor("input", {1, 2, 3});
1473     AddTestWeights<int32>("weights", {4}, {0, 3, 1, 2});
1474     RunValidationAndConversion(node_def);
1475     TRT_TensorOrWeights output;
1476     TF_EXPECT_OK(GetTensorOrWeights("my_transpose", &output));
1477     EXPECT_TRUE(output.is_tensor());
1478     ExpectTrtDimsEqualsArray({3, 1, 2}, output.tensor()->getDimensions());
1479 
1480     const DataVec input_data{
1481         {"input", test::AsTensor<float>({1, 2, 3, 4, 5, 6})}};
1482     DataVec output_data{{"my_transpose", ConstructTensor<float>(6)}};
1483     BuildAndRun(input_data, &output_data);
1484     EXPECT_THAT(GetSpanForData<float>(output_data[0]),
1485                 ElementsAre(1, 4, 2, 5, 3, 6));
1486   }
1487 }
1488 
TEST_F(OpConverterTest,ConvertReshape)1489 TEST_F(OpConverterTest, ConvertReshape) {
1490   {
1491     // Input list is empty, should fail.
1492     NodeDef node_def = MakeNodeDef("my_reshape", "Reshape", {});
1493     RunValidationAndConversion(
1494         node_def, error::INVALID_ARGUMENT,
1495         "Reshape got 0 inputs but expected 2, at my_reshape");
1496   }
1497 
1498   // Get the NodeDef for Reshape.
1499   Scope s = Scope::NewRootScope();
1500   auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
1501   auto weights = ops::Placeholder(s.WithOpName("weights"), DT_INT32);
1502   auto reshape = ops::Reshape(s.WithOpName("my_reshape"), input, weights);
1503   const NodeDef& node_def = reshape.operation.node()->def();
1504 
1505   {
1506     // Shape is a tensor, should fail.
1507     Reset();
1508     AddTestTensor("input", {1, 2, 3});
1509     AddTestTensor("weights", {3});
1510     RunValidationAndConversion(
1511         node_def, error::UNIMPLEMENTED,
1512         "The input \"shape\" for Reshape must be a constant, at my_reshape");
1513   }
1514   {
1515     // Reshape to scalar, should fail.
1516     Reset();
1517     AddTestTensor("input", {1, 2, 3});
1518     AddTestWeights<int32>("weights", {0}, {});
1519     RunValidationAndConversion(
1520         node_def, error::UNIMPLEMENTED,
1521         "Reshape to shape=[] is not supported, at my_reshape");
1522   }
1523 
1524   struct TestParams {
1525     int batch_size;
1526     std::vector<int> tensor_dims;
1527     std::vector<int> shape;
1528   };
1529 
1530   // Reshape at batch dimension, should fail.
1531   const int kReshapeBatchDimsCases = 5;
1532   TestParams params[kReshapeBatchDimsCases] = {
1533       TestParams{1, {1, 2, 3}, {3, 1, 1, 2}},
1534       TestParams{1, {1, 2, -1}, {-1, 1, 1, 2}},
1535       TestParams{1, {1, 2, 3}, {-1, 1, 1, 2}},
1536       TestParams{-1, {1, 2, 3}, {1, 1, 1, 2}},
1537       TestParams{-1, {-1, 2, 3}, {1, 1, 1, 6}},  // TODO(laigd): it should pass.
1538   };
1539   for (int i = 0; i < kReshapeBatchDimsCases; ++i) {
1540     Reset();
1541     const std::vector<int>& dims = params[i].tensor_dims;
1542     AddTestTensor("input", dims, params[i].batch_size);
1543     AddTestWeights<int32>("weights", {4}, params[i].shape);
1544     RunValidationAndConversion(
1545         node_def, error::UNIMPLEMENTED,
1546         "Reshape on batch dimension is not supported, at my_reshape",
1547         /*should_run_conversion=*/(dims[0] > 0 && dims[1] > 0 && dims[2] > 0));
1548   }
1549 
1550   // Reshape on non batch dimensions, ok.
1551   const int kReshapeOKCases = 3;
1552   TestParams ok_params[kReshapeOKCases] = {
1553       TestParams{-1, {1, 2, 3}, {-1, 1, 3, 2}},
1554       TestParams{1, {1, 2, 3}, {-1, 1, 3, 2}},
1555       TestParams{1, {1, 2, 3}, {1, 1, 3, 2}},
1556   };
1557   for (int i = 0; i < kReshapeOKCases; ++i) {
1558     Reset();
1559     AddTestTensor("input", ok_params[i].tensor_dims, ok_params[i].batch_size);
1560     AddTestWeights<int32>("weights", {4}, ok_params[i].shape);
1561     RunValidationAndConversion(node_def);
1562     TRT_TensorOrWeights output;
1563     TF_EXPECT_OK(GetTensorOrWeights("my_reshape", &output));
1564     EXPECT_TRUE(output.is_tensor());
1565     ExpectTrtDimsEqualsArray({1, 3, 2}, output.tensor()->getDimensions());
1566 
1567     const DataVec input_data{
1568         {"input", test::AsTensor<float>({1, 2, 3, 4, 5, 6})}};
1569     DataVec output_data{{"my_reshape", ConstructTensor<float>(6)}};
1570     BuildAndRun(input_data, &output_data);
1571     EXPECT_THAT(GetSpanForData<float>(output_data[0]),
1572                 ElementsAre(1, 2, 3, 4, 5, 6));
1573   }
1574 }
1575 
TEST_F(OpConverterTest,ConvertMatMul)1576 TEST_F(OpConverterTest, ConvertMatMul) {
1577   {
1578     // Input list is empty, should fail.
1579     NodeDef node_def = MakeNodeDef("my_matmul", "MatMul", {});
1580     RunValidationAndConversion(
1581         node_def, error::INVALID_ARGUMENT,
1582         "MatMul got 0 inputs but expected 2, at my_matmul");
1583   }
1584 
1585   // Get the NodeDef for MatMul.
1586   auto get_matmul_nodedef = [](DataType dtype, bool transpose_a,
1587                                bool transpose_b) -> NodeDef {
1588     Scope s = Scope::NewRootScope();
1589     auto input = ops::Placeholder(s.WithOpName("input"), dtype);
1590     auto weights = ops::Placeholder(s.WithOpName("weights"), dtype);
1591     const auto matmul_attrs =
1592         ops::MatMul::TransposeA(transpose_a).TransposeB(transpose_b);
1593     auto matmul =
1594         ops::MatMul(s.WithOpName("my_matmul"), input, weights, matmul_attrs);
1595     return matmul.operation.node()->def();
1596   };
1597 
1598   {
1599     // Unsupported data type.
1600     Reset();
1601     NodeDef node_def = get_matmul_nodedef(DT_INT32, false, false);
1602     AddTestTensor("input", {2}, /*batch_size=*/1, nvinfer1::DataType::kINT32);
1603     AddTestWeights<int32>("weights", {2, 1}, {3, 5});
1604     RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
1605                                "Data type int32 is not supported for MatMul, "
1606                                "must be one of [float, half], at my_matmul");
1607   }
1608   // transpose_a is set.
1609   for (bool transpose_b : {false, true}) {
1610     Reset();
1611     NodeDef node_def =
1612         get_matmul_nodedef(DT_FLOAT, /*transpose_a=*/true, transpose_b);
1613     AddTestTensor("input", {2}, /*batch_size=*/1);
1614     AddTestWeights<float>("weights", {2, 2}, {0, 1, 2, 3});
1615     RunValidationAndConversion(
1616         node_def, error::INVALID_ARGUMENT,
1617         "transpose_a is not supported for TensorRT FullyConnected");
1618   }
1619   // OK.
1620   for (bool transpose_b : {false, true}) {
1621     Reset();
1622     NodeDef node_def =
1623         get_matmul_nodedef(DT_FLOAT, /*transpose_a=*/false, transpose_b);
1624     AddTestTensor("input", {2}, /*batch_size=*/1);
1625     AddTestWeights<float>("weights", {2, 2}, {0, 1, 2, 3});
1626     RunValidationAndConversion(node_def);
1627     TRT_TensorOrWeights output;
1628     TF_EXPECT_OK(GetTensorOrWeights("my_matmul", &output));
1629     EXPECT_TRUE(output.is_tensor());
1630     ExpectTrtDimsEqualsArray({2}, output.tensor()->getDimensions());
1631 
1632     const DataVec input_data{{"input", test::AsTensor<float>({0, 1})}};
1633     DataVec output_data{{"my_matmul", ConstructTensor<float>(2)}};
1634     BuildAndRun(input_data, &output_data);
1635     if (transpose_b) {
1636       EXPECT_THAT(GetSpanForData<float>(output_data[0]), ElementsAre(1, 3));
1637     } else {
1638       EXPECT_THAT(GetSpanForData<float>(output_data[0]), ElementsAre(2, 3));
1639     }
1640   }
1641 }
1642 
1643 template <DataType dtype>
TestConvertBiasAdd(OpConverterTest * test)1644 void TestConvertBiasAdd(OpConverterTest* test) {
1645   // Get the NodeDef for BiasAdd.
1646   auto get_biasadd_nodedef = [](const string& data_format) -> NodeDef {
1647     Scope s = Scope::NewRootScope();
1648     auto input = ops::Placeholder(s.WithOpName("input"), dtype);
1649     auto weights = ops::Placeholder(s.WithOpName("weights"), dtype);
1650     const auto biasadd_attrs = ops::BiasAdd::DataFormat(data_format);
1651     auto biasadd =
1652         ops::BiasAdd(s.WithOpName("my_biasadd"), input, weights, biasadd_attrs);
1653     return biasadd.operation.node()->def();
1654   };
1655 
1656   typedef typename EnumToDataType<dtype>::Type CType;
1657   for (const string& data_format : {"NHWC", "NCHW"}) {
1658     for (const int trt_input_rank : {1, 2, 3, 4}) {
1659       test->Reset();
1660       NodeDef node_def = get_biasadd_nodedef(data_format);
1661 
1662       // Add input, dims_array will be like {2, 1, ..., 1, 3}
1663       std::vector<int32> dims_array(trt_input_rank, 1);
1664       if (trt_input_rank == 1) {
1665         dims_array[0] = (data_format == "NHWC" ? 3 : 2);
1666       } else {
1667         dims_array[0] = 2;
1668         dims_array[trt_input_rank - 1] = 3;
1669       }
1670       test->AddTestTensor("input", dims_array, /*batch_size=*/1,
1671                           TfDataTypeToTrt(dtype));
1672 
1673       // Add bias weights.
1674       const int channel_size = (data_format == "NHWC" ? 3 : 2);
1675       std::vector<CType> bias(channel_size);
1676       for (int i = 0; i < channel_size; ++i) {
1677         bias[i] = CType(i + 1);  // bias will be {1, 2, 3, ...}
1678       }
1679       test->AddTestWeights<CType>("weights", {channel_size}, bias);
1680 
1681       // Run the conversion.
1682       test->RunValidationAndConversion(node_def);
1683       TRT_TensorOrWeights output;
1684       TF_EXPECT_OK(test->GetTensorOrWeights("my_biasadd", &output));
1685       EXPECT_TRUE(output.is_tensor());
1686       ExpectTrtDimsEqualsArray(dims_array, output.tensor()->getDimensions());
1687 
1688       // Build and run the engine.
1689       const int num_input = TrtDimsNumElements(GetTestDims(dims_array));
1690       ASSERT_EQ(trt_input_rank > 1 ? 6 : (data_format == "NHWC" ? 3 : 2),
1691                 num_input);
1692 
1693       const DataVec input_data{
1694           {"input", ConstructTensor<CType>(num_input, CType(0))}};
1695       DataVec output_data{{"my_biasadd", ConstructTensor<CType>(num_input)}};
1696       test->BuildAndRun(input_data, &output_data);
1697       if (trt_input_rank == 1) {
1698         if (data_format == "NHWC") {
1699           EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
1700                       ElementsAre(CType(1), CType(2), CType(3)));
1701         } else {
1702           EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
1703                       ElementsAre(CType(1), CType(2)));
1704         }
1705       } else {
1706         if (data_format == "NHWC") {
1707           EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
1708                       ElementsAre(CType(1), CType(2), CType(3), CType(1),
1709                                   CType(2), CType(3)));
1710         } else {
1711           EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
1712                       ElementsAre(CType(1), CType(1), CType(1), CType(2),
1713                                   CType(2), CType(2)));
1714         }
1715       }
1716     }
1717   }
1718 }
1719 
TEST_F(OpConverterTest,ConvertBiasAdd)1720 TEST_F(OpConverterTest, ConvertBiasAdd) {
1721   {
1722     // Input list is empty, should fail.
1723     NodeDef node_def = MakeNodeDef("my_biasadd", "BiasAdd", {});
1724     RunValidationAndConversion(
1725         node_def, error::INVALID_ARGUMENT,
1726         "BiasAdd got 0 inputs but expected 2, at my_biasadd");
1727   }
1728 
1729   // OK. Note that kINT32 is not supported by IScaleLayer, so we don't test
1730   // DT_INT32 type here.
1731   TestConvertBiasAdd<DT_FLOAT>(this);
1732   TestConvertBiasAdd<DT_HALF>(this);
1733 }
1734 
1735 template <typename OpType>
GetBinaryOpNodeDef(const string & input_name_l,const string & input_name_r,DataType dtype)1736 NodeDef GetBinaryOpNodeDef(const string& input_name_l,
1737                            const string& input_name_r, DataType dtype) {
1738   Scope s = Scope::NewRootScope();
1739   auto input_l = ops::Placeholder(s.WithOpName(input_name_l), dtype);
1740   auto input_r = ops::Placeholder(s.WithOpName(input_name_r), dtype);
1741   auto op = OpType(s.WithOpName("my_binary"), input_l, input_r);
1742   return op.operation.node()->def();
1743 }
1744 
CheckAddedLayers(OpConverterTest * test,bool expect_scale_layer)1745 void CheckAddedLayers(OpConverterTest* test, bool expect_scale_layer) {
1746   bool element_wise_layer_found = false;
1747   bool scale_layer_found = false;
1748   for (int i = 0; i < test->converter_->network()->getNbLayers(); i++) {
1749     nvinfer1::ILayer* layer = test->converter_->network()->getLayer(i);
1750     if (dynamic_cast<nvinfer1::IScaleLayer*>(layer)) {
1751       scale_layer_found = true;
1752     } else if (dynamic_cast<nvinfer1::IElementWiseLayer*>(layer)) {
1753       element_wise_layer_found = true;
1754     }
1755   }
1756   EXPECT_EQ(expect_scale_layer, scale_layer_found);
1757   EXPECT_NE(expect_scale_layer, element_wise_layer_found);
1758 }
1759 
1760 template <typename OpType, DataType dtype>
TestBinaryTensorOpWeightNoBroadcast(OpConverterTest * test)1761 void TestBinaryTensorOpWeightNoBroadcast(OpConverterTest* test) {
1762   typedef typename EnumToDataType<dtype>::Type CType;
1763   for (auto swap_inputs : {false, true}) {
1764     test->Reset();
1765     NodeDef node_def;
1766     if (swap_inputs) {
1767       node_def = GetBinaryOpNodeDef<OpType>("weights", "input", dtype);
1768     } else {
1769       node_def = GetBinaryOpNodeDef<OpType>("input", "weights", dtype);
1770     }
1771 
1772     const std::vector<CType> operand1{CType(3), CType(7.5)};
1773     const std::vector<CType> operand2{CType(2), CType(3)};
1774 
1775     // It requires the dims to be at least of rank 3 to apply an IScaleLayer.
1776     test->AddTestTensor("input", /*dims=*/{1, 1, 2}, /*batch_size=*/1,
1777                         TfDataTypeToTrt(dtype));
1778     test->AddTestWeights<CType>("weights", /*dims=*/{1, 1, 2},
1779                                 /*values=*/swap_inputs ? operand1 : operand2);
1780     test->RunValidationAndConversion(node_def);
1781 
1782     // Make sure it does use BinaryTensorOpWeight, not BinaryTensorOpTensor.
1783     CheckAddedLayers(test, /*expect_scale_layer=*/true);
1784 
1785     // Check the dims of the output ITensor.
1786     TRT_TensorOrWeights output;
1787     TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output));
1788     EXPECT_TRUE(output.is_tensor());
1789     ExpectTrtDimsEqualsArray({1, 1, 2}, output.tensor()->getDimensions());
1790 
1791     const DataVec input_data{
1792         {"input", test::AsTensor<CType>(swap_inputs ? operand2 : operand1)}};
1793     DataVec output_data{{"my_binary", ConstructTensor<CType>(2)}};
1794     test->BuildAndRun(
1795         input_data, &output_data,
1796         dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32);
1797     if (node_def.op() == "Add") {
1798       EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
1799                   ElementsAre(CType(5), CType(10.5)));
1800     } else if (node_def.op() == "Sub") {
1801       EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
1802                   ElementsAre(CType(1), CType(4.5)));
1803     } else if (node_def.op() == "Mul") {
1804       EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
1805                   ElementsAre(CType(6), CType(22.5)));
1806     } else if (node_def.op() == "Div") {
1807       EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
1808                   ElementsAre(CType(1.5), CType(2.5)));
1809     } else if (node_def.op() == "RealDiv") {
1810       EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
1811                   ElementsAre(CType(1.5), CType(2.5)));
1812     } else {
1813       ASSERT_TRUE(false);
1814     }
1815   }
1816 }
1817 
1818 template <DataType dtype>
TestBinaryTensorOpWeightWithChannelWiseBroadcast(OpConverterTest * test)1819 void TestBinaryTensorOpWeightWithChannelWiseBroadcast(OpConverterTest* test) {
1820   typedef typename EnumToDataType<dtype>::Type CType;
1821   const NodeDef node_def =
1822       GetBinaryOpNodeDef<ops::Add>("input", "weights", dtype);
1823   const std::vector<CType> input{CType(1), CType(2), CType(3), CType(4)};
1824   const std::vector<CType> weights{CType(10), CType(20)};
1825   // There are two types of valid dim pairs which requires channel-wise
1826   // broadcasting:
1827   // - input dims (X Y Z) vs weights dims (X 1 1)
1828   // - input dims (X Y Z) vs weights dims (Z)
1829   // Here X=Z=2 and Y=1.
1830   for (auto weights_dims : std::vector<std::vector<int>>{{2, 1, 1}, {2}}) {
1831     test->Reset();
1832     test->AddTestTensor("input", /*dims=*/{2, 1, 2}, /*batch_size=*/1,
1833                         TfDataTypeToTrt(dtype));
1834     test->AddTestWeights<CType>("weights", weights_dims, weights);
1835     test->RunValidationAndConversion(node_def);
1836 
1837     // Make sure it does use BinaryTensorOpWeight, not BinaryTensorOpTensor.
1838     CheckAddedLayers(test, /*expect_scale_layer=*/true);
1839 
1840     // Check the dims of the output ITensor.
1841     TRT_TensorOrWeights output;
1842     TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output));
1843     EXPECT_TRUE(output.is_tensor());
1844     ExpectTrtDimsEqualsArray({2, 1, 2}, output.tensor()->getDimensions());
1845 
1846     const DataVec input_data{{"input", test::AsTensor<CType>(input)}};
1847     DataVec output_data{{"my_binary", ConstructTensor<CType>(4)}};
1848     test->BuildAndRun(input_data, &output_data);
1849     if (weights_dims.size() == 1) {
1850       EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
1851                   ElementsAre(CType(11), CType(22), CType(13), CType(24)));
1852     } else {
1853       EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
1854                   ElementsAre(CType(11), CType(12), CType(23), CType(24)));
1855     }
1856   }
1857 }
1858 
1859 template <DataType dtype>
TestBinaryTensorOpWeightWithUniformlyBroadcast(OpConverterTest * test)1860 void TestBinaryTensorOpWeightWithUniformlyBroadcast(OpConverterTest* test) {
1861   typedef typename EnumToDataType<dtype>::Type CType;
1862   const NodeDef node_def =
1863       GetBinaryOpNodeDef<ops::Add>("input", "weights", dtype);
1864   const std::vector<CType> input{CType(1), CType(2), CType(3), CType(4)};
1865   const std::vector<CType> weights{CType(10)};
1866   test->Reset();
1867   test->AddTestTensor("input", /*dims=*/{2, 1, 2}, /*batch_size=*/1,
1868                       TfDataTypeToTrt(dtype));
1869   test->AddTestWeights<CType>("weights", {1, 1, 1, 1}, weights);
1870   test->RunValidationAndConversion(node_def);
1871 
1872   // Make sure it does use BinaryTensorOpWeight, not BinaryTensorOpTensor.
1873   CheckAddedLayers(test, /*expect_scale_layer=*/true);
1874 
1875   // Check the dims of the output ITensor.
1876   TRT_TensorOrWeights output;
1877   TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output));
1878   EXPECT_TRUE(output.is_tensor());
1879   ExpectTrtDimsEqualsArray({2, 1, 2}, output.tensor()->getDimensions());
1880 
1881   const DataVec input_data{{"input", test::AsTensor<CType>(input)}};
1882   DataVec output_data{{"my_binary", ConstructTensor<CType>(4)}};
1883   test->BuildAndRun(input_data, &output_data);
1884   EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
1885               ElementsAre(CType(11), CType(12), CType(13), CType(14)));
1886 }
1887 
1888 template <typename OpType>
TestBinaryTensorOpWeightFallback(OpConverterTest * test,const std::vector<int32> & input_dims,const std::vector<int> & weights_dims,error::Code code=error::OK,const char * error_msg_substr=nullptr,const int input_batch_size=1)1889 void TestBinaryTensorOpWeightFallback(OpConverterTest* test,
1890                                       const std::vector<int32>& input_dims,
1891                                       const std::vector<int>& weights_dims,
1892                                       error::Code code = error::OK,
1893                                       const char* error_msg_substr = nullptr,
1894                                       const int input_batch_size = 1) {
1895   const DataType dtype = DT_FLOAT;
1896   typedef typename EnumToDataType<dtype>::Type CType;
1897   const size_t num_inputs = TrtDimsNumElements(GetTestDims(input_dims));
1898   const size_t num_weights = TrtDimsNumElements(GetTestDims(weights_dims));
1899 
1900   test->Reset();
1901   const NodeDef node_def =
1902       GetBinaryOpNodeDef<OpType>("input", "weights", dtype);
1903   test->AddTestTensor("input", /*dims=*/input_dims, input_batch_size,
1904                       TfDataTypeToTrt(dtype));
1905   test->AddTestWeights<CType>(
1906       "weights", /*dims=*/weights_dims,
1907       /*values=*/std::vector<CType>(num_weights, CType(1)));
1908   test->RunValidationAndConversion(node_def, code, error_msg_substr);
1909   if (code != error::OK) return;
1910 
1911   // Make sure it does use BinaryTensorOpTensor, not BinaryTensorOpWeight.
1912   CheckAddedLayers(test, /*expect_scale_layer=*/false);
1913 
1914   TRT_TensorOrWeights output;
1915   TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output));
1916   EXPECT_TRUE(output.is_tensor());
1917 
1918   // Check the dims of the output ITensor.
1919   std::vector<int> expected_output_dims = input_dims;
1920   for (int i = expected_output_dims.size() - 1, j = weights_dims.size() - 1;
1921        i >= 0 && j >= 0; --i, --j) {
1922     if (expected_output_dims[i] == 1) {
1923       expected_output_dims[i] = weights_dims[j];
1924     }
1925   }
1926   ExpectTrtDimsEqualsArray(expected_output_dims,
1927                            output.tensor()->getDimensions());
1928 
1929   // Check the result of running the engine.
1930   const int expected_num_outputs =
1931       TrtDimsNumElements(GetTestDims(expected_output_dims));
1932   const DataVec input_data{
1933       {"input", ConstructTensor<CType>(num_inputs, CType(2))}};
1934   DataVec output_data{
1935       {"my_binary", ConstructTensor<CType>(expected_num_outputs)}};
1936   test->BuildAndRun(input_data, &output_data);
1937   if (node_def.op() == "Add") {
1938     EXPECT_THAT(
1939         GetSpanForData<CType>(output_data[0]),
1940         ElementsAreArray(std::vector<CType>(expected_num_outputs, CType(3))));
1941   } else if (node_def.op() == "Minimum") {
1942     EXPECT_THAT(
1943         GetSpanForData<CType>(output_data[0]),
1944         ElementsAreArray(std::vector<CType>(expected_num_outputs, CType(1))));
1945   } else {
1946     ASSERT_TRUE(false);
1947   }
1948 }
1949 
1950 template <typename OpType, DataType dtype>
TestBinaryTensorOpTensor(OpConverterTest * test)1951 void TestBinaryTensorOpTensor(OpConverterTest* test) {
1952   typedef typename EnumToDataType<dtype>::Type CType;
1953   test->Reset();
1954   const NodeDef node_def =
1955       GetBinaryOpNodeDef<OpType>("input1", "input2", dtype);
1956   test->AddTestTensor("input1", /*dims=*/{1, 2}, /*batch_size=*/1,
1957                       TfDataTypeToTrt(dtype));
1958   test->AddTestTensor("input2", /*dims=*/{2, 1}, /*batch_size=*/1,
1959                       TfDataTypeToTrt(dtype));
1960   test->RunValidationAndConversion(node_def);
1961 
1962   // Make sure it does use BinaryTensorOpTensor, not BinaryTensorOpWeight.
1963   CheckAddedLayers(test, /*expect_scale_layer=*/false);
1964 
1965   // Check output dims.
1966   TRT_TensorOrWeights output;
1967   TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output));
1968   EXPECT_TRUE(output.is_tensor());
1969   ExpectTrtDimsEqualsArray({2, 2}, output.tensor()->getDimensions());
1970 
1971   const DataVec input_data{
1972       {"input1", test::AsTensor<CType>({CType(3), CType(6)})},
1973       {"input2", test::AsTensor<CType>({CType(2), CType(3)})}};
1974   DataVec output_data{{"my_binary", ConstructTensor<CType>(4)}};
1975   // After broadcasting first input becomes {3, 6, 3, 6} and second input
1976   // becomes {2, 3, 2, 3}.
1977   test->BuildAndRun(
1978       input_data, &output_data,
1979       dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32);
1980   if (node_def.op() == "Add") {
1981     EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
1982                 ElementsAre(CType(5), CType(8), CType(6), CType(9)));
1983   } else if (node_def.op() == "Sub") {
1984     EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
1985                 ElementsAre(CType(1), CType(4), CType(0), CType(3)));
1986   } else if (node_def.op() == "Mul") {
1987     EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
1988                 ElementsAre(CType(6), CType(12), CType(9), CType(18)));
1989   } else if (node_def.op() == "Div") {
1990     EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
1991                 ElementsAre(CType(1.5), CType(3), CType(1), CType(2)));
1992   } else if (node_def.op() == "RealDiv") {
1993     EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
1994                 ElementsAre(CType(1.5), CType(3), CType(1), CType(2)));
1995   } else if (node_def.op() == "Minimum") {
1996     EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
1997                 ElementsAre(CType(2), CType(2), CType(3), CType(3)));
1998   } else if (node_def.op() == "Maximum") {
1999     EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
2000                 ElementsAre(CType(3), CType(6), CType(3), CType(6)));
2001   } else if (node_def.op() == "Pow") {
2002     ExpectArrayNear(
2003         std::vector<CType>{CType(9), CType(36), CType(27), CType(216)},
2004         GetSpanForData<CType>(output_data[0]));
2005   } else {
2006     ASSERT_TRUE(false);
2007   }
2008 }
2009 
TEST_F(OpConverterTest,ConvertBinary)2010 TEST_F(OpConverterTest, ConvertBinary) {
2011   AttrValue dtype;
2012   dtype.set_type(DT_FLOAT);
2013   // Input size doesn't match, should fail.
2014   for (size_t num_inputs = 0; num_inputs < 2; ++num_inputs) {
2015     Reset();
2016     NodeDef node_def =
2017         MakeNodeDef("my_add", "Add", {num_inputs, "input"}, {{"T", dtype}});
2018     AddTestTensor("input", {1}, /*batch_size=*/1, nvinfer1::DataType::kFLOAT);
2019     RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
2020                                StrCat("Add got ", std::to_string(num_inputs),
2021                                       " inputs but expected 2, at my_add")
2022                                    .c_str());
2023   }
2024   {
2025     // Both inputs are weights.
2026     Reset();
2027     NodeDef node_def =
2028         MakeNodeDef("my_add", "Add", {"weights1", "weights2"}, {{"T", dtype}});
2029     AddTestWeights<float>("weights1", {1}, {1});
2030     AddTestWeights<float>("weights2", {1}, {1});
2031     RunValidationAndConversion(
2032         node_def, error::UNIMPLEMENTED,
2033         "Constant folding is falled back to TensorFlow, binary op received "
2034         "both input as constant at: my_add");
2035   }
2036 
2037   // Test BinaryTensorOpWeight() without broadcasting.
2038   TestBinaryTensorOpWeightNoBroadcast<ops::Add, DT_FLOAT>(this);
2039   TestBinaryTensorOpWeightNoBroadcast<ops::Sub, DT_FLOAT>(this);
2040   TestBinaryTensorOpWeightNoBroadcast<ops::Mul, DT_FLOAT>(this);
2041   TestBinaryTensorOpWeightNoBroadcast<ops::Div, DT_FLOAT>(this);
2042   TestBinaryTensorOpWeightNoBroadcast<ops::RealDiv, DT_FLOAT>(this);
2043 
2044   TestBinaryTensorOpWeightNoBroadcast<ops::Add, DT_HALF>(this);
2045   TestBinaryTensorOpWeightNoBroadcast<ops::Sub, DT_HALF>(this);
2046   TestBinaryTensorOpWeightNoBroadcast<ops::Mul, DT_HALF>(this);
2047   TestBinaryTensorOpWeightNoBroadcast<ops::Div, DT_HALF>(this);
2048   TestBinaryTensorOpWeightNoBroadcast<ops::RealDiv, DT_HALF>(this);
2049 
2050   // Test BinaryTensorOpWeight() with channel-wise broadcasting.
2051   TestBinaryTensorOpWeightWithChannelWiseBroadcast<DT_FLOAT>(this);
2052 
2053   // Test BinaryTensorOpWeight() with uniformly broadcasting.
2054   TestBinaryTensorOpWeightWithUniformlyBroadcast<DT_FLOAT>(this);
2055 
2056   // Test BinaryTensorOpWeight() falling back to BinaryTensorOpTensor().
2057   // Unsupported op.
2058   TestBinaryTensorOpWeightFallback<ops::Minimum>(this, {1, 1, 1}, {1});
2059   // Rank of input tensor dimension <3.
2060   TestBinaryTensorOpWeightFallback<ops::Add>(this, {1, 1}, {1});
2061   // Broadcast on batch dimension, should fail.
2062   TestBinaryTensorOpWeightFallback<ops::Add>(
2063       this, {1, 1, 1}, {2, 1, 1, 1}, error::INVALID_ARGUMENT,
2064       "Unsupported binary op broadcast scheme for op my_binary",
2065       /*input_batch_size=*/2);
2066   // Incompatible dims with per-channel mode.
2067   TestBinaryTensorOpWeightFallback<ops::Add>(this, {1, 1, 1}, {1, 2, 1});
2068   // Incompatible dims.
2069   TestBinaryTensorOpWeightFallback<ops::Add>(this, {1, 2, 1}, {2});
2070 
2071   // Test BinaryTensorOpTensor() with broadcasting.
2072   TestBinaryTensorOpTensor<ops::Add, DT_FLOAT>(this);
2073   TestBinaryTensorOpTensor<ops::Sub, DT_FLOAT>(this);
2074   TestBinaryTensorOpTensor<ops::Mul, DT_FLOAT>(this);
2075   TestBinaryTensorOpTensor<ops::Div, DT_FLOAT>(this);
2076   TestBinaryTensorOpTensor<ops::RealDiv, DT_FLOAT>(this);
2077   TestBinaryTensorOpTensor<ops::Minimum, DT_FLOAT>(this);
2078   TestBinaryTensorOpTensor<ops::Maximum, DT_FLOAT>(this);
2079   TestBinaryTensorOpTensor<ops::Pow, DT_FLOAT>(this);
2080 
2081   TestBinaryTensorOpTensor<ops::Add, DT_HALF>(this);
2082   TestBinaryTensorOpTensor<ops::Sub, DT_HALF>(this);
2083   TestBinaryTensorOpTensor<ops::Mul, DT_HALF>(this);
2084   TestBinaryTensorOpTensor<ops::Div, DT_HALF>(this);
2085   TestBinaryTensorOpTensor<ops::RealDiv, DT_HALF>(this);
2086   TestBinaryTensorOpTensor<ops::Minimum, DT_HALF>(this);
2087   TestBinaryTensorOpTensor<ops::Maximum, DT_HALF>(this);
2088   TestBinaryTensorOpTensor<ops::Pow, DT_HALF>(this);
2089 }
2090 
TEST_F(OpConverterTest,ConvertQuantize)2091 TEST_F(OpConverterTest, ConvertQuantize) {
2092   precision_mode_to_test_ = TrtPrecisionMode::INT8;
2093   const std::pair<string, int> op_with_num_inputs[4] = {
2094       {"FakeQuantWithMinMaxArgs", 1},
2095       {"FakeQuantWithMinMaxVars", 3},
2096       {"QuantizeAndDequantizeV2", 3},
2097       {"QuantizeAndDequantizeV3", 4}};
2098   for (const auto& pair : op_with_num_inputs) {
2099     // Input list is empty, should fail.
2100     NodeDef node_def = MakeNodeDef("my_quantize", pair.first, {});
2101     RunValidationAndConversion(
2102         node_def, error::INVALID_ARGUMENT,
2103         StrCat(pair.first, " got 0 inputs but expected ",
2104                std::to_string(pair.second), ", at my_quantize")
2105             .c_str());
2106   }
2107   {
2108     // FakeQuantWithMinMaxArgs attributes are empty, should fail.
2109     NodeDef node_def =
2110         MakeNodeDef("my_quantize", "FakeQuantWithMinMaxArgs", {"input"});
2111     AddTestTensor("input", {1, 2, 3});
2112     RunValidationAndConversion(
2113         node_def, error::INVALID_ARGUMENT,
2114         "Min or max attribute not found for FakeQuantWithMinMaxArgs "
2115         "at my_quantize");
2116   }
2117   {
2118     // FakeQuantWithMinMaxArgs ranges set via attributes, ok.
2119     Reset();
2120     Scope s = Scope::NewRootScope();
2121     auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
2122     auto quantize_attrs = ops::FakeQuantWithMinMaxArgs::Min(-6.0f).Max(6.0f);
2123     auto quantize = ops::FakeQuantWithMinMaxArgs(s.WithOpName("my_quantize"),
2124                                                  input, quantize_attrs);
2125     const NodeDef& node_def = quantize.operation.node()->def();
2126     AddTestTensor("input", {1, 2, 3});
2127     RunValidationAndConversion(node_def);
2128     TRT_TensorOrWeights output;
2129     TF_EXPECT_OK(GetTensorOrWeights("my_quantize", &output));
2130     EXPECT_TRUE(output.is_tensor());
2131     auto ranges = quantization_ranges();
2132     EXPECT_EQ(1, ranges.count(output.tensor()));
2133     EXPECT_EQ(6.0f, ranges[output.tensor()]);
2134   }
2135   {
2136     // FakeQuantWithMinMaxVars ranges set via inputs, ok.
2137     Reset();
2138     Scope s = Scope::NewRootScope();
2139     auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
2140     auto weights_min = ops::Placeholder(s.WithOpName("weights_min"), DT_FLOAT);
2141     auto weights_max = ops::Placeholder(s.WithOpName("weights_max"), DT_FLOAT);
2142     auto quantize = ops::FakeQuantWithMinMaxVars(
2143         s.WithOpName("my_quantize"), input, weights_min, weights_max);
2144     const NodeDef& node_def = quantize.operation.node()->def();
2145     AddTestTensor("input", {1, 2, 3});
2146     AddTestWeights<float>("weights_min", {1}, {-6.0f});
2147     AddTestWeights<float>("weights_max", {1}, {6.0f});
2148     RunValidationAndConversion(node_def);
2149     TRT_TensorOrWeights output;
2150     TF_EXPECT_OK(GetTensorOrWeights("my_quantize", &output));
2151     EXPECT_TRUE(output.is_tensor());
2152     auto ranges = quantization_ranges();
2153     EXPECT_EQ(1, ranges.count(output.tensor()));
2154     EXPECT_EQ(6.0f, ranges[output.tensor()]);
2155   }
2156   {
2157     // QuantizeAndDequantizeV2 ranges set via inputs, ok.
2158     Reset();
2159     Scope s = Scope::NewRootScope();
2160     auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
2161     auto weights_min = ops::Placeholder(s.WithOpName("weights_min"), DT_FLOAT);
2162     auto weights_max = ops::Placeholder(s.WithOpName("weights_max"), DT_FLOAT);
2163     auto quantize = ops::QuantizeAndDequantizeV2(
2164         s.WithOpName("my_quantize"), input, weights_min, weights_max);
2165     const NodeDef& node_def = quantize.operation.node()->def();
2166     AddTestTensor("input", {1, 2, 3});
2167     AddTestWeights<float>("weights_min", {1}, {-6.0f});
2168     AddTestWeights<float>("weights_max", {1}, {6.0f});
2169     RunValidationAndConversion(node_def);
2170     TRT_TensorOrWeights output;
2171     TF_EXPECT_OK(GetTensorOrWeights("my_quantize", &output));
2172     EXPECT_TRUE(output.is_tensor());
2173     auto ranges = quantization_ranges();
2174     EXPECT_EQ(1, ranges.count(output.tensor()));
2175     EXPECT_EQ(6.0f, ranges[output.tensor()]);
2176   }
2177   {
2178     // QuantizeAndDequantizeV2 Range inputs are tensors, should fail.
2179     Reset();
2180     Scope s = Scope::NewRootScope();
2181     auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
2182     auto weights_min = ops::Placeholder(s.WithOpName("weights_min"), DT_FLOAT);
2183     auto weights_max = ops::Placeholder(s.WithOpName("weights_max"), DT_FLOAT);
2184     auto quantize = ops::QuantizeAndDequantizeV2(
2185         s.WithOpName("my_quantize"), input, weights_min, weights_max);
2186     const NodeDef& node_def = quantize.operation.node()->def();
2187     AddTestTensor("input", {1, 2, 3});
2188     AddTestTensor("weights_min", {1});
2189     AddTestTensor("weights_max", {1});
2190     RunValidationAndConversion(
2191         node_def, error::UNIMPLEMENTED,
2192         "The input \"input_min\" for QuantizeAndDequantizeV2 must be a constant"
2193         ", at my_quantize");
2194   }
2195   {
2196     // QuantizeAndDequantizeV3 ranges set via inputs, ok.
2197     Reset();
2198     Scope s = Scope::NewRootScope();
2199     auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
2200     auto weights_min = ops::Placeholder(s.WithOpName("weights_min"), DT_FLOAT);
2201     auto weights_max = ops::Placeholder(s.WithOpName("weights_max"), DT_FLOAT);
2202     auto num_bits = ops::Placeholder(s.WithOpName("num_bits"), DT_INT32);
2203     auto quantize = ops::QuantizeAndDequantizeV3(
2204         s.WithOpName("my_quantize"), input, weights_min, weights_max, num_bits);
2205     const NodeDef& node_def = quantize.operation.node()->def();
2206     AddTestTensor("input", {1, 2, 3});
2207     AddTestWeights<float>("weights_min", {1}, {-6.0f});
2208     AddTestWeights<float>("weights_max", {1}, {6.0f});
2209     AddTestWeights<int>("num_bits", {1}, {8});
2210     RunValidationAndConversion(node_def);
2211     TRT_TensorOrWeights output;
2212     TF_EXPECT_OK(GetTensorOrWeights("my_quantize", &output));
2213     EXPECT_TRUE(output.is_tensor());
2214     auto ranges = quantization_ranges();
2215     EXPECT_EQ(1, ranges.count(output.tensor()));
2216     EXPECT_EQ(6.0f, ranges[output.tensor()]);
2217   }
2218 }
2219 
2220 template <DataType dtype>
TestConvertSquare(OpConverterTest * test)2221 void TestConvertSquare(OpConverterTest* test) {
2222   test->Reset();
2223   typedef typename EnumToDataType<dtype>::Type CType;
2224 
2225   Scope s = Scope::NewRootScope();
2226   auto input = ops::Placeholder(s.WithOpName("input"), dtype);
2227   auto square = ops::Square(s.WithOpName("my_square"), input);
2228   NodeDef node_def = square.operation.node()->def();
2229 
2230   test->AddTestTensor("input", {1, 20}, /*batch_size=*/1,
2231                       TfDataTypeToTrt(dtype));
2232   test->RunValidationAndConversion(node_def);
2233   TRT_TensorOrWeights output;
2234   TF_EXPECT_OK(test->GetTensorOrWeights("my_square", &output));
2235   EXPECT_TRUE(output.is_tensor());
2236   ExpectTrtDimsEqualsArray({1, 20}, output.tensor()->getDimensions());
2237 
2238   const int num_inputs = 20;
2239   std::vector<CType> inputs(num_inputs);
2240   std::vector<CType> expected_outputs(num_inputs);
2241   for (int i = 0; i < num_inputs; ++i) {
2242     const CType value = CType(i - 9);
2243     inputs[i] = value;
2244     expected_outputs[i] = value * value;
2245   }
2246   const DataVec input_data{{"input", test::AsTensor<CType>(inputs)}};
2247   // Engine outputs are converted to FP16 automatically if we set FP16 mode in
2248   // the builder.
2249   DataVec output_data{{"my_square", ConstructTensor<CType>(num_inputs)}};
2250   test->BuildAndRun(
2251       input_data, &output_data,
2252       dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32);
2253   ExpectArrayNear(expected_outputs, GetSpanForData<CType>(output_data[0]));
2254 }
2255 
TEST_F(OpConverterTest,ConvertSquare)2256 TEST_F(OpConverterTest, ConvertSquare) {
2257   {
2258     // Input list is empty, should fail.
2259     NodeDef node_def = MakeNodeDef("my_square", "Square", {});
2260     RunValidationAndConversion(
2261         node_def, error::INVALID_ARGUMENT,
2262         "Square got 0 inputs but expected 1, at my_square");
2263   }
2264   {
2265     // Input is weights, should fail.
2266     Reset();
2267     Scope s = Scope::NewRootScope();
2268     auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
2269     auto square = ops::Square(s.WithOpName("my_square"), input);
2270     NodeDef node_def = square.operation.node()->def();
2271     AddTestWeights<float>("input", {1, 2, 3}, {1, 2, 3, 4, -5, 6});
2272     RunValidationAndConversion(
2273         node_def, error::UNIMPLEMENTED,
2274         "The input \"x\" for Square must be a tensor, at my_square");
2275   }
2276 
2277   // OK. Note that kINT32 is not supported by IElementWiseLayer, so we don't
2278   // test DT_INT32 type here.
2279   TestConvertSquare<DT_FLOAT>(this);
2280   TestConvertSquare<DT_HALF>(this);
2281 }
2282 
TEST_F(OpConverterTest,ConvertActivation)2283 TEST_F(OpConverterTest, ConvertActivation) {
2284   {
2285     // Input list is empty, should fail.
2286     NodeDef node_def = MakeNodeDef("my_act", "Relu", {});
2287     RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
2288                                "Relu got 0 inputs but expected 1, at my_act");
2289   }
2290   {
2291     // Input is weights, should fail.
2292     Reset();
2293     Scope s = Scope::NewRootScope();
2294     auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
2295     auto relu = ops::Relu(s.WithOpName("my_act"), input);
2296     const NodeDef& node_def = relu.operation.node()->def();
2297     AddTestWeights<int32>("input", {1, 2, 3}, {-3, -2, -1, 0, 1, 2});
2298     RunValidationAndConversion(
2299         node_def, error::UNIMPLEMENTED,
2300         "The input \"input\" for Relu must be a tensor, at my_act");
2301   }
2302 
2303   constexpr float kAlpha = 0.2f;
2304 
2305   // Get nodedef for activation layer.
2306   auto get_act_nodedef = [](string op_name) -> NodeDef {
2307     Scope s = Scope::NewRootScope();
2308     auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
2309     if (op_name == "LeakyRelu") {
2310       auto act =
2311           ops::internal::LeakyRelu(s.WithOpName("my_act"), input,
2312                                    ops::internal::LeakyRelu::Alpha(kAlpha));
2313       return act.operation.node()->def();
2314     } else if (op_name == "Relu") {
2315       auto act = ops::Relu(s.WithOpName("my_act"), input);
2316       return act.operation.node()->def();
2317     } else if (op_name == "Relu6") {
2318       auto act = ops::Relu6(s.WithOpName("my_act"), input);
2319       return act.operation.node()->def();
2320     } else if (op_name == "Sigmoid") {
2321       auto act = ops::Sigmoid(s.WithOpName("my_act"), input);
2322       return act.operation.node()->def();
2323     } else if (op_name == "Tanh") {
2324       auto act = ops::Tanh(s.WithOpName("my_act"), input);
2325       return act.operation.node()->def();
2326     }
2327     EXPECT_TRUE(false);
2328     return NodeDef();
2329   };
2330   // Get expected output for activation layer.
2331   auto get_act_output = [](string op_name, float input) -> float {
2332     if (op_name == "LeakyRelu") {
2333       return (input > 0.0f) ? input : input * kAlpha;
2334     } else if (op_name == "Relu") {
2335       return (input > 0.0f) ? input : 0.0f;
2336     } else if (op_name == "Relu6") {
2337       return std::min(std::max(input, 0.0f), 6.0f);
2338     } else if (op_name == "Sigmoid") {
2339       return 1.0f / (1.0f + std::exp(-input));
2340     } else if (op_name == "Tanh") {
2341       return std::tanh(input);
2342     }
2343     EXPECT_TRUE(false);
2344     return 0;
2345   };
2346 
2347   // Ok.
2348   for (const string& op_name :
2349        {"LeakyRelu", "Relu", "Relu6", "Sigmoid", "Tanh"}) {
2350     Reset();
2351     NodeDef node_def = get_act_nodedef(op_name);
2352     AddTestTensor("input", {1, 2, 3});
2353     RunValidationAndConversion(node_def);
2354     TRT_TensorOrWeights output;
2355     TF_EXPECT_OK(GetTensorOrWeights("my_act", &output));
2356     EXPECT_TRUE(output.is_tensor());
2357     ExpectTrtDimsEqualsArray({1, 2, 3}, output.tensor()->getDimensions());
2358     if (op_name == "Relu6") {
2359       // Relu6 should set quantization range automatically.
2360       auto ranges = quantization_ranges();
2361       EXPECT_EQ(ranges[output.tensor()], 6.0f);
2362     }
2363 
2364     const std::vector<float> input = {-100, -2, -1, 0, 1, 100};
2365     const DataVec input_data{{"input", test::AsTensor<float>(input)}};
2366     DataVec output_data{{"my_act", ConstructTensor<float>(6)}};
2367     BuildAndRun(input_data, &output_data);
2368     for (int i = 0; i < input.size(); i++) {
2369       const float expected_output = get_act_output(op_name, input[i]);
2370       EXPECT_FLOAT_EQ(GetSpanForData<float>(output_data[0])[i],
2371                       expected_output);
2372     }
2373   }
2374 }
2375 
TEST_F(OpConverterTest,ConvertExpandDims)2376 TEST_F(OpConverterTest, ConvertExpandDims) {
2377   {
2378     // Input list is empty, should fail.
2379     NodeDef node_def = MakeNodeDef("my_expanddims", "ExpandDims", {});
2380     RunValidationAndConversion(
2381         node_def, error::INVALID_ARGUMENT,
2382         "ExpandDims got 0 inputs but expected 2, at my_expanddims");
2383   }
2384 
2385   // Get the NodeDef for ExpandDims.
2386   Scope s = Scope::NewRootScope();
2387   auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
2388   auto weights = ops::Placeholder(s.WithOpName("weights"), DT_INT32);
2389   auto expanddims =
2390       ops::ExpandDims(s.WithOpName("my_expanddims"), input, weights);
2391   const NodeDef& node_def = expanddims.operation.node()->def();
2392   {
2393     // Input is weights, should fail.
2394     Reset();
2395     AddTestWeights<int32>("input", {1, 2, 3}, {1, 2, 3, 4, 5, 6});
2396     AddTestWeights<int32>("weights", {1}, {1});
2397     RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
2398                                "The input \"input\" for ExpandDims must be a "
2399                                "tensor, at my_expanddims");
2400   }
2401   {
2402     // Axis is a tensor, should fail.
2403     Reset();
2404     AddTestTensor("input", {1, 2, 3});
2405     AddTestTensor("weights", {3});
2406     RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
2407                                "The input \"axis\" for ExpandDims must be a "
2408                                "constant, at my_expanddims");
2409   }
2410   {
2411     // Add dim at batch dimension, should fail.
2412     Reset();
2413     AddTestTensor("input", {1, 2, 3});
2414     AddTestWeights<int32>("weights", {1}, {0});
2415     RunValidationAndConversion(
2416         node_def, error::UNIMPLEMENTED,
2417         "Modifying batch dimension is not supported for ExpandDims, at "
2418         "my_expanddims");
2419   }
2420   {
2421     // Add dim at batch dimension via negative axis, should fail.
2422     Reset();
2423     AddTestTensor("input", {1, 2, 3});
2424     // Input is rank 4 (batch dim included)
2425     AddTestWeights<int32>("weights", {1}, {-5});
2426     RunValidationAndConversion(
2427         node_def, error::UNIMPLEMENTED,
2428         "Modifying batch dimension is not supported for ExpandDims, at "
2429         "my_expanddims");
2430   }
2431   {
2432     // Axis > rank(input), should fail.
2433     Reset();
2434     AddTestTensor("input", {1, 2, 3});
2435     // Input is rank 4 (batch dim included)
2436     AddTestWeights<int32>("weights", {1}, {5});
2437     RunValidationAndConversion(
2438         node_def, error::INVALID_ARGUMENT,
2439         "Axis for ExpandDims is invalid, must be in the range "
2440         "[-rank(input) - 1, rank(input)], at my_expanddims");
2441   }
2442   {
2443     // Axis < -rank(input)-1, should fail.
2444     Reset();
2445     AddTestTensor("input", {1, 2, 3});
2446     // Input is rank 4 (batch dim included)
2447     AddTestWeights<int32>("weights", {1}, {-6});
2448     RunValidationAndConversion(
2449         node_def, error::INVALID_ARGUMENT,
2450         "Axis for ExpandDims is invalid, must be in the range "
2451         "[-rank(input) - 1, rank(input)], at my_expanddims");
2452   }
2453 
2454   struct TestParams {
2455     std::vector<int> input_dims;
2456     int axis;
2457     std::vector<int> expected_output_dims;
2458   };
2459 
2460   // Ok.
2461   const int kExpandDimsOKCases = 8;
2462   TestParams ok_params[kExpandDimsOKCases] = {
2463       TestParams{{2, 3}, 1, {1, 2, 3}}, TestParams{{2, 3}, -3, {1, 2, 3}},
2464       TestParams{{2, 3}, 3, {2, 3, 1}}, TestParams{{2, 3}, -1, {2, 3, 1}},
2465       TestParams{{2, 3}, 2, {2, 1, 3}}, TestParams{{2, 3}, -2, {2, 1, 3}},
2466       TestParams{{6}, 1, {1, 6}},       TestParams{{6}, -1, {6, 1}},
2467   };
2468   for (int i = 0; i < kExpandDimsOKCases; ++i) {
2469     Reset();
2470     AddTestTensor("input", ok_params[i].input_dims);
2471     AddTestWeights<int32>("weights", {1}, {ok_params[i].axis});
2472     RunValidationAndConversion(node_def);
2473     TRT_TensorOrWeights output;
2474     TF_EXPECT_OK(GetTensorOrWeights("my_expanddims", &output));
2475     EXPECT_TRUE(output.is_tensor());
2476     ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims,
2477                              output.tensor()->getDimensions());
2478 
2479     const DataVec input_data{
2480         {"input", test::AsTensor<float>({1, 2, 3, 4, 5, 6})}};
2481     DataVec output_data{{"my_expanddims", ConstructTensor<float>(6)}};
2482     BuildAndRun(input_data, &output_data);
2483     EXPECT_THAT(GetSpanForData<float>(output_data[0]),
2484                 ElementsAre(1, 2, 3, 4, 5, 6));
2485   }
2486 }
2487 
TEST_F(OpConverterTest,ConvertSqueeze)2488 TEST_F(OpConverterTest, ConvertSqueeze) {
2489   {
2490     // Input list is empty, should fail.
2491     NodeDef node_def = MakeNodeDef("my_squeeze", "Squeeze", {});
2492     RunValidationAndConversion(
2493         node_def, error::INVALID_ARGUMENT,
2494         "Squeeze got 0 inputs but expected 1, at my_squeeze");
2495   }
2496   {
2497     // No attrs, should fail.
2498     Reset();
2499     Scope s = Scope::NewRootScope();
2500     auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
2501     auto squeeze = ops::Squeeze(s.WithOpName("my_squeeze"), input);
2502     const NodeDef& node_def = squeeze.operation.node()->def();
2503     AddTestTensor("input", {1, 2, 3});
2504     RunValidationAndConversion(
2505         node_def, error::UNIMPLEMENTED,
2506         "Squeeze is only implemented for explicit dims, at my_squeeze");
2507   }
2508 
2509   // Get the NodeDef for Squeeze.
2510   auto get_squeeze_nodedef = [](std::vector<int> axis) -> NodeDef {
2511     Scope s = Scope::NewRootScope();
2512     auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
2513     ops::Squeeze::Attrs squeeze_attrs;
2514     squeeze_attrs.axis_ = gtl::ArraySlice<int>(axis);  // non-absl ok
2515     auto squeeze =
2516         ops::Squeeze(s.WithOpName("my_squeeze"), input, squeeze_attrs);
2517     return squeeze.operation.node()->def();
2518   };
2519 
2520   {
2521     // Input is weights, should fail.
2522     Reset();
2523     NodeDef node_def = get_squeeze_nodedef({0});
2524     AddTestWeights<float>("input", {1, 2, 3}, {1, 2, 3, 4, 5, 6});
2525     RunValidationAndConversion(
2526         node_def, error::UNIMPLEMENTED,
2527         "The input \"input\" for Squeeze must be a tensor, at my_squeeze");
2528   }
2529   {
2530     // Squeeze batch dim, should fail.
2531     Reset();
2532     NodeDef node_def = get_squeeze_nodedef({0});
2533     AddTestTensor("input", {1, 2, 3});
2534     RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
2535                                "Cannot squeeze batch dimension, at my_squeeze");
2536   }
2537   {
2538     // Squeeze batch dim via negative axis, should fail.
2539     Reset();
2540     NodeDef node_def = get_squeeze_nodedef({-4});
2541     AddTestTensor("input", {1, 2, 3});
2542     RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
2543                                "Cannot squeeze batch dimension, at my_squeeze");
2544   }
2545   {
2546     // Squeeze >= rank(input), should fail.
2547     Reset();
2548     NodeDef node_def = get_squeeze_nodedef({4});
2549     AddTestTensor("input", {1, 2, 3});
2550     RunValidationAndConversion(
2551         node_def, error::INVALID_ARGUMENT,
2552         "Axis for Squeeze is invalid, must be in the range "
2553         "[-rank(input), rank(input)), at my_squeeze");
2554   }
2555   {
2556     // Squeeze < -rank(input), should fail.
2557     Reset();
2558     NodeDef node_def = get_squeeze_nodedef({-5});
2559     AddTestTensor("input", {1, 2, 3});
2560     RunValidationAndConversion(
2561         node_def, error::INVALID_ARGUMENT,
2562         "Axis for Squeeze is invalid, must be in the range "
2563         "[-rank(input), rank(input)), at my_squeeze");
2564   }
2565 
2566   struct TestParams {
2567     std::vector<int> input_dims;
2568     std::vector<int> axis;
2569     std::vector<int> expected_output_dims;
2570   };
2571 
2572   // Ok.
2573   const int kSqueezeOKCases = 10;
2574   TestParams ok_params[kSqueezeOKCases] = {
2575       TestParams{{1, 2, 3}, {1}, {2, 3}},
2576       TestParams{{1, 2, 3}, {-3}, {2, 3}},
2577       TestParams{{2, 3, 1}, {3}, {2, 3}},
2578       TestParams{{2, 3, 1}, {-1}, {2, 3}},
2579       TestParams{{1, 2, 1, 3, 1}, {1, 3, 5}, {2, 3}},
2580       TestParams{{1, 2, 1, 3, 1}, {3, 1, 5}, {2, 3}},
2581       TestParams{{1, 2, 1, 3, 1}, {-1, -3, -5}, {2, 3}},
2582       TestParams{{1, 2, 1, 3, 1}, {1, -3, 5}, {2, 3}},
2583       TestParams{{1, 6}, {1}, {6}},
2584       TestParams{{6, 1}, {2}, {6}},
2585   };
2586   for (int i = 0; i < kSqueezeOKCases; ++i) {
2587     Reset();
2588     NodeDef node_def = get_squeeze_nodedef(ok_params[i].axis);
2589     AddTestTensor("input", ok_params[i].input_dims);
2590     RunValidationAndConversion(node_def);
2591     TRT_TensorOrWeights output;
2592     TF_EXPECT_OK(GetTensorOrWeights("my_squeeze", &output));
2593     EXPECT_TRUE(output.is_tensor());
2594     ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims,
2595                              output.tensor()->getDimensions());
2596 
2597     const DataVec input_data{
2598         {"input", test::AsTensor<float>({1, 2, 3, 4, 5, 6})}};
2599     DataVec output_data{{"my_squeeze", ConstructTensor<float>(6)}};
2600     BuildAndRun(input_data, &output_data);
2601     EXPECT_THAT(GetSpanForData<float>(output_data[0]),
2602                 ElementsAre(1, 2, 3, 4, 5, 6));
2603   }
2604 }
2605 
TEST_F(OpConverterTest,ConvertStridedSlice)2606 TEST_F(OpConverterTest, ConvertStridedSlice) {
2607   {
2608     // Input list is empty, should fail.
2609     NodeDef node_def = MakeNodeDef("my_strided_slice", "StridedSlice", {});
2610     RunValidationAndConversion(
2611         node_def, error::INVALID_ARGUMENT,
2612         "StridedSlice got 0 inputs but expected 4, at my_strided_slice");
2613   }
2614 
2615   // Get nodedef for StridedSlice layer.
2616   auto get_strided_slice_nodedef =
2617       [](int64 begin_mask = 0, int64 end_mask = 0, int64 ellipsis_mask = 0,
2618          int64 new_axis_mask = 0, int64 shrink_axis_mask = 0) -> NodeDef {
2619     Scope s = Scope::NewRootScope();
2620     auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
2621     auto begin = ops::Placeholder(s.WithOpName("begin"), DT_INT32);
2622     auto end = ops::Placeholder(s.WithOpName("end"), DT_INT32);
2623     auto strides = ops::Placeholder(s.WithOpName("strides"), DT_INT32);
2624     ops::StridedSlice::Attrs attrs = ops::StridedSlice::Attrs()
2625                                          .BeginMask(begin_mask)
2626                                          .EndMask(end_mask)
2627                                          .EllipsisMask(ellipsis_mask)
2628                                          .NewAxisMask(new_axis_mask)
2629                                          .ShrinkAxisMask(shrink_axis_mask);
2630     auto strided_slice = ops::StridedSlice(s.WithOpName("my_strided_slice"),
2631                                            input, begin, end, strides, attrs);
2632     return strided_slice.operation.node()->def();
2633   };
2634 
2635   {
2636     // Input is weights, should fail.
2637     Reset();
2638     NodeDef node_def = get_strided_slice_nodedef();
2639     AddTestWeights<int32>("input", {1, 2, 3}, {1, 2, 3, 4, 5, 6});
2640     AddTestWeights<int32>("begin", {4}, {0, 0, 0, 0});
2641     AddTestWeights<int32>("end", {4}, {1, 1, 2, 3});
2642     AddTestWeights<int32>("strides", {4}, {1, 1, 1, 1});
2643     RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
2644                                "The input \"input\" for StridedSlice must be a "
2645                                "tensor, at my_strided_slice");
2646   }
2647   {
2648     // Begin, end, strides are tensors, should fail.
2649     Reset();
2650     NodeDef node_def = get_strided_slice_nodedef();
2651     AddTestTensor("input", {1, 2, 3});
2652     AddTestTensor("begin", {4});
2653     AddTestTensor("end", {4});
2654     AddTestTensor("strides", {4});
2655     RunValidationAndConversion(
2656         node_def, error::UNIMPLEMENTED,
2657         "The input \"begin\" for StridedSlice must be a constant, at "
2658         "my_strided_slice");
2659   }
2660   {
2661     // Non-zero ellipsis_mask, should fail.
2662     Reset();
2663     NodeDef node_def = get_strided_slice_nodedef(
2664         /*begin_mask=*/0, /*end_mask=*/0, /*ellipsis_mask=*/2,
2665         /*new_axis_mask=*/0, /*shrink_axis_mask=*/0);
2666     AddTestTensor("input", {1, 2, 3});
2667     AddTestWeights<int32>("begin", {4}, {0, 0, 0, 0});
2668     AddTestWeights<int32>("end", {4}, {1, 1, 2, 3});
2669     AddTestWeights<int32>("strides", {4}, {1, 1, 1, 1});
2670     RunValidationAndConversion(
2671         node_def, error::UNIMPLEMENTED,
2672         "ellipsis_mask is not supported for StridedSlice, at "
2673         "my_strided_slice");
2674   }
2675   {
2676     // Modify batch dim, should fail.
2677     Reset();
2678     NodeDef node_def = get_strided_slice_nodedef();
2679     AddTestTensor("input", {1, 2, 3});
2680     AddTestWeights<int32>("begin", {4}, {0, 0, 0, 0});
2681     AddTestWeights<int32>("end", {4}, {0, 1, 2, 3});
2682     AddTestWeights<int32>("strides", {4}, {1, 1, 1, 1});
2683     RunValidationAndConversion(
2684         node_def, error::UNIMPLEMENTED,
2685         "TensorRT does not allow modifications to the batch dimension, at "
2686         "my_strided_slice");
2687   }
2688   {
2689     // Dynamic batch size without end_mask, should fail.
2690     Reset();
2691     NodeDef node_def = get_strided_slice_nodedef();
2692     AddTestTensor("input", {1, 2, 3}, /*batch_size=*/-1);
2693     AddTestWeights<int32>("begin", {4}, {0, 0, 0, 0});
2694     AddTestWeights<int32>("end", {4}, {1, 1, 2, 3});
2695     AddTestWeights<int32>("strides", {4}, {1, 1, 1, 1});
2696     RunValidationAndConversion(
2697         node_def, error::UNIMPLEMENTED,
2698         "TensorRT does not allow modifications to the batch dimension, at "
2699         "my_strided_slice");
2700   }
2701   {
2702     // Dynamic batch size but using end_mask, ok.
2703     Reset();
2704     NodeDef node_def = get_strided_slice_nodedef(/*begin_mask=*/0,
2705                                                  /*end_mask=*/1);
2706     AddTestTensor("input", {1, 2, 3}, /*batch_size=*/-1);
2707     AddTestWeights<int32>("begin", {4}, {0, 0, 0, 0});
2708     AddTestWeights<int32>("end", {4}, {0, 1, 2, 2});
2709     AddTestWeights<int32>("strides", {4}, {1, 1, 1, 1});
2710     RunValidationAndConversion(node_def);
2711   }
2712 // TRT 5.1+ supports strides
2713 #if IS_TRT_VERSION_GE(5, 1, 0)
2714   {
2715     // Negative strides, should fail.
2716     Reset();
2717     NodeDef node_def = get_strided_slice_nodedef();
2718     AddTestTensor("input", {1, 2, 3});
2719     AddTestWeights<int32>("begin", {4}, {0, 0, 0, 0});
2720     AddTestWeights<int32>("end", {4}, {1, 1, 2, 3});
2721     AddTestWeights<int32>("strides", {4}, {1, 1, 1, -1});
2722     RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
2723                                "Negative or zero stride values are not "
2724                                "supported for StridedSlice, at "
2725                                "my_strided_slice");
2726   }
2727 #else
2728   {
2729     // Stride is not 1, should fail.
2730     Reset();
2731     NodeDef node_def = get_strided_slice_nodedef();
2732     AddTestTensor("input", {1, 2, 3});
2733     AddTestWeights<int32>("begin", {4}, {0, 0, 0, 0});
2734     AddTestWeights<int32>("end", {4}, {1, 1, 2, 3});
2735     AddTestWeights<int32>("strides", {4}, {1, 2, 1, 3});
2736     RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
2737                                "Strides other than 1 are not supported with "
2738                                "this version of TRT, at my_strided_slice");
2739   }
2740 #endif
2741   {
2742     // Size of sliced dim is negative, should fail.
2743     Reset();
2744     NodeDef node_def = get_strided_slice_nodedef();
2745     AddTestTensor("input", {1, 2, 3});
2746     AddTestWeights<int32>("begin", {4}, {0, 0, 2, 0});
2747     AddTestWeights<int32>("end", {4}, {1, 1, 0, 3});
2748     AddTestWeights<int32>("strides", {4}, {1, 1, 1, 1});
2749     RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
2750                                "\"size\" cannot be negative or zero for "
2751                                "StridedSlice, at my_strided_slice");
2752   }
2753 
2754   struct TestParams {
2755     std::vector<int> input_dims;
2756     std::vector<int> begin;
2757     std::vector<int> end;
2758     std::vector<int> strides;
2759     int begin_mask;
2760     int end_mask;
2761     std::vector<int> expected_output_dims;
2762     std::vector<float> expected_output;
2763   };
2764 
2765   auto get_mask = [](const std::vector<int>& mask) {
2766     int result = 0;
2767     for (int i = 0; i < mask.size(); i++) {
2768       if (mask[i]) result += (1 << i);
2769     }
2770     return result;
2771   };
2772 
2773   // Same input is used for all tests.
2774   const std::vector<float> ok_input = {1, 2, 3, 4, 5, 6};
2775 
2776 #if IS_TRT_VERSION_GE(5, 1, 0)
2777   const int kStridedSliceOKCases = 23;
2778 #else
2779   const int kStridedSliceOKCases = 19;
2780 #endif
2781   // Ok.
2782   TestParams ok_params[kStridedSliceOKCases] = {
2783     // 2D Crop.
2784     TestParams{/*input_dims=*/{1, 2, 3}, /*begin=*/{0, 0, 0, 0},
2785                /*end=*/{0, 0, 1, 2}, /*strides=*/{1, 1, 1, 1},
2786                /*begin_mask=*/get_mask({0, 0, 0, 0}),
2787                /*end_mask=*/get_mask({1, 1, 0, 0}),
2788                /*expected_output_dims=*/{1, 1, 2}, /*expected_output=*/{1, 2}},
2789     TestParams{
2790         /*input_dims=*/{1, 2, 3},
2791         /*begin=*/{0, 0, 1, 1}, /*end=*/{0, 0, 0, 0}, /*strides=*/{1, 1, 1, 1},
2792         /*begin_mask=*/get_mask({0, 0, 0, 0}),
2793         /*end_mask=*/get_mask({1, 1, 1, 1}), /*expected_output_dims=*/{1, 1, 2},
2794         /*expected_output=*/{5, 6}},
2795     TestParams{
2796         /*input_dims=*/{1, 2, 3},
2797         /*begin=*/{0, 0, 1, 1}, /*end=*/{0, 1, 2, 3}, /*strides=*/{1, 1, 1, 1},
2798         /*begin_mask=*/get_mask({0, 0, 0, 0}),
2799         /*end_mask=*/get_mask({1, 1, 0, 0}), /*expected_output_dims=*/{1, 1, 2},
2800         /*expected_output=*/{5, 6}},
2801     // 2D Crop, with transpose.
2802     TestParams{
2803         /*input_dims=*/{2, 3, 1},
2804         /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 1, 2, 1}, /*strides=*/{1, 1, 1, 1},
2805         /*begin_mask=*/get_mask({0, 0, 0, 0}),
2806         /*end_mask=*/get_mask({1, 0, 0, 0}), /*expected_output_dims=*/{1, 2, 1},
2807         /*expected_output=*/{1, 2}},
2808     TestParams{
2809         /*input_dims=*/{2, 3, 1},
2810         /*begin=*/{0, 1, 1, 0}, /*end=*/{0, 2, 3, 1}, /*strides=*/{1, 1, 1, 1},
2811         /*begin_mask=*/get_mask({0, 0, 0, 0}),
2812         /*end_mask=*/get_mask({1, 0, 0, 0}), /*expected_output_dims=*/{1, 2, 1},
2813         /*expected_output=*/{5, 6}},
2814     TestParams{
2815         /*input_dims=*/{2, 1, 3},
2816         /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 1, 1, 2}, /*strides=*/{1, 1, 1, 1},
2817         /*begin_mask=*/get_mask({0, 0, 0, 0}),
2818         /*end_mask=*/get_mask({1, 0, 0, 0}), /*expected_output_dims=*/{1, 1, 2},
2819         /*expected_output=*/{1, 2}},
2820     TestParams{
2821         /*input_dims=*/{2, 1, 3},
2822         /*begin=*/{0, 1, 0, 1}, /*end=*/{0, 2, 1, 3}, /*strides=*/{1, 1, 1, 1},
2823         /*begin_mask=*/get_mask({0, 0, 0, 0}),
2824         /*end_mask=*/get_mask({1, 0, 0, 0}), /*expected_output_dims=*/{1, 1, 2},
2825         /*expected_output=*/{5, 6}},
2826     // 2D Crop, with reshape.
2827     TestParams{/*input_dims=*/{2, 3},
2828                /*begin=*/{0, 0, 0}, /*end=*/{0, 1, 2}, /*strides=*/{1, 1, 1},
2829                /*begin_mask=*/get_mask({0, 0, 0}),
2830                /*end_mask=*/get_mask({1, 0, 0}),
2831                /*expected_output_dims=*/{1, 2},
2832                /*expected_output=*/{1, 2}},
2833     TestParams{/*input_dims=*/{2, 3},
2834                /*begin=*/{0, 1, 1}, /*end=*/{0, 0, 0}, /*strides=*/{1, 1, 1},
2835                /*begin_mask=*/get_mask({0, 0, 0}),
2836                /*end_mask=*/get_mask({1, 1, 1}),
2837                /*expected_output_dims=*/{1, 2},
2838                /*expected_output=*/{5, 6}},
2839     // 1D Crop.
2840     TestParams{
2841         /*input_dims=*/{1, 2, 3},
2842         /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 0, 0, 2}, /*strides=*/{1, 1, 1, 1},
2843         /*begin_mask=*/get_mask({0, 0, 0, 0}),
2844         /*end_mask=*/get_mask({1, 1, 1, 0}), /*expected_output_dims=*/{1, 2, 2},
2845         /*expected_output=*/{1, 2, 4, 5}},
2846     TestParams{
2847         /*input_dims=*/{1, 2, 3},
2848         /*begin=*/{0, 0, 1, 0}, /*end=*/{0, 0, 0, 0}, /*strides=*/{1, 1, 1, 1},
2849         /*begin_mask=*/get_mask({0, 0, 0, 0}),
2850         /*end_mask=*/get_mask({1, 1, 1, 1}), /*expected_output_dims=*/{1, 1, 3},
2851         /*expected_output=*/{4, 5, 6}},
2852     // 1D Crop, with transpose.
2853     TestParams{
2854         /*input_dims=*/{2, 3, 1},
2855         /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 1, 0, 0}, /*strides=*/{1, 1, 1, 1},
2856         /*begin_mask=*/get_mask({0, 0, 0, 0}),
2857         /*end_mask=*/get_mask({1, 0, 1, 1}), /*expected_output_dims=*/{1, 3, 1},
2858         /*expected_output=*/{1, 2, 3}},
2859     TestParams{
2860         /*input_dims=*/{2, 3, 1},
2861         /*begin=*/{0, 1, 0, 0}, /*end=*/{0, 0, 0, 0}, /*strides=*/{1, 1, 1, 1},
2862         /*begin_mask=*/get_mask({0, 0, 0, 0}),
2863         /*end_mask=*/get_mask({1, 1, 1, 1}), /*expected_output_dims=*/{1, 3, 1},
2864         /*expected_output=*/{4, 5, 6}},
2865     // 1D Crop, with reshape.
2866     TestParams{/*input_dims=*/{6},
2867                /*begin=*/{0, 0}, /*end=*/{0, 3}, /*strides=*/{1, 1},
2868                /*begin_mask=*/get_mask({0, 0}), /*end_mask=*/get_mask({1, 0}),
2869                /*expected_output_dims=*/{3},
2870                /*expected_output=*/{1, 2, 3}},
2871     TestParams{/*input_dims=*/{1, 6},
2872                /*begin=*/{0, 0, 2}, /*end=*/{0, 0, 5}, /*strides=*/{1, 1, 1},
2873                /*begin_mask=*/get_mask({0, 0, 0}),
2874                /*end_mask=*/get_mask({1, 1, 0}),
2875                /*expected_output_dims=*/{1, 3},
2876                /*expected_output=*/{3, 4, 5}},
2877     TestParams{/*input_dims=*/{6, 1},
2878                /*begin=*/{0, 2, 0}, /*end=*/{0, 5, 0}, /*strides=*/{1, 1, 1},
2879                /*begin_mask=*/get_mask({0, 0, 0}),
2880                /*end_mask=*/get_mask({1, 0, 1}),
2881                /*expected_output_dims=*/{3, 1},
2882                /*expected_output=*/{3, 4, 5}},
2883     // Negative axis.
2884     TestParams{/*input_dims=*/{6, 1},
2885                /*begin=*/{0, -6, 0}, /*end=*/{0, -3, 0}, /*strides=*/{1, 1, 1},
2886                /*begin_mask=*/get_mask({0, 0, 0}),
2887                /*end_mask=*/get_mask({1, 0, 1}),
2888                /*expected_output_dims=*/{3, 1},
2889                /*expected_output=*/{1, 2, 3}},
2890     TestParams{/*input_dims=*/{6, 1},
2891                /*begin=*/{0, 0, 0}, /*end=*/{0, -1, 0}, /*strides=*/{1, 1, 1},
2892                /*begin_mask=*/get_mask({0, 0, 0}),
2893                /*end_mask=*/get_mask({1, 0, 1}),
2894                /*expected_output_dims=*/{5, 1},
2895                /*expected_output=*/{1, 2, 3, 4, 5}},
2896     // Clamp out of bounds begin and end.
2897     TestParams{/*input_dims=*/{1, 2, 3}, /*begin=*/{0, 0, -9999, -9},
2898                /*end=*/{0, 1, 1000, 4}, /*strides=*/{1, 1, 1, 1},
2899                /*begin_mask=*/get_mask({0, 0, 0, 0}),
2900                /*end_mask=*/get_mask({1, 0, 0, 0}),
2901                /*expected_output_dims=*/{1, 2, 3},
2902                /*expected_output=*/{1, 2, 3, 4, 5, 6}},
2903 #if IS_TRT_VERSION_GE(5, 1, 0)
2904     // Strides
2905     TestParams{/*input_dims=*/{6},
2906                /*begin=*/{0, 0}, /*end=*/{0, 5}, /*strides=*/{1, 2},
2907                /*begin_mask=*/get_mask({0, 0}), /*end_mask=*/get_mask({1, 0}),
2908                /*expected_output_dims=*/{3},
2909                /*expected_output=*/{1, 3, 5}},
2910     TestParams{/*input_dims=*/{6},
2911                /*begin=*/{0, 0}, /*end=*/{0, 6}, /*strides=*/{1, 2},
2912                /*begin_mask=*/get_mask({0, 0}), /*end_mask=*/get_mask({1, 0}),
2913                /*expected_output_dims=*/{3},
2914                /*expected_output=*/{1, 3, 5}},
2915     TestParams{/*input_dims=*/{6},
2916                /*begin=*/{0, 1}, /*end=*/{0, 6}, /*strides=*/{1, 2},
2917                /*begin_mask=*/get_mask({0, 0}), /*end_mask=*/get_mask({1, 0}),
2918                /*expected_output_dims=*/{3},
2919                /*expected_output=*/{2, 4, 6}},
2920     TestParams{/*input_dims=*/{6},
2921                /*begin=*/{0, 2}, /*end=*/{0, 6}, /*strides=*/{1, 3},
2922                /*begin_mask=*/get_mask({0, 0}), /*end_mask=*/get_mask({1, 0}),
2923                /*expected_output_dims=*/{2},
2924                /*expected_output=*/{3, 6}},
2925 #endif
2926   };
2927 
2928   for (int i = 0; i < kStridedSliceOKCases; i++) {
2929     Reset();
2930     NodeDef node_def = get_strided_slice_nodedef(ok_params[i].begin_mask,
2931                                                  ok_params[i].end_mask);
2932     AddTestTensor("input", ok_params[i].input_dims);
2933     AddTestWeights<int32>("begin",
2934                           {static_cast<int>(ok_params[i].begin.size())},
2935                           ok_params[i].begin);
2936     AddTestWeights<int32>("end", {static_cast<int>(ok_params[i].end.size())},
2937                           ok_params[i].end);
2938     AddTestWeights<int32>("strides",
2939                           {static_cast<int>(ok_params[i].strides.size())},
2940                           ok_params[i].strides);
2941     RunValidationAndConversion(node_def);
2942 
2943     TRT_TensorOrWeights output;
2944     TF_EXPECT_OK(GetTensorOrWeights("my_strided_slice", &output));
2945     EXPECT_TRUE(output.is_tensor());
2946     ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims,
2947                              output.tensor()->getDimensions());
2948 
2949     const DataVec input_data{{"input", test::AsTensor<float>(ok_input)}};
2950     DataVec output_data{
2951         {"my_strided_slice",
2952          ConstructTensor<float>(ok_params[i].expected_output.size())}};
2953     BuildAndRun(input_data, &output_data);
2954     EXPECT_THAT(GetSpanForData<float>(output_data[0]),
2955                 ElementsAreArray(ok_params[i].expected_output));
2956   }
2957 }
2958 
TEST_F(OpConverterTest,ConvertSlice)2959 TEST_F(OpConverterTest, ConvertSlice) {
2960   // Get nodedef for Slice layer.
2961   auto get_slice_nodedef = []() -> NodeDef {
2962     Scope s = Scope::NewRootScope();
2963     auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
2964     auto begin = ops::Placeholder(s.WithOpName("begin"), DT_INT32);
2965     auto size = ops::Placeholder(s.WithOpName("size"), DT_INT32);
2966     auto slice = ops::Slice(s.WithOpName("my_slice"), input, begin, size);
2967     return slice.operation.node()->def();
2968   };
2969 
2970   {
2971     // Begin is below bounds, should fail.
2972     Reset();
2973     NodeDef node_def = get_slice_nodedef();
2974     AddTestTensor("input", {1, 2, 3});
2975     AddTestWeights<int32>("begin", {4}, {0, 0, -1, 0});
2976     AddTestWeights<int32>("size", {4}, {1, 1, 2, 3});
2977     RunValidationAndConversion(
2978         node_def, error::INVALID_ARGUMENT,
2979         "\"begin\" for dimension 2 in Slice is out of range, at my_slice");
2980   }
2981   {
2982     // Begin is above bounds, should fail.
2983     Reset();
2984     NodeDef node_def = get_slice_nodedef();
2985     AddTestTensor("input", {1, 2, 3});
2986     AddTestWeights<int32>("begin", {4}, {0, 0, 3, 0});
2987     AddTestWeights<int32>("size", {4}, {1, 1, 2, 3});
2988     RunValidationAndConversion(
2989         node_def, error::INVALID_ARGUMENT,
2990         "\"begin\" for dimension 2 in Slice is out of range, at my_slice");
2991   }
2992   {
2993     // Size is below bounds, should fail.
2994     Reset();
2995     NodeDef node_def = get_slice_nodedef();
2996     AddTestTensor("input", {1, 2, 3});
2997     AddTestWeights<int32>("begin", {4}, {0, 0, 0, 0});
2998     AddTestWeights<int32>("size", {4}, {1, 1, 2, -2});
2999     RunValidationAndConversion(
3000         node_def, error::INVALID_ARGUMENT,
3001         "\"begin\" + \"size\" for dimension 3 in Slice is out of range, at "
3002         "my_slice");
3003   }
3004   {
3005     // Size is above bounds, should fail.
3006     Reset();
3007     NodeDef node_def = get_slice_nodedef();
3008     AddTestTensor("input", {1, 2, 3});
3009     AddTestWeights<int32>("begin", {4}, {0, 0, 0, 0});
3010     AddTestWeights<int32>("size", {4}, {1, 1, 3, 3});
3011     RunValidationAndConversion(
3012         node_def, error::INVALID_ARGUMENT,
3013         "\"begin\" + \"size\" for dimension 2 in Slice is out of range, at "
3014         "my_slice");
3015   }
3016   {
3017     // Modify batch dim, should fail.
3018     Reset();
3019     NodeDef node_def = get_slice_nodedef();
3020     AddTestTensor("input", {1, 2, 3});
3021     AddTestWeights<int32>("begin", {4}, {0, 0, 0, 0});
3022     AddTestWeights<int32>("size", {4}, {0, 1, 2, 3});
3023     RunValidationAndConversion(
3024         node_def, error::UNIMPLEMENTED,
3025         "TensorRT does not allow modifications to the batch dimension, at "
3026         "my_slice");
3027   }
3028   {
3029     // Dynamic batch size with size[0] not -1, should fail.
3030     Reset();
3031     NodeDef node_def = get_slice_nodedef();
3032     AddTestTensor("input", {1, 2, 3}, /*batch_size=*/-1);
3033     AddTestWeights<int32>("begin", {4}, {0, 0, 0, 0});
3034     AddTestWeights<int32>("size", {4}, {1, 1, 2, 3});
3035     RunValidationAndConversion(
3036         node_def, error::UNIMPLEMENTED,
3037         "TensorRT does not allow modifications to the batch dimension, at "
3038         "my_slice");
3039   }
3040   {
3041     // Dynamic batch size but using size[0] of -1, ok.
3042     Reset();
3043     NodeDef node_def = get_slice_nodedef();
3044     AddTestTensor("input", {1, 2, 3}, /*batch_size=*/-1);
3045     AddTestWeights<int32>("begin", {4}, {0, 0, 0, 0});
3046     AddTestWeights<int32>("size", {4}, {-1, 1, 2, 2});
3047     RunValidationAndConversion(node_def);
3048   }
3049 
3050   struct TestParams {
3051     std::vector<int> input_dims;
3052     std::vector<int> begin;
3053     std::vector<int> size;
3054     std::vector<int> expected_output_dims;
3055     std::vector<int> expected_output;
3056   };
3057 
3058   // Ok.
3059   const int kSliceOKCases = 5;
3060   TestParams ok_params[kSliceOKCases] = {
3061       TestParams{{1, 2, 3},
3062                  {0, 0, 0, 0},
3063                  {-1, -1, -1, -1},
3064                  {1, 2, 3},
3065                  {1, 2, 3, 4, 5, 6}},
3066       TestParams{
3067           {1, 2, 3}, {0, 0, 0, 0}, {1, 1, 2, 3}, {1, 2, 3}, {1, 2, 3, 4, 5, 6}},
3068       TestParams{
3069           {1, 2, 3}, {0, 0, 0, 0}, {1, -1, 2, 2}, {1, 2, 2}, {1, 2, 4, 5}},
3070       TestParams{{6}, {0, 1}, {1, 5}, {5}, {2, 3, 4, 5, 6}},
3071       TestParams{{6}, {0, 1}, {-1, 3}, {3}, {2, 3, 4}},
3072   };
3073 
3074   for (int i = 0; i < kSliceOKCases; i++) {
3075     Reset();
3076     NodeDef node_def = get_slice_nodedef();
3077     AddTestTensor("input", ok_params[i].input_dims);
3078     AddTestWeights<int32>("begin",
3079                           {static_cast<int>(ok_params[i].begin.size())},
3080                           ok_params[i].begin);
3081     AddTestWeights<int32>("size", {static_cast<int>(ok_params[i].size.size())},
3082                           ok_params[i].size);
3083     RunValidationAndConversion(node_def);
3084 
3085     TRT_TensorOrWeights output;
3086     TF_EXPECT_OK(GetTensorOrWeights("my_slice", &output));
3087     EXPECT_TRUE(output.is_tensor());
3088     ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims,
3089                              output.tensor()->getDimensions());
3090 
3091     const DataVec input_data{
3092         {"input", test::AsTensor<float>({1, 2, 3, 4, 5, 6})}};
3093     DataVec output_data{{"my_slice", ConstructTensor<float>(
3094                                          ok_params[i].expected_output.size())}};
3095     BuildAndRun(input_data, &output_data);
3096     EXPECT_THAT(GetSpanForData<float>(output_data[0]),
3097                 ElementsAreArray(ok_params[i].expected_output));
3098   }
3099 }
3100 
TEST_F(OpConverterTest,ConvertConv2D)3101 TEST_F(OpConverterTest, ConvertConv2D) {
3102   {
3103     // Input list is empty, should fail.
3104     NodeDef node_def = MakeNodeDef("my_conv2d", "Conv2D", {});
3105     RunValidationAndConversion(
3106         node_def, error::INVALID_ARGUMENT,
3107         "Conv2D got 0 inputs but expected 2, at my_conv2d");
3108   }
3109 
3110   // Get nodedef for Conv2D layer.
3111   auto get_conv2d_nodedef =
3112       [](std::vector<int> strides = {1, 1, 1, 1}, string padding = "SAME",
3113          string data_format = "NCHW", std::vector<int> dilations = {1, 1, 1, 1},
3114          bool is_conv2d_backprop_input = false) -> NodeDef {
3115     Scope s = Scope::NewRootScope();
3116     auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
3117     auto filter = ops::Placeholder(s.WithOpName("weights"), DT_FLOAT);
3118     if (is_conv2d_backprop_input) {
3119       auto input_sizes =
3120           ops::Placeholder(s.WithOpName("input_sizes"), DT_INT32);
3121       ops::Conv2DBackpropInput::Attrs attrs = ops::Conv2DBackpropInput::Attrs()
3122                                                   .DataFormat(data_format)
3123                                                   .Dilations(dilations);
3124       auto conv2d =
3125           ops::Conv2DBackpropInput(s.WithOpName("my_conv2d"), input_sizes,
3126                                    filter, input, strides, padding, attrs);
3127       return conv2d.operation.node()->def();
3128     } else {
3129       ops::Conv2D::Attrs attrs =
3130           ops::Conv2D::Attrs().DataFormat(data_format).Dilations(dilations);
3131       auto conv2d = ops::Conv2D(s.WithOpName("my_conv2d"), input, filter,
3132                                 strides, padding, attrs);
3133       return conv2d.operation.node()->def();
3134     }
3135   };
3136 
3137   {
3138     // Input is weights, should fail.
3139     Reset();
3140     NodeDef node_def = get_conv2d_nodedef();
3141     AddTestWeights<float>("input", {1, 2, 3}, {1, 2, 3, 4, 5, 6});
3142     AddTestWeights<float>("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
3143     RunValidationAndConversion(
3144         node_def, error::UNIMPLEMENTED,
3145         "The input \"input\" for Conv2D must be a tensor, at my_conv2d");
3146   }
3147   {
3148     // Filter is tensor, should fail.
3149     Reset();
3150     NodeDef node_def = get_conv2d_nodedef();
3151     AddTestTensor("input", {1, 2, 3});
3152     AddTestTensor("weights", {3, 3, 1, 1});
3153     RunValidationAndConversion(
3154         node_def, error::UNIMPLEMENTED,
3155         "The input \"filter\" for Conv2D must be a constant, at my_conv2d");
3156   }
3157   {
3158     // Filter is not 4D, should fail.
3159     Reset();
3160     NodeDef node_def = get_conv2d_nodedef();
3161     AddTestTensor("input", {1, 2, 3});
3162     AddTestWeights<float>("weights", {3, 3, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
3163     RunValidationAndConversion(
3164         node_def, error::INVALID_ARGUMENT,
3165         "Conv2D expects kernel of dimension 4, at my_conv2d");
3166   }
3167   {
3168     // Dilations is not 4D, should fail.
3169     Reset();
3170     NodeDef node_def =
3171         get_conv2d_nodedef({1, 1, 1, 1}, "SAME", "NCHW", {1, 1, 1});
3172     AddTestTensor("input", {1, 2, 3});
3173     AddTestWeights<float>("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
3174     RunValidationAndConversion(
3175         node_def, error::INVALID_ARGUMENT,
3176         "Convolution dilations field must specify 4 dimensions, at my_conv2d");
3177   }
3178   {
3179     // Dilation value is not 1 for channel, should fail.
3180     Reset();
3181     NodeDef node_def =
3182         get_conv2d_nodedef({1, 1, 1, 1}, "SAME", "NCHW", {1, 2, 1, 1});
3183     AddTestTensor("input", {1, 2, 3});
3184     AddTestWeights<float>("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
3185     RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
3186                                "Dilation rate must be 1 for batch and channel "
3187                                "dimensions, at my_conv2d");
3188   }
3189   {
3190     // Dilation value is not 1 for channel (NHWC), should fail.
3191     Reset();
3192     NodeDef node_def =
3193         get_conv2d_nodedef({1, 1, 1, 1}, "SAME", "NHWC", {1, 1, 1, 2});
3194     AddTestTensor("input", {2, 3, 1});
3195     AddTestWeights<float>("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
3196     RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
3197                                "Dilation rate must be 1 for batch and channel "
3198                                "dimensions, at my_conv2d");
3199   }
3200   {
3201     // Dilation + Conv2DBackpropInput, should fail.
3202     Reset();
3203     NodeDef node_def =
3204         get_conv2d_nodedef({1, 1, 1, 1}, "SAME", "NHWC", {1, 1, 2, 1}, true);
3205     AddTestTensor("input", {2, 3, 1});
3206     AddTestWeights<float>("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
3207     AddTestWeights<int>("input_sizes", {4}, {1, 2, 3, 1});
3208     RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
3209                                "Dilation with Conv2DBackpropInput "
3210                                "(conv2d_transpose) is not supported, "
3211                                "at my_conv2d");
3212   }
3213   {
3214     // Strides is not 4D, should fail.
3215     Reset();
3216     NodeDef node_def =
3217         get_conv2d_nodedef({1, 1, 1}, "SAME", "NCHW", {1, 1, 1, 1});
3218     AddTestTensor("input", {1, 2, 3});
3219     AddTestWeights<float>("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
3220     RunValidationAndConversion(
3221         node_def, error::INVALID_ARGUMENT,
3222         "Convolution strides field must specify 4 dimensions, at my_conv2d");
3223   }
3224   {
3225     // Stride value is not 1 for channel, should fail.
3226     Reset();
3227     NodeDef node_def =
3228         get_conv2d_nodedef({1, 2, 1, 1}, "SAME", "NCHW", {1, 1, 1, 1});
3229     AddTestTensor("input", {1, 2, 3});
3230     AddTestWeights<float>("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
3231     RunValidationAndConversion(
3232         node_def, error::UNIMPLEMENTED,
3233         "Stride must be 1 for batch and channel dimensions, at my_conv2d");
3234   }
3235 
3236   struct TestParams {
3237     std::vector<int> input_dims;
3238     std::vector<float> input;
3239     std::vector<int> filter_dims;
3240     std::vector<float> filter;
3241     std::vector<int> strides;
3242     string padding;
3243     string data_format;
3244     std::vector<int> dilations;
3245     bool is_conv2d_backprop_input;
3246     std::vector<int> expected_output_dims;
3247     std::vector<float> expected_output;
3248   };
3249 
3250   // Ok.
3251   const int kConv2DOKCases = 7;
3252   TestParams ok_params[kConv2DOKCases] = {
3253       // Basic
3254       TestParams{/*input_dims=*/{1, 2, 3},
3255                  /*input=*/{0, 1, 2, 3, 3, 4},
3256                  /*filter_dims=*/{1, 2, 1, 1},
3257                  /*filter=*/{-1, 1},
3258                  /*strides=*/{1, 1, 1, 1},
3259                  /*padding=*/"VALID",
3260                  /*data_format=*/"NCHW",
3261                  /*dilations=*/{1, 1, 1, 1},
3262                  /*is_conv2d_backprop_input=*/false,
3263                  /*expected_output_dims=*/{1, 2, 2},
3264                  /*expected_output=*/{1, 1, 0, 1}},
3265       // SAME padding (Asymmetric)
3266       TestParams{/*input_dims=*/{1, 2, 3},
3267                  /*input=*/{0, 1, 2, 3, 3, 4},
3268                  /*filter_dims=*/{1, 2, 1, 1},
3269                  /*filter=*/{-1, 1},
3270                  /*strides=*/{1, 1, 1, 1},
3271                  /*padding=*/"SAME",
3272                  /*data_format=*/"NCHW",
3273                  /*dilations=*/{1, 1, 1, 1},
3274                  /*is_conv2d_backprop_input=*/false,
3275                  /*expected_output_dims=*/{1, 2, 3},
3276                  /*expected_output=*/{1, 1, -2, 0, 1, -4}},
3277       // SAME padding (Symmetric)
3278       TestParams{/*input_dims=*/{1, 2, 3},
3279                  /*input=*/{0, 1, 2, 3, 3, 4},
3280                  /*filter_dims=*/{1, 3, 1, 1},
3281                  /*filter=*/{-1, 0, 1},
3282                  /*strides=*/{1, 1, 1, 1},
3283                  /*padding=*/"SAME",
3284                  /*data_format=*/"NCHW",
3285                  /*dilations=*/{1, 1, 1, 1},
3286                  /*is_conv2d_backprop_input=*/false,
3287                  /*expected_output_dims=*/{1, 2, 3},
3288                  /*expected_output=*/{1, 2, -1, 3, 1, -3}},
3289       // NHWC
3290       TestParams{/*input_dims=*/{2, 3, 1},
3291                  /*input=*/{0, 1, 2, 3, 3, 4},
3292                  /*filter_dims=*/{1, 2, 1, 1},
3293                  /*filter=*/{-1, 1},
3294                  /*strides=*/{1, 1, 1, 1},
3295                  /*padding=*/"VALID",
3296                  /*data_format=*/"NHWC",
3297                  /*dilations=*/{1, 1, 1, 1},
3298                  /*is_conv2d_backprop_input=*/false,
3299                  /*expected_output_dims=*/{2, 2, 1},
3300                  /*expected_output=*/{1, 1, 0, 1}},
3301       // Dilated
3302       TestParams{/*input_dims=*/{1, 2, 3},
3303                  /*input=*/{0, 1, 2, 3, 3, 4},
3304                  /*filter_dims=*/{1, 2, 1, 1},
3305                  /*filter=*/{-1, 1},
3306                  /*strides=*/{1, 1, 1, 1},
3307                  /*padding=*/"VALID",
3308                  /*data_format=*/"NCHW",
3309                  /*dilations=*/{1, 1, 1, 2},
3310                  /*is_conv2d_backprop_input=*/false,
3311                  /*expected_output_dims=*/{1, 2, 1},
3312                  /*expected_output=*/{2, 1}},
3313       // Strided
3314       TestParams{/*input_dims=*/{1, 2, 4},
3315                  /*input=*/{0, 1, 2, 2, 3, 4, 4, 7},
3316                  /*filter_dims=*/{1, 2, 1, 1},
3317                  /*filter=*/{-1, 1},
3318                  /*strides=*/{1, 1, 1, 2},
3319                  /*padding=*/"VALID",
3320                  /*data_format=*/"NCHW",
3321                  /*dilations=*/{1, 1, 1, 1},
3322                  /*is_conv2d_backprop_input=*/false,
3323                  /*expected_output_dims=*/{1, 2, 2},
3324                  /*expected_output=*/{1, 0, 1, 3}},
3325       // Transpose Strided
3326       TestParams{/*input_dims=*/{1, 2, 2},
3327                  /*input=*/{0, 1, 2, 3},
3328                  /*filter_dims=*/{1, 2, 1, 1},
3329                  /*filter=*/{-1, 1},
3330                  /*strides=*/{1, 1, 1, 2},
3331                  /*padding=*/"SAME",
3332                  /*data_format=*/"NCHW",
3333                  /*dilations=*/{1, 1, 1, 1},
3334                  /*is_conv2d_backprop_input=*/true,
3335                  /*expected_output_dims=*/{1, 2, 4},
3336                  /*expected_output=*/{0, 0, -1, 1, -2, 2, -3, 3}},
3337   };
3338 
3339   for (int i = 0; i < kConv2DOKCases; i++) {
3340     Reset();
3341     NodeDef node_def = get_conv2d_nodedef(
3342         ok_params[i].strides, ok_params[i].padding, ok_params[i].data_format,
3343         ok_params[i].dilations, ok_params[i].is_conv2d_backprop_input);
3344     AddTestTensor("input", ok_params[i].input_dims);
3345     AddTestWeights<float>("weights", ok_params[i].filter_dims,
3346                           ok_params[i].filter);
3347     if (ok_params[i].is_conv2d_backprop_input) {
3348       AddTestWeights<float>(
3349           "input_sizes",
3350           {static_cast<int>(ok_params[i].expected_output.size())},
3351           ok_params[i].expected_output);
3352     }
3353     RunValidationAndConversion(node_def);
3354     TRT_TensorOrWeights output;
3355     TF_EXPECT_OK(GetTensorOrWeights("my_conv2d", &output));
3356     EXPECT_TRUE(output.is_tensor());
3357     ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims,
3358                              output.tensor()->getDimensions());
3359 
3360     const DataVec input_data{
3361         {"input", test::AsTensor<float>(ok_params[i].input)}};
3362     DataVec output_data{
3363         {"my_conv2d",
3364          ConstructTensor<float>(ok_params[i].expected_output.size())}};
3365     BuildAndRun(input_data, &output_data);
3366     EXPECT_THAT(GetSpanForData<float>(output_data[0]),
3367                 ElementsAreArray(ok_params[i].expected_output));
3368   }
3369 }
3370 
TEST_F(OpConverterTest,ConvertTopK)3371 TEST_F(OpConverterTest, ConvertTopK) {
3372   {
3373     // Input list is empty, should fail.
3374     NodeDef node_def = MakeNodeDef("my_topk", "TopKV2", {});
3375     RunValidationAndConversion(
3376         node_def, error::INVALID_ARGUMENT,
3377         "TopKV2 got 0 inputs but expected 2, at my_topk");
3378   }
3379 
3380   for (const auto dtype : {DT_FLOAT, DT_INT32}) {
3381     // Get the NodeDef for TopKV2.
3382     Scope s = Scope::NewRootScope();
3383     auto input = ops::Placeholder(s.WithOpName("input"), dtype);
3384     auto weights = ops::Placeholder(s.WithOpName("weights"), DT_INT32);
3385     auto topk = ops::TopK(s.WithOpName("my_topk"), input, weights);
3386     const NodeDef& node_def = topk.operation.node()->def();
3387     {
3388       // K is a tensor, should fail.
3389       Reset();
3390       AddTestTensor("input", {1, 2, 3}, /*batch_size=*/1,
3391                     /*trt_dtype=*/TfDataTypeToTrt(dtype));
3392       AddTestTensor("weights", {2});
3393       RunValidationAndConversion(
3394           node_def, error::UNIMPLEMENTED,
3395           "The input \"k\" for TopKV2 must be a constant, at my_topk");
3396     }
3397     {
3398       // Ok.
3399       Reset();
3400       AddTestTensor("input", {1, 2, 5});
3401       AddTestWeights<int32>("weights", {1}, {2});
3402       RunValidationAndConversion(node_def);
3403       TRT_TensorOrWeights outputs[2];
3404       TF_EXPECT_OK(GetTensorOrWeights("my_topk", &outputs[0]));
3405       TF_EXPECT_OK(GetTensorOrWeights("my_topk:1", &outputs[1]));
3406       for (auto& output : outputs) {
3407         EXPECT_TRUE(output.is_tensor());
3408         ExpectTrtDimsEqualsArray({1, 2, 2}, output.tensor()->getDimensions());
3409       }
3410 
3411       const DataVec input_data{
3412           {"input", test::AsTensor<float>({-9, 3, 5, 1, 6, -5, 7, 1, 0, -1})}};
3413       DataVec output_data{{"my_topk", ConstructTensor<float>(4)},
3414                           {"my_topk:1", ConstructTensor<int32>(4)}};
3415       BuildAndRun(input_data, &output_data);
3416       EXPECT_THAT(GetSpanForData<float>(output_data[0]),
3417                   ElementsAre(6, 5, 7, 1));
3418       EXPECT_THAT(GetSpanForData<int32>(output_data[1]),
3419                   ElementsAre(4, 2, 1, 2));
3420     }
3421   }
3422 }
3423 
3424 template <DataType dtype>
TestConvertGather(OpConverterTest * test)3425 void TestConvertGather(OpConverterTest* test) {
3426   typedef typename EnumToDataType<dtype>::Type CType;
3427 
3428   // Get the NodeDef for GatherV2.
3429   Scope s = Scope::NewRootScope();
3430   auto params = ops::Placeholder(s.WithOpName("params"), dtype);
3431   auto indices = ops::Placeholder(s.WithOpName("indices"), DT_INT32);
3432   auto axis = ops::Placeholder(s.WithOpName("axis"), DT_INT32);
3433   auto gather = ops::GatherV2(s.WithOpName("my_gather"), params, indices, axis);
3434   const NodeDef& node_def = gather.operation.node()->def();
3435 
3436   struct TestParams {
3437     std::vector<int> params_dims;
3438     std::vector<int> indices_dims;
3439     std::vector<int> indices;
3440     int axis;
3441     std::vector<int> expected_output_dims;
3442     std::vector<int> expected_output;
3443   };
3444 
3445   // Input is the same {1, 2, 3, 4, 5, 6} for all cases.
3446   const int kGatherOKCases = 5;
3447   const std::vector<CType> params_input = {CType(1), CType(2), CType(3),
3448                                            CType(4), CType(5), CType(6)};
3449   TestParams ok_params[kGatherOKCases] = {
3450       // Indices are always of rank>1, and output rank is
3451       // rank(params) + rank(indices) - 1.
3452       // TODO(laigd): do we support 0-rank ITensor as indices?
3453       TestParams{{1, 2, 3}, {1}, {0}, 3, {1, 2, 1, 1}, {1, 4}},
3454       TestParams{{1, 2, 3}, {1}, {1}, 3, {1, 2, 1, 1}, {2, 5}},
3455       TestParams{{1, 2, 3}, {1}, {2}, -1, {1, 2, 1, 1}, {3, 6}},
3456       TestParams{
3457           {1, 2, 3}, {3}, {2, 0, 1}, 3, {1, 2, 1, 3}, {3, 1, 2, 6, 4, 5}},
3458       TestParams{{3, 2},
3459                  {2, 2},
3460                  {0, 0, 1, 0},
3461                  2,
3462                  {3, 1, 2, 2},
3463                  {1, 1, 2, 1, 3, 3, 4, 3, 5, 5, 6, 5}},
3464   };
3465 
3466   // Ok.
3467   for (int i = 0; i < kGatherOKCases; i++) {
3468     test->Reset();
3469     test->AddTestTensor("params", ok_params[i].params_dims, 1,
3470                         TfDataTypeToTrt(dtype));
3471     test->AddTestTensor("indices", ok_params[i].indices_dims, 1,
3472                         nvinfer1::DataType::kINT32);
3473     test->AddTestWeights<int32>("axis", {1}, {ok_params[i].axis});
3474     test->RunValidationAndConversion(node_def);
3475     TRT_TensorOrWeights output;
3476     TF_EXPECT_OK(test->GetTensorOrWeights("my_gather", &output));
3477     EXPECT_TRUE(output.is_tensor());
3478     ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims,
3479                              output.tensor()->getDimensions());
3480 
3481     // Create input in CType and convert expected output to CType.
3482     std::vector<CType> converted_expected_output(
3483         ok_params[i].expected_output.begin(),
3484         ok_params[i].expected_output.end());
3485 
3486     const DataVec input_data{
3487         {"params", test::AsTensor<CType>(params_input)},
3488         {"indices", test::AsTensor<int32>(ok_params[i].indices)}};
3489     DataVec output_data{
3490         {"my_gather",
3491          ConstructTensor<CType>(ok_params[i].expected_output.size())}};
3492     test->BuildAndRun(input_data, &output_data);
3493     EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
3494                 ElementsAreArray(converted_expected_output));
3495   }
3496 }
3497 
TEST_F(OpConverterTest,ConvertGather)3498 TEST_F(OpConverterTest, ConvertGather) {
3499   {
3500     // Input list is empty, should fail.
3501     NodeDef node_def = MakeNodeDef("my_gather", "GatherV2", {});
3502     RunValidationAndConversion(
3503         node_def, error::INVALID_ARGUMENT,
3504         "GatherV2 got 0 inputs but expected 3, at my_gather");
3505   }
3506 
3507   // Get the NodeDef for GatherV2.
3508   Scope s = Scope::NewRootScope();
3509   auto params = ops::Placeholder(s.WithOpName("params"), DT_FLOAT);
3510   auto indices = ops::Placeholder(s.WithOpName("indices"), DT_INT32);
3511   auto axis = ops::Placeholder(s.WithOpName("axis"), DT_INT32);
3512   auto gather = ops::GatherV2(s.WithOpName("my_gather"), params, indices, axis);
3513   const NodeDef& node_def = gather.operation.node()->def();
3514   {
3515     // Axis is a tensor, should fail.
3516     Reset();
3517     AddTestTensor("params", {1, 2, 3});
3518     AddTestTensor("indices", {2});
3519     AddTestTensor("axis", {1});
3520     RunValidationAndConversion(
3521         node_def, error::UNIMPLEMENTED,
3522         "The input \"axis\" for GatherV2 must be a constant, at my_gather");
3523   }
3524   {
3525     // Axis is out of bounds, should fail.
3526     Reset();
3527     AddTestTensor("params", {1, 2, 3});
3528     AddTestTensor("indices", {2});
3529     AddTestWeights<int32>("axis", {1}, {4});
3530     RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
3531                                "Axis value of 4 is out of bounds, must be in "
3532                                "range [-4, 4), at my_gather");
3533   }
3534   {
3535     // Axis is batch dimension, should fail.
3536     Reset();
3537     AddTestTensor("params", {1, 2, 3});
3538     AddTestTensor("indices", {2});
3539     AddTestWeights<int32>("axis", {1}, {0});
3540     RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
3541                                "TensorRT does not allow manipulation of the "
3542                                "batch dimension, at my_gather");
3543   }
3544 
3545   Reset();
3546   TestConvertGather<DT_FLOAT>(this);
3547   TestConvertGather<DT_HALF>(this);
3548   TestConvertGather<DT_INT32>(this);
3549 }
3550 
TEST_F(OpConverterTest,ConvertUnary)3551 TEST_F(OpConverterTest, ConvertUnary) {
3552   {
3553     // Input list is empty, should fail.
3554     NodeDef node_def = MakeNodeDef("my_unary", "Neg", {});
3555     RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
3556                                "Neg got 0 inputs but expected 1, at my_unary");
3557   }
3558   {
3559     // Input is weights, should fail.
3560     Reset();
3561     Scope s = Scope::NewRootScope();
3562     auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
3563     auto neg = ops::Neg(s.WithOpName("my_unary"), input);
3564     const NodeDef& node_def = neg.operation.node()->def();
3565     AddTestWeights<float>("input", {1, 2, 3}, {-3, -2, -1, 0, 1, 2});
3566     RunValidationAndConversion(
3567         node_def, error::UNIMPLEMENTED,
3568         "The input \"x\" for Neg must be a tensor, at my_unary");
3569   }
3570 
3571   // Get nodedef for unary layer.
3572   auto get_unary_nodedef = [](string op_name) -> NodeDef {
3573     Scope s = Scope::NewRootScope();
3574     auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
3575     if (op_name == "Abs") {
3576       auto unary = ops::Abs(s.WithOpName("my_unary"), input);
3577       return unary.operation.node()->def();
3578     } else if (op_name == "Acos") {
3579       auto unary = ops::Acos(s.WithOpName("my_unary"), input);
3580       return unary.operation.node()->def();
3581     } else if (op_name == "Acosh") {
3582       auto unary = ops::Acosh(s.WithOpName("my_unary"), input);
3583       return unary.operation.node()->def();
3584     } else if (op_name == "Asin") {
3585       auto unary = ops::Asin(s.WithOpName("my_unary"), input);
3586       return unary.operation.node()->def();
3587     } else if (op_name == "Asinh") {
3588       auto unary = ops::Asinh(s.WithOpName("my_unary"), input);
3589       return unary.operation.node()->def();
3590     } else if (op_name == "Atan") {
3591       auto unary = ops::Atan(s.WithOpName("my_unary"), input);
3592       return unary.operation.node()->def();
3593     } else if (op_name == "Atanh") {
3594       auto unary = ops::Atanh(s.WithOpName("my_unary"), input);
3595       return unary.operation.node()->def();
3596     } else if (op_name == "Ceil") {
3597       auto unary = ops::Ceil(s.WithOpName("my_unary"), input);
3598       return unary.operation.node()->def();
3599     } else if (op_name == "Cos") {
3600       auto unary = ops::Cos(s.WithOpName("my_unary"), input);
3601       return unary.operation.node()->def();
3602     } else if (op_name == "Cosh") {
3603       auto unary = ops::Cosh(s.WithOpName("my_unary"), input);
3604       return unary.operation.node()->def();
3605     } else if (op_name == "Exp") {
3606       auto unary = ops::Exp(s.WithOpName("my_unary"), input);
3607       return unary.operation.node()->def();
3608     } else if (op_name == "Floor") {
3609       auto unary = ops::Floor(s.WithOpName("my_unary"), input);
3610       return unary.operation.node()->def();
3611     } else if (op_name == "Log") {
3612       auto unary = ops::Log(s.WithOpName("my_unary"), input);
3613       return unary.operation.node()->def();
3614     } else if (op_name == "Neg") {
3615       auto unary = ops::Neg(s.WithOpName("my_unary"), input);
3616       return unary.operation.node()->def();
3617     } else if (op_name == "Reciprocal") {
3618       auto unary = ops::Reciprocal(s.WithOpName("my_unary"), input);
3619       return unary.operation.node()->def();
3620     } else if (op_name == "Rsqrt") {
3621       auto unary = ops::Rsqrt(s.WithOpName("my_unary"), input);
3622       return unary.operation.node()->def();
3623     } else if (op_name == "Sin") {
3624       auto unary = ops::Sin(s.WithOpName("my_unary"), input);
3625       return unary.operation.node()->def();
3626     } else if (op_name == "Sinh") {
3627       auto unary = ops::Sinh(s.WithOpName("my_unary"), input);
3628       return unary.operation.node()->def();
3629     } else if (op_name == "Sqrt") {
3630       auto unary = ops::Sqrt(s.WithOpName("my_unary"), input);
3631       return unary.operation.node()->def();
3632     } else if (op_name == "Tan") {
3633       auto unary = ops::Tan(s.WithOpName("my_unary"), input);
3634       return unary.operation.node()->def();
3635     }
3636     EXPECT_TRUE(false);
3637     return NodeDef();
3638   };
3639   // Get expected output for unary layer.
3640   auto get_unary_output = [](string op_name, float input) -> float {
3641     if (op_name == "Abs") {
3642       return std::abs(input);
3643     } else if (op_name == "Acos") {
3644       return std::acos(input);
3645     } else if (op_name == "Acosh") {
3646       return std::acosh(input);
3647     } else if (op_name == "Asin") {
3648       return std::asin(input);
3649     } else if (op_name == "Asinh") {
3650       return std::asinh(input);
3651     } else if (op_name == "Atan") {
3652       return std::atan(input);
3653     } else if (op_name == "Atanh") {
3654       return std::atanh(input);
3655     } else if (op_name == "Ceil") {
3656       return std::ceil(input);
3657     } else if (op_name == "Cos") {
3658       return std::cos(input);
3659     } else if (op_name == "Cosh") {
3660       return std::cosh(input);
3661     } else if (op_name == "Exp") {
3662       return std::exp(input);
3663     } else if (op_name == "Floor") {
3664       return std::floor(input);
3665     } else if (op_name == "Log") {
3666       return std::log(input);
3667     } else if (op_name == "Neg") {
3668       return -input;
3669     } else if (op_name == "Reciprocal") {
3670       return 1.0 / input;
3671     } else if (op_name == "Rsqrt") {
3672       return 1.0 / std::sqrt(input);
3673     } else if (op_name == "Sin") {
3674       return std::sin(input);
3675     } else if (op_name == "Sinh") {
3676       return std::sinh(input);
3677     } else if (op_name == "Sqrt") {
3678       return std::sqrt(input);
3679     } else if (op_name == "Tan") {
3680       return std::tan(input);
3681     }
3682     EXPECT_TRUE(false);
3683     return 0;
3684   };
3685 
3686   // Get list of ops to test.
3687   std::vector<string> ops_to_test;
3688   // Add all ops supported by ConvertUnary.
3689   auto* map = UnaryOperationMap();
3690   ops_to_test.reserve(map->size());
3691   for (auto& pair : *map) {
3692     ops_to_test.push_back(pair.first);
3693   }
3694   // Add other unary ops to test.
3695   ops_to_test.push_back("Rsqrt");
3696   // Ok.
3697   for (string op_name : ops_to_test) {
3698     Reset();
3699     NodeDef node_def = get_unary_nodedef(op_name);
3700     AddTestTensor("input", {1, 2, 3});
3701     RunValidationAndConversion(node_def);
3702     TRT_TensorOrWeights output;
3703     TF_EXPECT_OK(GetTensorOrWeights("my_unary", &output));
3704     EXPECT_TRUE(output.is_tensor());
3705     ExpectTrtDimsEqualsArray({1, 2, 3}, output.tensor()->getDimensions());
3706 
3707     const std::vector<float> input = {-0.9f, 0.6f, 0.0f, -3.5f, 100.0f, 2.9f};
3708     const DataVec input_data{{"input", test::AsTensor<float>(input)}};
3709     DataVec output_data{{"my_unary", ConstructTensor<float>(6)}};
3710     BuildAndRun(input_data, &output_data);
3711     for (int i = 0; i < input.size(); ++i) {
3712       const float expected_output = get_unary_output(op_name, input[i]);
3713       EXPECT_THAT(GetSpanForData<float>(output_data[0])[i],
3714                   NanSensitiveFloatNear(expected_output, 0.0001));
3715     }
3716   }
3717 }
3718 
3719 }  // namespace convert
3720 }  // namespace tensorrt
3721 }  // namespace tensorflow
3722 
3723 #endif  // GOOGLE_TENSORRT
3724 #endif  // GOOGLE_CUDA
3725