1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
17
18 #include <memory>
19 #include <unordered_map>
20 #include <vector>
21
22 #include <gmock/gmock.h>
23 #include <gtest/gtest.h>
24 #include "absl/strings/match.h"
25 #include "absl/strings/numbers.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/string_view.h"
28 #include "absl/types/span.h"
29 #include "tensorflow/cc/framework/ops.h"
30 #include "tensorflow/cc/framework/scope.h"
31 #include "tensorflow/cc/ops/nn_ops_internal.h"
32 #include "tensorflow/cc/ops/standard_ops.h"
33 #include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h"
34 #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
35 #include "tensorflow/core/framework/node_def.pb.h" // NOLINT
36 #include "tensorflow/core/framework/tensor.h"
37 #include "tensorflow/core/framework/tensor.pb.h" // NOLINT
38 #include "tensorflow/core/framework/tensor_shape.h"
39 #include "tensorflow/core/framework/tensor_testutil.h"
40 #include "tensorflow/core/framework/types.h"
41 #include "tensorflow/core/grappler/costs/graph_properties.h"
42 #include "tensorflow/core/lib/core/status.h"
43 #include "tensorflow/core/lib/core/status_test_util.h"
44 #include "tensorflow/core/lib/strings/str_util.h"
45 #include "tensorflow/core/lib/strings/strcat.h"
46 #include "tensorflow/core/platform/protobuf.h"
47 #include "tensorflow/core/platform/test.h"
48 #include "tensorflow/core/protobuf/config.pb.h" // NOLINT
49 #include "tensorflow/core/public/session.h"
50
51 #if GOOGLE_CUDA
52 #if GOOGLE_TENSORRT
53 #include "cuda/include/cuda.h"
54 #include "cuda/include/cuda_runtime_api.h"
55 #include "tensorrt/include/NvInfer.h"
56
57 namespace tensorflow {
58 namespace tensorrt {
59 namespace convert {
60
61 using absl::StrCat;
62 using ::testing::ElementsAre;
63 using ::testing::ElementsAreArray;
64 using ::testing::NanSensitiveFloatNear;
65
66 // TODO(laigd): put this into some test utils file.
ExpectStatus(Status status,error::Code code=error::OK,const char * substr=nullptr)67 void ExpectStatus(Status status, error::Code code = error::OK,
68 const char* substr = nullptr) {
69 EXPECT_EQ(code, status.code())
70 << status << " vs expected error code \"" << error::Code_Name(code)
71 << "\" and message \"" << substr << "\"";
72 if (substr) {
73 EXPECT_THAT(status.error_message(), ::testing::HasSubstr(substr)) << status;
74 }
75 }
76
GetTestDims(const std::vector<int> & d)77 nvinfer1::Dims GetTestDims(const std::vector<int>& d) {
78 nvinfer1::Dims dims;
79 dims.nbDims = d.size();
80 for (int i = 0; i < d.size(); ++i) {
81 dims.d[i] = d[i];
82 }
83 return dims;
84 }
85
TfDataTypeToTrt(DataType tf_dtype)86 nvinfer1::DataType TfDataTypeToTrt(DataType tf_dtype) {
87 switch (tf_dtype) {
88 case DT_FLOAT:
89 return nvinfer1::DataType::kFLOAT;
90 case DT_HALF:
91 return nvinfer1::DataType::kHALF;
92 case DT_INT32:
93 return nvinfer1::DataType::kINT32;
94 default:
95 QCHECK(false) << "Unexpected data type " << DataTypeString(tf_dtype);
96 }
97 }
98
TrtDataTypeToTf(nvinfer1::DataType trt_dtype)99 DataType TrtDataTypeToTf(nvinfer1::DataType trt_dtype) {
100 switch (trt_dtype) {
101 case nvinfer1::DataType::kFLOAT:
102 return DT_FLOAT;
103 case nvinfer1::DataType::kHALF:
104 return DT_HALF;
105 case nvinfer1::DataType::kINT32:
106 return DT_INT32;
107 default:
108 QCHECK(false) << "Unexpected data type " << static_cast<int>(trt_dtype);
109 }
110 }
111
MakeNodeDef(const string & name,const string & op,const std::vector<string> & inputs,const std::map<string,AttrValue> attrs={})112 NodeDef MakeNodeDef(const string& name, const string& op,
113 const std::vector<string>& inputs,
114 const std::map<string, AttrValue> attrs = {}) {
115 NodeDef node_def;
116 node_def.set_name(name);
117 node_def.set_op(op);
118 for (const string& input : inputs) {
119 node_def.add_input(input);
120 }
121 for (const auto& attr : attrs) {
122 (*node_def.mutable_attr())[attr.first] = attr.second;
123 }
124 return node_def;
125 }
126
127 template <typename T>
MakeConstNodeDef(const string & name,const std::vector<T> & vals,const TensorShape & shape)128 NodeDef MakeConstNodeDef(const string& name, const std::vector<T>& vals,
129 const TensorShape& shape) {
130 Scope s = Scope::NewRootScope();
131 Tensor t = test::AsTensor<T>(vals, shape);
132 auto const_op = ops::Const(s.WithOpName(name), t);
133 return const_op.node()->def();
134 }
135
136 template <typename T>
MakeConstNodeDef(const string & name,const std::vector<T> & vals)137 NodeDef MakeConstNodeDef(const string& name, const std::vector<T>& vals) {
138 TensorShape shape;
139 const std::vector<int32> shape_dims = {static_cast<int32>(vals.size())};
140 TF_EXPECT_OK(TensorShapeUtils::MakeShape(shape_dims, &shape));
141 return MakeConstNodeDef(name, vals, shape);
142 }
143
TrtDimsEquals(const nvinfer1::Dims & lhs,const nvinfer1::Dims & rhs)144 bool TrtDimsEquals(const nvinfer1::Dims& lhs, const nvinfer1::Dims& rhs) {
145 if (lhs.nbDims != rhs.nbDims) return false;
146 for (int i = 0; i < lhs.nbDims; ++i) {
147 if (lhs.d[i] != rhs.d[i]) return false;
148 // We don't check the types in the tests.
149 }
150 return true;
151 }
152
TrtDimsEqualsArray(const std::vector<int> & lhs,const nvinfer1::Dims & rhs)153 bool TrtDimsEqualsArray(const std::vector<int>& lhs,
154 const nvinfer1::Dims& rhs) {
155 return TrtDimsEquals(GetTestDims(lhs), rhs);
156 }
157
158 // TODO(laigd): define a parameterized matcher that can compare against the
159 // vector.
ExpectTrtDimsEqualsArray(const std::vector<int> & lhs,const nvinfer1::Dims & rhs)160 void ExpectTrtDimsEqualsArray(const std::vector<int>& lhs,
161 const nvinfer1::Dims& rhs) {
162 EXPECT_TRUE(TrtDimsEqualsArray(lhs, rhs))
163 << "expected: " << DebugString(GetTestDims(lhs)) << "\n"
164 << " actual: " << DebugString(rhs);
165 }
166
167 template <typename T>
ExpectArrayNear(const std::vector<T> & lhs,absl::Span<const T> rhs)168 void ExpectArrayNear(const std::vector<T>& lhs, absl::Span<const T> rhs) {
169 ASSERT_EQ(lhs.size(), rhs.size());
170 for (int i = 0; i < lhs.size(); i++) {
171 EXPECT_FLOAT_EQ(lhs[i], rhs[i]);
172 }
173 }
174
175 // Eigen::half cannot implicitly convert to float which is required for
176 // EXPECT_FLOAT_EQ.
177 template <>
ExpectArrayNear(const std::vector<Eigen::half> & lhs,absl::Span<const Eigen::half> rhs)178 void ExpectArrayNear(const std::vector<Eigen::half>& lhs,
179 absl::Span<const Eigen::half> rhs) {
180 ASSERT_EQ(lhs.size(), rhs.size());
181 for (int i = 0; i < lhs.size(); i++) {
182 EXPECT_FLOAT_EQ(Eigen::half_impl::half_to_float(lhs[i]),
183 Eigen::half_impl::half_to_float(rhs[i]));
184 }
185 }
186
TrtShapedWeightsEquals(const TRT_ShapedWeights & lhs,const TRT_ShapedWeights & rhs)187 bool TrtShapedWeightsEquals(const TRT_ShapedWeights& lhs,
188 const TRT_ShapedWeights& rhs) {
189 return TrtDimsEquals(lhs.shape_, rhs.shape_) && lhs.type_ == rhs.type_ &&
190 lhs.GetValues() == rhs.GetValues();
191 }
192
193 template <typename T>
ValidateWeights(const TRT_ShapedWeights & weights,const std::vector<int> & expected_dims,const std::vector<T> & expected_value)194 void ValidateWeights(const TRT_ShapedWeights& weights,
195 const std::vector<int>& expected_dims,
196 const std::vector<T>& expected_value) {
197 ExpectTrtDimsEqualsArray(expected_dims, weights.shape_);
198 ASSERT_EQ(expected_value.size(), weights.count()) << weights.DebugString();
199 const T* actual_values = static_cast<const T*>(weights.GetValues());
200 for (int i = 0; i < expected_value.size(); ++i) {
201 EXPECT_EQ(expected_value[i], actual_values[i]);
202 }
203 }
204
205 // Fake ITensor implementation for testing purposes.
206 class FakeITensor : public nvinfer1::ITensor {
207 public:
FakeITensor()208 FakeITensor() : dynamic_range_(0.0f) {}
209
FakeITensor(const nvinfer1::Dims & dims)210 FakeITensor(const nvinfer1::Dims& dims) : dims_(dims), dynamic_range_(0.0f) {}
211
FakeITensor(const std::vector<int> & dims)212 FakeITensor(const std::vector<int>& dims)
213 : dims_(GetTestDims(dims)), dynamic_range_(0.0f) {}
214
setName(const char * name)215 void setName(const char* name) override { name_ = name; }
216
getName() const217 const char* getName() const override { return name_.c_str(); }
218
setDimensions(nvinfer1::Dims dimensions)219 void setDimensions(nvinfer1::Dims dimensions) override { dims_ = dimensions; }
220
getDimensions() const221 nvinfer1::Dims getDimensions() const override { return dims_; }
222
setType(nvinfer1::DataType type)223 void setType(nvinfer1::DataType type) override { type_ = type; }
224
getType() const225 nvinfer1::DataType getType() const override { return type_; }
226
isNetworkInput() const227 bool isNetworkInput() const override { return false; }
228
isNetworkOutput() const229 bool isNetworkOutput() const override { return false; }
230
setBroadcastAcrossBatch(bool broadcastAcrossBatch)231 void setBroadcastAcrossBatch(bool broadcastAcrossBatch) override {}
232
getBroadcastAcrossBatch() const233 bool getBroadcastAcrossBatch() const override { return false; }
234
getLocation() const235 nvinfer1::TensorLocation getLocation() const override { return location_; }
236
setLocation(nvinfer1::TensorLocation location)237 void setLocation(nvinfer1::TensorLocation location) override {
238 location_ = location;
239 }
240
241 #if IS_TRT_VERSION_GE(5, 0, 0)
setDynamicRange(float min,float max)242 bool setDynamicRange(float min, float max) override {
243 dynamic_range_ = std::max(std::abs(min), std::abs(max));
244 return true;
245 }
246
getDynamicRange() const247 float getDynamicRange() const override { return dynamic_range_; }
248 #endif
249
250 #if IS_TRT_VERSION_GE(5, 1, 0)
dynamicRangeIsSet() const251 bool dynamicRangeIsSet() const override { return true; }
252
resetDynamicRange()253 void resetDynamicRange() override {}
254
getDynamicRangeMin() const255 float getDynamicRangeMin() const override { return 0.f; }
256
getDynamicRangeMax() const257 float getDynamicRangeMax() const override { return 0.f; }
258 #endif
259
260 private:
261 string name_;
262 nvinfer1::Dims dims_;
263 nvinfer1::DataType type_;
264 nvinfer1::TensorLocation location_;
265 float dynamic_range_;
266 };
267
TEST(TRT_ShapedWeights_Test,Basic)268 TEST(TRT_ShapedWeights_Test, Basic) {
269 // Test constructor with no arguments.
270 {
271 TRT_ShapedWeights weights;
272 TRT_ShapedWeights copy(weights);
273 for (auto ptr : {&weights, ©}) {
274 nvinfer1::Weights trt_weights = ptr->GetTrtWeights();
275 EXPECT_EQ(nvinfer1::DataType::kFLOAT, trt_weights.type);
276 EXPECT_EQ(nullptr, trt_weights.values);
277 EXPECT_EQ(0, trt_weights.count);
278
279 EXPECT_EQ(nullptr, ptr->GetValues());
280 EXPECT_EQ(0, ptr->count());
281 EXPECT_EQ(0, ptr->size_bytes());
282 }
283 }
284 // Test constructor with DataType argument.
285 {
286 TRT_ShapedWeights weights(DT_FLOAT);
287 TRT_ShapedWeights copy(weights);
288 for (auto ptr : {&weights, ©}) {
289 nvinfer1::Weights trt_weights = ptr->GetTrtWeights();
290 EXPECT_EQ(nvinfer1::DataType::kFLOAT, trt_weights.type);
291 EXPECT_EQ(nullptr, trt_weights.values);
292 EXPECT_EQ(0, trt_weights.count);
293
294 EXPECT_EQ(nullptr, ptr->GetValues());
295 EXPECT_EQ(0, ptr->count());
296 EXPECT_EQ(0, ptr->size_bytes());
297 }
298 }
299 // Test constructor with DataType and nvinfer1::Dims arguments.
300 {
301 TrtWeightStore store;
302 TRT_ShapedWeights weights =
303 store.GetTempWeights(DT_FLOAT, GetTestDims({2, 5}));
304 TRT_ShapedWeights copy(weights);
305 for (auto ptr : {&weights, ©}) {
306 nvinfer1::Weights trt_weights = ptr->GetTrtWeights();
307 EXPECT_EQ(nvinfer1::DataType::kFLOAT, trt_weights.type);
308 EXPECT_NE(nullptr, trt_weights.values);
309 EXPECT_EQ(10, trt_weights.count);
310
311 EXPECT_EQ(trt_weights.values, ptr->GetValues());
312 EXPECT_EQ(10, ptr->count());
313 EXPECT_EQ(40, ptr->size_bytes());
314 }
315 // Test that it doesn't copy the underlying buffer.
316 EXPECT_EQ(weights.GetValues(), copy.GetValues());
317 }
318 }
319
TEST(TRT_TensorOrWeights_Test,Basic)320 TEST(TRT_TensorOrWeights_Test, Basic) {
321 // Test constructor with no arguments.
322 {
323 TRT_TensorOrWeights tw;
324 TRT_TensorOrWeights copy(tw);
325 TRT_TensorOrWeights assigned;
326 assigned = tw;
327 for (auto ptr : {&tw, ©, &assigned}) {
328 EXPECT_EQ(false, ptr->is_tensor());
329 EXPECT_EQ(false, ptr->is_weights());
330 EXPECT_EQ(-1, ptr->batch_size());
331 }
332 }
333
334 // Test constructor with ITensor and batch size argument.
335 {
336 nvinfer1::Dims dims;
337 dims.nbDims = 1;
338 dims.d[0] = 1;
339 FakeITensor itensor(dims);
340 TRT_TensorOrWeights tw(&itensor);
341 TRT_TensorOrWeights tw1(&itensor, /*batch_size=*/1);
342
343 for (auto original_ptr : {&tw, &tw1}) {
344 TRT_TensorOrWeights copy(*original_ptr);
345 TRT_TensorOrWeights assigned;
346 assigned = *original_ptr;
347
348 for (auto ptr : {original_ptr, ©, &assigned}) {
349 EXPECT_EQ(true, ptr->is_tensor());
350 EXPECT_EQ(false, ptr->is_weights());
351 if (original_ptr == &tw) {
352 EXPECT_EQ(-1, ptr->batch_size());
353 } else {
354 EXPECT_EQ(1, ptr->batch_size());
355 }
356 EXPECT_EQ(&itensor, ptr->tensor());
357 ExpectTrtDimsEqualsArray({1}, ptr->GetTrtDims());
358 }
359 }
360 }
361 // Test constructor which creates and owns an ITensor.
362 {
363 nvinfer1::Dims dims;
364 dims.nbDims = 1;
365 dims.d[0] = 1;
366 TRT_TensorOrWeights tw(nvinfer1::DataType::kFLOAT, dims, /*batch_size=*/1);
367 TRT_TensorOrWeights copy(tw);
368 TRT_TensorOrWeights assigned;
369 assigned = tw;
370
371 for (auto ptr : {&tw, ©, &assigned}) {
372 EXPECT_EQ(true, ptr->is_tensor());
373 EXPECT_EQ(false, ptr->is_weights());
374 EXPECT_EQ(1, ptr->batch_size());
375 EXPECT_NE(nullptr, ptr->tensor());
376 ExpectTrtDimsEqualsArray({1}, ptr->GetTrtDims());
377 }
378 }
379 // Test constructor with TRT_ShapedWeights argument.
380 {
381 TRT_ShapedWeights weights;
382 TRT_TensorOrWeights tw(weights);
383 TRT_TensorOrWeights copy(tw);
384 TRT_TensorOrWeights assigned;
385 assigned = tw;
386 for (auto ptr : {&tw, ©, &assigned}) {
387 EXPECT_EQ(false, ptr->is_tensor());
388 EXPECT_EQ(true, ptr->is_weights());
389 EXPECT_TRUE(TrtShapedWeightsEquals(weights, ptr->weights()));
390 ExpectTrtDimsEqualsArray({}, ptr->GetTrtDims());
391 }
392 }
393 }
394
395 class ValidatorTest : public ::testing::Test {
396 public:
op_validators()397 std::unordered_map<string, OpConverter>& op_validators() {
398 return validator_.op_validators_;
399 }
400
ConvertToTensorOrWeights(const NodeDef & node_def,int output_port,const grappler::GraphProperties & graph_properties,TRT_TensorOrWeights * tensor_or_weights)401 Status ConvertToTensorOrWeights(
402 const NodeDef& node_def, int output_port,
403 const grappler::GraphProperties& graph_properties,
404 TRT_TensorOrWeights* tensor_or_weights) {
405 return validator_.ConvertToTensorOrWeights(
406 node_def, output_port, graph_properties, tensor_or_weights);
407 }
408
GetQuantizeOps()409 const std::set<string>* GetQuantizeOps() { return validator_.quantize_ops; }
410
411 protected:
412 TrtNodeValidator validator_;
413 };
414
TEST_F(ValidatorTest,QuantizeOpsAreRegistered)415 TEST_F(ValidatorTest, QuantizeOpsAreRegistered) {
416 for (const string& quantize_op : *GetQuantizeOps()) {
417 QCHECK(op_validators().count(quantize_op));
418 }
419 }
420
TEST_F(ValidatorTest,ConvertToTensorOrWeights)421 TEST_F(ValidatorTest, ConvertToTensorOrWeights) {
422 // Convert Const.
423 {
424 NodeDef node_def = MakeConstNodeDef<float>("my_const", {1.0f, 2.0f});
425 TRT_TensorOrWeights output;
426 grappler::GrapplerItem item;
427 grappler::GraphProperties graph_properties(item);
428 ExpectStatus(ConvertToTensorOrWeights(node_def, /*output_port=*/0,
429 graph_properties, &output));
430 ValidateWeights<float>(output.weights(), {2}, {1.0, 2.0});
431 }
432
433 // Helper method to run ConvertToTensorOrWeights() with predefined parameters.
434 auto convert_to_tensor_or_weights = [this](const std::vector<int64>& dims,
435 TRT_TensorOrWeights* output) {
436 Scope s = Scope::NewRootScope();
437 const auto attrs = ops::Placeholder::Shape(PartialTensorShape{dims});
438 auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT, attrs);
439 auto add = ops::Add(s.WithOpName("add"), feed, feed);
440
441 grappler::GrapplerItem item;
442 TF_EXPECT_OK(s.ToGraphDef(&item.graph));
443 grappler::GraphProperties graph_properties(item);
444 TF_EXPECT_OK(graph_properties.InferStatically(true));
445 const NodeDef& node_def = add.operation.node()->def();
446 return this->ConvertToTensorOrWeights(node_def, /*output_port=*/0,
447 graph_properties, output);
448 };
449 // Convert non-Const with #dims > nvinfer1::Dims::MAX_DIMS+1.
450 {
451 TRT_TensorOrWeights output;
452 ExpectStatus(
453 convert_to_tensor_or_weights(
454 std::vector<int64>(nvinfer1::Dims::MAX_DIMS + 2, 1), &output),
455 error::OUT_OF_RANGE, "Input tensor rank is greater than 9");
456 }
457 // Convert non-Const with #dims < 2.
458 {
459 TRT_TensorOrWeights output;
460 ExpectStatus(
461 convert_to_tensor_or_weights({1}, &output), error::INVALID_ARGUMENT,
462 "Input tensor with rank<2 is not supported since the first dimension "
463 "is treated as batch dimension by TRT");
464 }
465 // Convert non-Const. We test the case where the non-batch dimemsion is
466 // unknown as well, to make sure the validator allows that.
467 for (const int32 non_batch_dim : {-1, 2}) {
468 const int32 batch_size = 12;
469 TRT_TensorOrWeights output;
470 ExpectStatus(
471 convert_to_tensor_or_weights({batch_size, non_batch_dim}, &output));
472 EXPECT_EQ(true, output.is_tensor());
473 EXPECT_EQ(batch_size, output.batch_size());
474 EXPECT_NE(nullptr, output.tensor());
475 ExpectTrtDimsEqualsArray({non_batch_dim}, output.GetTrtDims());
476 }
477 }
478
TEST_F(ValidatorTest,ValidateNode)479 TEST_F(ValidatorTest, ValidateNode) {
480 grappler::GrapplerItem item;
481 grappler::GraphProperties graph_properties(item);
482
483 bool start_conversion = false;
484 bool should_fail = false;
485 auto op_converter = [&start_conversion,
486 &should_fail](OpConverterParams* params) -> Status {
487 if (should_fail) return errors::InvalidArgument("");
488 if (!params->validation_only) start_conversion = true;
489 return Status::OK();
490 };
491 NodeDef node_def = MakeNodeDef("my_op", "MyOp", {});
492
493 // Validator not registered.
494 ExpectStatus(validator_.ValidateNode(node_def, {}, TrtPrecisionMode::FP32,
495 graph_properties),
496 error::UNIMPLEMENTED, "Op type MyOp is not supported.");
497
498 // Register validator.
499 op_validators()["MyOp"] = op_converter;
500 TF_EXPECT_OK(validator_.ValidateNode(node_def, {}, TrtPrecisionMode::FP32,
501 graph_properties));
502 EXPECT_EQ(false, start_conversion);
503
504 // Let the converter return error.
505 should_fail = true;
506 ExpectStatus(validator_.ValidateNode(node_def, {}, TrtPrecisionMode::FP32,
507 graph_properties),
508 error::INVALID_ARGUMENT);
509
510 // Test quantization ops, they're only supported in INT8 mode. The success
511 // case is tested in OpConverterTest.ConvertQuantize.
512 node_def = MakeNodeDef("my_op", "FakeQuantWithMinMaxArgs", {});
513 ExpectStatus(validator_.ValidateNode(node_def, {}, TrtPrecisionMode::FP32,
514 graph_properties),
515 error::UNIMPLEMENTED,
516 "Op type FakeQuantWithMinMaxArgs is not supported.");
517 }
518
519 class ConverterTest : public ::testing::Test {
520 public:
ConverterTest()521 ConverterTest() {
522 builder_.reset(nvinfer1::createInferBuilder(logger_));
523 network_.reset(builder_->createNetwork());
524 converter_.reset(new Converter(network_.get(), TrtPrecisionMode::FP32,
525 /*use_calibration=*/false));
526 weight_store_ = &converter_->weight_store_;
527 }
528
AddOpConverter(const string & op_name,OpConverter op_converter)529 void AddOpConverter(const string& op_name, OpConverter op_converter) {
530 converter_->op_registry_[op_name] = op_converter;
531 }
532
533 // Below we expose private methods of Converter for testing.
534
MaybeUpdateBatchSize(int batch_size)535 Status MaybeUpdateBatchSize(int batch_size) {
536 return converter_->MaybeUpdateBatchSize(batch_size);
537 }
538
AddTensorOrWeights(const string & name,TRT_TensorOrWeights input)539 Status AddTensorOrWeights(const string& name, TRT_TensorOrWeights input) {
540 return converter_->AddTensorOrWeights(name, input);
541 }
542
GetTensorOrWeights(const string & name,TRT_TensorOrWeights * output)543 Status GetTensorOrWeights(const string& name, TRT_TensorOrWeights* output) {
544 return converter_->GetTensorOrWeights(name, output);
545 }
546
GetInputs(const NodeDef & node_def,std::vector<TRT_TensorOrWeights> * inputs) const547 Status GetInputs(const NodeDef& node_def,
548 std::vector<TRT_TensorOrWeights>* inputs) const {
549 return converter_->GetInputs(node_def, inputs);
550 }
551
GetWeightRange(const TRT_ShapedWeights & weights,float * out_min,float * out_max) const552 Status GetWeightRange(const TRT_ShapedWeights& weights, float* out_min,
553 float* out_max) const {
554 return converter_->GetWeightRange(weights, out_min, out_max);
555 }
556
PropagateQuantizationRanges()557 void PropagateQuantizationRanges() {
558 converter_->PropagateQuantizationRanges();
559 }
560
batch_size() const561 int batch_size() const { return converter_->batch_size_; }
562
quantization_ranges()563 std::unordered_map<nvinfer1::ITensor*, float>& quantization_ranges() {
564 return converter_->quantization_ranges_;
565 }
566
567 private:
568 Logger logger_;
569 // These members are ordered in a way such that the destruction order is:
570 // converter_ -> network_ -> builder_
571 TrtUniquePtrType<nvinfer1::IBuilder> builder_;
572 TrtUniquePtrType<nvinfer1::INetworkDefinition> network_;
573
574 protected:
575 std::unique_ptr<Converter> converter_;
576 TrtWeightStore* weight_store_;
577 };
578
TEST_F(ConverterTest,ConvertNode)579 TEST_F(ConverterTest, ConvertNode) {
580 FakeITensor output_tensors[2];
581 auto op_converter = [&output_tensors](OpConverterParams* params) -> Status {
582 nvinfer1::Dims dims = params->inputs[0].tensor()->getDimensions();
583 for (int i = 0; i < 2; ++i) {
584 dims.d[0] += 1;
585 output_tensors[i].setDimensions(dims);
586 params->outputs->push_back(TRT_TensorOrWeights(&output_tensors[i]));
587 }
588 return Status::OK();
589 };
590 NodeDef node_def = MakeNodeDef("my_op", "MyOp", {"my_input"});
591 TF_EXPECT_OK(converter_->AddInputTensor(
592 "my_input", nvinfer1::DataType::kFLOAT, GetTestDims({123}), 1));
593
594 // Converter not registered.
595 ExpectStatus(converter_->ConvertNode(node_def), error::UNIMPLEMENTED,
596 "No converter registered for op: MyOp");
597
598 // Register the converter and retry.
599 AddOpConverter("MyOp", op_converter);
600 TF_EXPECT_OK(converter_->ConvertNode(node_def));
601
602 TRT_TensorOrWeights actual_output_1;
603 TF_EXPECT_OK(GetTensorOrWeights("my_op", &actual_output_1));
604 EXPECT_EQ(&output_tensors[0], actual_output_1.tensor());
605 EXPECT_EQ(124, actual_output_1.tensor()->getDimensions().d[0]);
606
607 TRT_TensorOrWeights actual_output_2;
608 TF_EXPECT_OK(GetTensorOrWeights("my_op:1", &actual_output_2));
609 EXPECT_EQ(&output_tensors[1], actual_output_2.tensor());
610 EXPECT_EQ(125, actual_output_2.tensor()->getDimensions().d[0]);
611 }
612
TEST_F(ConverterTest,AddAndGetInputs)613 TEST_F(ConverterTest, AddAndGetInputs) {
614 NodeDef node_def;
615 node_def.add_input("^control_input");
616 node_def.add_input("input");
617 node_def.add_input("input:0");
618 node_def.add_input("input:1");
619 node_def.add_input("weird_input:2:3:4:0");
620
621 TF_EXPECT_OK(converter_->AddInputTensor("input", nvinfer1::DataType::kFLOAT,
622 GetTestDims({1}), 1));
623 TF_EXPECT_OK(converter_->AddInputTensor("input:1", nvinfer1::DataType::kINT32,
624 GetTestDims({2, 3}), 1));
625 TF_EXPECT_OK(converter_->AddInputTensor(
626 "weird_input:2:3:4", nvinfer1::DataType::kHALF, GetTestDims({5, 3}), 1));
627
628 std::vector<TRT_TensorOrWeights> inputs;
629 TF_EXPECT_OK(GetInputs(node_def, &inputs));
630
631 EXPECT_EQ(4, inputs.size());
632 EXPECT_EQ(inputs[0].tensor(), inputs[1].tensor());
633
634 EXPECT_EQ(nvinfer1::DataType::kFLOAT, inputs[0].tensor()->getType());
635 EXPECT_EQ(nvinfer1::DataType::kINT32, inputs[2].tensor()->getType());
636 EXPECT_EQ(nvinfer1::DataType::kHALF, inputs[3].tensor()->getType());
637 ExpectTrtDimsEqualsArray({1}, inputs[0].tensor()->getDimensions());
638 ExpectTrtDimsEqualsArray({2, 3}, inputs[2].tensor()->getDimensions());
639 ExpectTrtDimsEqualsArray({5, 3}, inputs[3].tensor()->getDimensions());
640 }
641
TEST_F(ConverterTest,RenameAndMarkOutputTensors)642 TEST_F(ConverterTest, RenameAndMarkOutputTensors) {
643 // Test that the tensor are actually named and marked as output after
644 // Converter::RenameAndMarkOutputTensors() is called.
645
646 // Register a custom converter which shuffles the input. We use it to build a
647 // TRT network whose output will be later marked.
648 std::vector<nvinfer1::ITensor*> output_tensors;
649 auto op_converter = [&output_tensors](OpConverterParams* params) -> Status {
650 nvinfer1::Permutation perm;
651 perm.order[0] = 1;
652 perm.order[1] = 0;
653 for (int i = 0; i < 2; ++i) {
654 nvinfer1::ITensor* input_tensor =
655 const_cast<nvinfer1::ITensor*>(params->inputs[0].tensor());
656 nvinfer1::IShuffleLayer* layer =
657 params->converter->network()->addShuffle(*input_tensor);
658 layer->setFirstTranspose(perm);
659 nvinfer1::ITensor* output_tensor = layer->getOutput(0);
660 params->outputs->emplace_back(output_tensor);
661 output_tensors.push_back(output_tensor);
662 }
663 TRT_ShapedWeights output_weights(DT_FLOAT);
664 params->outputs->emplace_back(output_weights);
665 return Status::OK();
666 };
667 AddOpConverter("MyOp", op_converter);
668
669 // Run the conversion.
670 NodeDef node_def = MakeNodeDef("my_op", "MyOp", {"my_input"});
671 TF_EXPECT_OK(converter_->AddInputTensor(
672 "my_input", nvinfer1::DataType::kFLOAT, GetTestDims({1, 2}), 1));
673 TF_EXPECT_OK(converter_->ConvertNode(node_def));
674
675 // Mark a weight as output, should fail.
676 ExpectStatus(
677 converter_->RenameAndMarkOutputTensors({{"my_op:2", "my_output"}}),
678 error::INVALID_ARGUMENT, "Output my_op:2 is weights not tensor");
679
680 // Mark tensors as output, should pass.
681 TF_EXPECT_OK(converter_->RenameAndMarkOutputTensors(
682 {{"my_op", "my_output"}, {"my_op:1", "my_output_1"}}));
683 EXPECT_EQ(2, output_tensors.size());
684 for (auto output_tensor : output_tensors) {
685 ExpectTrtDimsEqualsArray({2, 1}, output_tensor->getDimensions());
686 }
687 EXPECT_EQ("my_output", string(output_tensors[0]->getName()));
688 EXPECT_EQ("my_output_1", string(output_tensors[1]->getName()));
689 }
690
TEST_F(ConverterTest,TransposeTensor)691 TEST_F(ConverterTest, TransposeTensor) {
692 nvinfer1::ITensor* input_tensor = converter_->network()->addInput(
693 "", nvinfer1::DataType::kFLOAT, GetTestDims({2, 3, 5}));
694 const nvinfer1::ITensor* output_tensor = nullptr;
695
696 // Rank doesn't match.
697 ExpectStatus(
698 converter_->TransposeTensor(input_tensor, {0, 1}, &output_tensor),
699 error::INVALID_ARGUMENT,
700 "Rank of perm for transpose does not match with that of the input");
701
702 // Transpose at batch dimension.
703 ExpectStatus(
704 converter_->TransposeTensor(input_tensor, {1, 0, 2, 3}, &output_tensor),
705 error::UNIMPLEMENTED, "Transpose at batch dimension is not supported.");
706
707 // OK.
708 TF_EXPECT_OK(
709 converter_->TransposeTensor(input_tensor, {0, 3, 1, 2}, &output_tensor));
710 ExpectTrtDimsEqualsArray({5, 2, 3}, output_tensor->getDimensions());
711 }
712
TEST_F(ConverterTest,PrepareTensorForShape_Tensor)713 TEST_F(ConverterTest, PrepareTensorForShape_Tensor) {
714 nvinfer1::ITensor* input_tensor = converter_->network()->addInput(
715 "", nvinfer1::DataType::kFLOAT, GetTestDims({2, 3, 5}));
716 TRT_TensorOrWeights tw(input_tensor);
717 const nvinfer1::ITensor* output_tensor = nullptr;
718
719 for (bool validation_only : {false, true}) {
720 // Shape size doesn't match.
721 ExpectStatus(
722 converter_->PrepareTensorForShape(tw, GetTestDims({2, 3, 6}),
723 validation_only, &output_tensor),
724 error::INVALID_ARGUMENT, "Reshape shapes are not compatible");
725
726 // TODO(aaroey): we should check the case where uninferred dimensions are
727 // not an exact divisor of input dim ensions, e.g. for dims {-1, 7}.
728
729 // Infer shape, ok.
730 TF_EXPECT_OK(converter_->PrepareTensorForShape(
731 tw, GetTestDims({-1, 2}), validation_only, &output_tensor));
732 if (validation_only) {
733 EXPECT_EQ(nullptr, output_tensor);
734 } else {
735 ExpectTrtDimsEqualsArray({15, 2}, output_tensor->getDimensions());
736 }
737
738 // Regular shape.
739 TF_EXPECT_OK(converter_->PrepareTensorForShape(
740 tw, GetTestDims({10, 3}), validation_only, &output_tensor));
741 if (validation_only) {
742 EXPECT_EQ(nullptr, output_tensor);
743 } else {
744 ExpectTrtDimsEqualsArray({10, 3}, output_tensor->getDimensions());
745 }
746 }
747 }
748
TEST_F(ConverterTest,PrepareTensorForShape_Weights)749 TEST_F(ConverterTest, PrepareTensorForShape_Weights) {
750 TRT_ShapedWeights weights =
751 weight_store_->GetTempWeights(DT_FLOAT, GetTestDims({2, 3, 5}));
752 TRT_TensorOrWeights tw(weights);
753 const nvinfer1::ITensor* output_tensor = nullptr;
754 for (bool validation_only : {false, true}) {
755 TF_EXPECT_OK(converter_->PrepareTensorForShape(
756 tw, GetTestDims({10, 3}), validation_only, &output_tensor));
757 if (validation_only) {
758 EXPECT_EQ(nullptr, output_tensor);
759 } else {
760 ExpectTrtDimsEqualsArray({10, 3}, output_tensor->getDimensions());
761 }
762 }
763 }
764
TEST_F(ConverterTest,MaybeUpdateBatchSize)765 TEST_F(ConverterTest, MaybeUpdateBatchSize) {
766 EXPECT_EQ(-1, batch_size());
767
768 TF_EXPECT_OK(MaybeUpdateBatchSize(-1));
769 EXPECT_EQ(-1, batch_size());
770
771 TF_EXPECT_OK(MaybeUpdateBatchSize(123));
772 EXPECT_EQ(123, batch_size());
773
774 TF_EXPECT_OK(MaybeUpdateBatchSize(123));
775 EXPECT_EQ(123, batch_size());
776
777 TF_EXPECT_OK(MaybeUpdateBatchSize(-1));
778 EXPECT_EQ(123, batch_size());
779
780 ExpectStatus(MaybeUpdateBatchSize(124), error::INVALID_ARGUMENT,
781 "Provided batch size does not match converter batch size");
782 }
783
TEST_F(ConverterTest,AddAndGetTensorOrWeights)784 TEST_F(ConverterTest, AddAndGetTensorOrWeights) {
785 // Add a tensor.
786 FakeITensor fake_tensor;
787 TRT_TensorOrWeights tensor(&fake_tensor);
788 EXPECT_EQ(-1, tensor.batch_size());
789 TF_EXPECT_OK(MaybeUpdateBatchSize(123));
790 TF_EXPECT_OK(AddTensorOrWeights("my_tensor", tensor));
791
792 // Get the added tensor.
793 TRT_TensorOrWeights added_tensor;
794 TF_EXPECT_OK(GetTensorOrWeights("my_tensor", &added_tensor));
795 EXPECT_EQ(123, added_tensor.batch_size());
796
797 // Add the same tensor again.
798 ExpectStatus(AddTensorOrWeights("my_tensor", tensor), error::ALREADY_EXISTS,
799 "tensor/weights my_tensor already exist");
800 }
801
802 template <typename T>
TestGetWeightRange(ConverterTest * test,TrtWeightStore * weight_store)803 void TestGetWeightRange(ConverterTest* test, TrtWeightStore* weight_store) {
804 TRT_ShapedWeights weights =
805 weight_store->GetTempWeights(DataTypeToEnum<T>::v(), GetTestDims({2, 3}));
806 const std::vector<T> values = {T(3), T(1), T(2), T(6), T(5), T(4)};
807 memcpy(const_cast<void*>(weights.GetValues()), values.data(),
808 weights.size_bytes());
809
810 float out_min = 0.0f;
811 float out_max = 0.0f;
812 TF_EXPECT_OK(test->GetWeightRange(weights, &out_min, &out_max));
813 EXPECT_EQ(1.0f, out_min);
814 EXPECT_EQ(6.0f, out_max);
815 }
816
TEST_F(ConverterTest,GetWeightRange)817 TEST_F(ConverterTest, GetWeightRange) {
818 TestGetWeightRange<float>(this, weight_store_);
819 TestGetWeightRange<Eigen::half>(this, weight_store_);
820 TestGetWeightRange<int32>(this, weight_store_);
821 }
822
TEST_F(ConverterTest,ProvideQuantizationRange)823 TEST_F(ConverterTest, ProvideQuantizationRange) {
824 FakeITensor fake_tensor;
825 // Assymetric range
826 converter_->ProvideQuantizationRange(&fake_tensor, 0.0f, 6.0f);
827 EXPECT_EQ(6.0f, quantization_ranges()[&fake_tensor]);
828 converter_->ProvideQuantizationRange(&fake_tensor, 1.0f, 6.0f);
829 EXPECT_EQ(6.0f, quantization_ranges()[&fake_tensor]);
830 converter_->ProvideQuantizationRange(&fake_tensor, -8.0f, 6.0f);
831 EXPECT_EQ(8.0f, quantization_ranges()[&fake_tensor]);
832 converter_->ProvideQuantizationRange(&fake_tensor, -8.123f, -6.123f);
833 EXPECT_EQ(8.123f, quantization_ranges()[&fake_tensor]);
834 // Symmetric range
835 converter_->ProvideQuantizationRange(&fake_tensor, -6.123f, 6.123f);
836 EXPECT_EQ(6.123f, quantization_ranges()[&fake_tensor]);
837 }
838
TEST_F(ConverterTest,MaybeApplyQuantizationRanges)839 TEST_F(ConverterTest, MaybeApplyQuantizationRanges) {
840 // input -> infer1 -> infer2 -> infer3
841 FakeITensor input, infer_1, infer_2, infer_3;
842 FakeITensor not_infer;
843 Converter int8_converter(/*trt_network=*/nullptr, TrtPrecisionMode::INT8,
844 /*use_calibration=*/true);
845 int8_converter.ProvideQuantizationRange(&input, -5.0f, 5.0f);
846 int8_converter.ProvideQuantizationRange(¬_infer, -100.0f, 100.0f);
847 int8_converter.MarkQuantizationRangesAsInferrable(&input, &infer_1);
848 int8_converter.MarkQuantizationRangesAsInferrable(&infer_1, &infer_2);
849 int8_converter.MarkQuantizationRangesAsInferrable(&infer_2, &infer_3);
850
851 // Input range should be inferred along the chain and applied to tensors.
852 int8_converter.MaybeApplyQuantizationRanges();
853 #if IS_TRT_VERSION_GE(5, 0, 0)
854 EXPECT_EQ(input.getDynamicRange(), 5.0f);
855 EXPECT_EQ(infer_1.getDynamicRange(), 5.0f);
856 EXPECT_EQ(infer_2.getDynamicRange(), 5.0f);
857 EXPECT_EQ(infer_3.getDynamicRange(), 5.0f);
858 EXPECT_EQ(not_infer.getDynamicRange(), 100.0f);
859 #endif
860 }
861
TEST_F(ConverterTest,PropagateQuantizationRanges)862 TEST_F(ConverterTest, PropagateQuantizationRanges) {
863 // infer0 <-> infer1 <-> infer2 <-> infer3
864 // |
865 // infer4 <-> infer5
866 FakeITensor infer[6];
867 FakeITensor not_infer;
868 converter_->ProvideQuantizationRange(&infer[4], -5.0f, 5.0f);
869 converter_->MarkQuantizationRangesAsInferrable(&infer[0], &infer[1]);
870 converter_->MarkQuantizationRangesAsInferrable(&infer[1], &infer[2]);
871 converter_->MarkQuantizationRangesAsInferrable(&infer[3], &infer[2]);
872 converter_->MarkQuantizationRangesAsInferrable(&infer[4], &infer[1]);
873 converter_->MarkQuantizationRangesAsInferrable(&infer[4], &infer[5]);
874
875 // Input range should be inferred along the chain.
876 PropagateQuantizationRanges();
877 auto ranges = quantization_ranges();
878 for (int i = 0; i < 6; ++i) {
879 EXPECT_EQ(5.0f, ranges[&infer[i]]);
880 }
881 EXPECT_EQ(ranges.count(¬_infer), 0);
882 }
883
TEST_F(ConverterTest,GetTrtBroadcastShape)884 TEST_F(ConverterTest, GetTrtBroadcastShape) {
885 const bool kIsTensor = true;
886 const bool kIsNotTensor = false;
887 auto symmetric_test = [this](const std::vector<int>& operand_1_shape,
888 const std::vector<int>& operand_2_shape,
889 const bool operand_1_is_tensor,
890 const bool operand_2_is_tensor,
891 const std::vector<int>& expected_operand_1_shape,
892 const std::vector<int>& expected_operand_2_shape,
893 error::Code expected_code = error::OK,
894 const char* expected_error_msg_substr = nullptr,
895 const int operand_1_batch_size = -1,
896 const int operand_2_batch_size = -1) {
897 auto create_tensor_or_weights = [](const std::vector<int>& shape,
898 bool is_tensor, int batch_size = -1) {
899 if (is_tensor) {
900 return TRT_TensorOrWeights{nvinfer1::DataType::kFLOAT,
901 GetTestDims(shape), batch_size};
902 }
903 TRT_ShapedWeights weights;
904 weights.shape_ = GetTestDims(shape);
905 return TRT_TensorOrWeights(weights);
906 };
907
908 nvinfer1::Dims operand_1_new_dims, operand_2_new_dims;
909 TRT_TensorOrWeights operand_1 = create_tensor_or_weights(
910 operand_1_shape, operand_1_is_tensor, operand_1_batch_size);
911 TRT_TensorOrWeights operand_2 = create_tensor_or_weights(
912 operand_2_shape, operand_2_is_tensor, operand_2_batch_size);
913
914 // operand_1 broadcast operand_2
915 ExpectStatus(
916 this->converter_->GetTrtBroadcastShape(
917 operand_1, operand_2, &operand_1_new_dims, &operand_2_new_dims),
918 expected_code, expected_error_msg_substr);
919 if (expected_code == error::OK) {
920 ExpectTrtDimsEqualsArray(expected_operand_1_shape, operand_1_new_dims);
921 ExpectTrtDimsEqualsArray(expected_operand_2_shape, operand_2_new_dims);
922 }
923 // operand_2 broadcast operand_1
924 ExpectStatus(
925 this->converter_->GetTrtBroadcastShape(
926 operand_2, operand_1, &operand_2_new_dims, &operand_1_new_dims),
927 expected_code, expected_error_msg_substr);
928 if (expected_code == error::OK) {
929 ExpectTrtDimsEqualsArray(expected_operand_1_shape, operand_1_new_dims);
930 ExpectTrtDimsEqualsArray(expected_operand_2_shape, operand_2_new_dims);
931 }
932 };
933
934 // Both inputs are weights.
935 symmetric_test(
936 {1}, {1}, kIsNotTensor, kIsNotTensor, {}, {}, error::INVALID_ARGUMENT,
937 "Broadcasting requires at least one of the operands be tensors");
938
939 // One tensor and one weights.
940 symmetric_test({1, 1, 1}, {2}, kIsTensor, kIsNotTensor, {1, 1, 1}, {1, 1, 2});
941 symmetric_test({1, 1, 2}, {2}, kIsTensor, kIsNotTensor, {1, 1, 2}, {1, 1, 2});
942 symmetric_test({1, 3, 2}, {1}, kIsTensor, kIsNotTensor, {1, 3, 2}, {1, 1, 1});
943 symmetric_test({1, 1, 1}, {2, 3}, kIsTensor, kIsNotTensor, {1, 1, 1},
944 {1, 2, 3});
945 symmetric_test({1, 1, 1}, {2, 3, 4}, kIsTensor, kIsNotTensor, {1, 1, 1},
946 {2, 3, 4});
947 symmetric_test({1, 1, 1}, {1, 2, 3, 4}, kIsTensor, kIsNotTensor, {1, 1, 1},
948 {2, 3, 4});
949 symmetric_test({1, 3, 4}, {1, 2, 1, 4}, kIsTensor, kIsNotTensor, {1, 3, 4},
950 {2, 1, 4});
951 symmetric_test({1, 1, 1}, {2, 1, 1, 1}, kIsTensor, kIsNotTensor, {}, {},
952 error::INVALID_ARGUMENT, "Infeasible broadcast scheme");
953 symmetric_test({1, 1, 1}, {2, 1, 1, 1}, kIsTensor, kIsNotTensor, {}, {},
954 error::INVALID_ARGUMENT, "Infeasible broadcast scheme",
955 /*operand_1_batch_size=*/2);
956 symmetric_test({1, 1, 1}, {1, 1, 1, 1, 1}, kIsTensor, kIsNotTensor, {}, {},
957 error::INVALID_ARGUMENT,
958 "Broadcasting beyond batch dimension is not supported "
959 "(tensor #dims 4 vs broadcast #dims 5)");
960
961 // Both inputs are tensors.
962 symmetric_test({1, 1, 1}, {1, 1}, kIsTensor, kIsTensor, {}, {},
963 error::INVALID_ARGUMENT,
964 "Broadcasting beyond batch dimension is not supported "
965 "(tensor #dims 3 vs broadcast #dims 4)");
966 symmetric_test({1, 3, 4}, {2, 1, 4}, kIsTensor, kIsTensor, {1, 3, 4},
967 {2, 1, 4});
968 symmetric_test({1, 1, 1}, {1, 1, 1, 1}, kIsTensor, kIsTensor, {}, {},
969 error::INVALID_ARGUMENT,
970 "Broadcasting beyond batch dimension is not supported "
971 "(tensor #dims 4 vs broadcast #dims 5)");
972 }
973
TEST_F(ConverterTest,CreateConstantLayer)974 TEST_F(ConverterTest, CreateConstantLayer) {
975 for (auto dtype : {DT_FLOAT, DT_INT32}) {
976 TRT_ShapedWeights weights =
977 weight_store_->GetTempWeights(dtype, GetTestDims({2, 3, 5}));
978 nvinfer1::ITensor* tensor =
979 converter_->CreateConstantLayer(weights, GetTestDims({3, 10}));
980 ASSERT_NE(nullptr, tensor);
981 EXPECT_EQ(TfDataTypeToTrt(dtype), tensor->getType())
982 << "Expected " << DebugString(TfDataTypeToTrt(dtype)) << " vs. actual "
983 << DebugString(tensor->getType());
984 ExpectTrtDimsEqualsArray({3, 10}, tensor->getDimensions());
985 }
986 }
987
988 class ConvertGraphDefToEngineTest : public ::testing::Test {
989 public:
RunConvertGraphDefToEngine(Scope * s)990 Status RunConvertGraphDefToEngine(Scope* s) {
991 GraphDef gdef;
992 TF_EXPECT_OK(s->ToGraphDef(&gdef));
993 std::vector<PartialTensorShape> input_shapes;
994 int batch_size = -1;
995 for (const NodeDef& node : gdef.node()) {
996 absl::string_view node_name(node.name());
997 if (str_util::ConsumePrefix(&node_name, kInputPHName)) {
998 int port = -1;
999 EXPECT_TRUE(absl::SimpleAtoi(node_name, &port)) << node.name();
1000 if (input_shapes.size() < port + 1) input_shapes.resize(port + 1);
1001 input_shapes[port] =
1002 PartialTensorShape(node.attr().at("shape").shape());
1003 if (batch_size == -1) {
1004 batch_size = input_shapes[port].dim_size(0);
1005 } else {
1006 EXPECT_EQ(batch_size, input_shapes[port].dim_size(0));
1007 }
1008 }
1009 }
1010 // TODO(laigd): execute the engine and get outputs.
1011 return ConvertGraphDefToEngine(
1012 gdef, TrtPrecisionMode::FP32, /*max_batch_size=*/1,
1013 /*max_workspace_size_bytes=*/64 << 20, input_shapes, &logger_,
1014 /*allocator=*/nullptr, /*calibrator=*/nullptr, &engine_,
1015 /*use_calibration=*/false, /*convert_successfully=*/nullptr);
1016 }
1017
1018 protected:
1019 TrtUniquePtrType<nvinfer1::ICudaEngine> engine_;
1020
1021 private:
1022 Logger logger_;
1023 };
1024
TEST_F(ConvertGraphDefToEngineTest,IdentityGraph)1025 TEST_F(ConvertGraphDefToEngineTest, IdentityGraph) {
1026 Scope s = Scope::NewRootScope();
1027 auto input = ops::Placeholder(s.WithOpName(StrCat(kInputPHName, 0)), DT_FLOAT,
1028 ops::Placeholder::Shape({1, 1}));
1029 auto output = ops::Identity(s.WithOpName("identity1"), input);
1030 output = ops::Identity(s.WithOpName("identity2"), output);
1031 output = ops::Identity(s.WithOpName(StrCat(kOutputPHName, 0)), output);
1032 // If the converter marks the input tensor as output tensor, the conversion
1033 // below will fail with:
1034 // > TensorRTOutputPH_0 cannot be both input and output
1035 // > Network must have at least one output
1036 TF_EXPECT_OK(RunConvertGraphDefToEngine(&s));
1037 }
1038
1039 // Input/output data format for OpConverterTest::BuildAndRun().
1040 struct InputOutputData {
Buffertensorflow::tensorrt::convert::InputOutputData1041 void* Buffer() const {
1042 return const_cast<char*>(tensor.tensor_data().data());
1043 }
1044
TotalBytestensorflow::tensorrt::convert::InputOutputData1045 size_t TotalBytes() const { return tensor.TotalBytes(); }
1046
1047 const char* name;
1048 Tensor tensor;
1049 };
1050
1051 template <typename T>
ConstructTensor(int data_size,const T & value=T ())1052 Tensor ConstructTensor(int data_size, const T& value = T()) {
1053 std::vector<T> values(data_size, value);
1054 return test::AsTensor<T>(values);
1055 }
1056
1057 using DataVec = std::vector<InputOutputData>;
1058
1059 template <typename T>
GetSpanForData(const InputOutputData & data)1060 inline absl::Span<const T> GetSpanForData(const InputOutputData& data) {
1061 const auto& tensor_map = data.tensor.flat<T>();
1062 return absl::Span<const T>(tensor_map.data(), tensor_map.size());
1063 }
1064
1065 // Class to test various op converters, using both a TrtNodeValidator and
1066 // Converter.
1067 class OpConverterTest : public ::testing::Test {
1068 public:
OpConverterTest()1069 OpConverterTest() : scope_(Scope::NewRootScope()) {
1070 QCHECK_EQ(0, cudaStreamCreate(&stream_));
1071 Reset();
1072 }
1073
~OpConverterTest()1074 ~OpConverterTest() override { QCHECK_EQ(0, cudaStreamDestroy(stream_)); }
1075
GetTensorOrWeights(const string & name,TRT_TensorOrWeights * output)1076 Status GetTensorOrWeights(const string& name, TRT_TensorOrWeights* output) {
1077 return converter_->GetTensorOrWeights(name, output);
1078 }
1079
Reset()1080 void Reset() {
1081 validator_.reset(nullptr);
1082 converter_.reset(nullptr);
1083
1084 // Reset the INetworkDefinition.
1085 engine_.reset(nullptr);
1086 network_.reset(nullptr);
1087 builder_.reset(nvinfer1::createInferBuilder(logger_));
1088 network_.reset(builder_->createNetwork());
1089 builder_->setMaxBatchSize(1);
1090 builder_->setMaxWorkspaceSize(1 << 26);
1091
1092 // Reset the validator and converter.
1093 validator_.reset(new TrtNodeValidator);
1094 converter_.reset(new Converter(network_.get(), precision_mode_to_test_,
1095 /*use_calibration=*/false));
1096
1097 // Reset other related artifacts.
1098 scope_ = Scope::NewRootScope();
1099 validator_inputs_.clear();
1100 }
1101
CheckDataTypeMatches(const DataVec & datas)1102 void CheckDataTypeMatches(const DataVec& datas) {
1103 for (const auto& data : datas) {
1104 const int input_index = engine_->getBindingIndex(data.name);
1105 ASSERT_NE(-1, input_index);
1106 const nvinfer1::DataType trt_dtype =
1107 engine_->getBindingDataType(input_index);
1108 const DataType tf_dtype = TrtDataTypeToTf(trt_dtype);
1109 ASSERT_EQ(data.tensor.dtype(), tf_dtype)
1110 << DataTypeString(data.tensor.dtype()) << " vs. "
1111 << DataTypeString(tf_dtype);
1112 }
1113 }
1114
1115 // TODO(laigd): test fp16 and int8 support for more converters.
BuildAndRun(const DataVec & input_data,DataVec * output_data,TrtPrecisionMode precision_mode=TrtPrecisionMode::FP32)1116 void BuildAndRun(const DataVec& input_data, DataVec* output_data,
1117 TrtPrecisionMode precision_mode = TrtPrecisionMode::FP32) {
1118 // Mark the output tensor as TRT engine output.
1119 std::vector<Converter::EngineOutputInfo> output_info;
1120 for (const auto& data : *output_data) {
1121 output_info.push_back(
1122 {data.name, data.name, TfDataTypeToTrt(data.tensor.dtype())});
1123 }
1124 TF_EXPECT_OK(converter_->RenameAndMarkOutputTensors(output_info));
1125
1126 // Build the TRT engine.
1127 if (precision_mode == TrtPrecisionMode::FP16) {
1128 builder_->setFp16Mode(true);
1129 } else if (precision_mode == TrtPrecisionMode::INT8) {
1130 // Setting FP16 mode as well allows TRT to also consider FP16 kernels and
1131 // use them in situations where they are faster than INT8 or where INT8 is
1132 // not supported for a given layer.
1133 builder_->setFp16Mode(true);
1134 builder_->setInt8Mode(true);
1135 }
1136 ASSERT_EQ(nullptr, engine_.get());
1137 engine_.reset(builder_->buildCudaEngine(*converter_->network()));
1138 CHECK_NOTNULL(engine_.get());
1139 CheckDataTypeMatches(input_data);
1140 CheckDataTypeMatches(*output_data);
1141
1142 // Execute the TRT engine.
1143 const int num_bindings = input_data.size() + output_data->size();
1144 std::vector<void*> buffers(num_bindings);
1145
1146 for (const auto& data : input_data) {
1147 const int input_index = engine_->getBindingIndex(data.name);
1148 ASSERT_EQ(0, cudaMalloc(&buffers[input_index], data.TotalBytes()));
1149 ASSERT_EQ(0, cudaMemcpyAsync(buffers[input_index], data.Buffer(),
1150 data.TotalBytes(), cudaMemcpyHostToDevice,
1151 stream_));
1152 }
1153 struct SizeAndIndex {
1154 SizeAndIndex(int in_size, int in_index)
1155 : size(in_size), index(in_index) {}
1156 int size;
1157 int index;
1158 };
1159 std::vector<SizeAndIndex> output_infos;
1160 for (const auto& data : *output_data) {
1161 const int output_index = engine_->getBindingIndex(data.name);
1162 output_infos.emplace_back(data.TotalBytes(), output_index);
1163 ASSERT_EQ(0, cudaMalloc(&buffers[output_index], data.TotalBytes()));
1164 }
1165
1166 ASSERT_EQ(engine_->getNbBindings(), num_bindings);
1167 TrtUniquePtrType<nvinfer1::IExecutionContext> execution_context(
1168 engine_->createExecutionContext());
1169 execution_context->enqueue(/*batchSize=*/1, buffers.data(), stream_,
1170 nullptr);
1171
1172 for (int i = 0; i < output_infos.size(); ++i) {
1173 const auto& output_info = output_infos[i];
1174 ASSERT_EQ(0, cudaMemcpyAsync(output_data->at(i).Buffer(),
1175 buffers[output_info.index], output_info.size,
1176 cudaMemcpyDeviceToHost, stream_));
1177 }
1178 cudaStreamSynchronize(stream_);
1179
1180 for (int i = 0; i < num_bindings; ++i) {
1181 ASSERT_EQ(0, cudaFree(buffers[i]));
1182 }
1183 }
1184
HasStaticShape(const nvinfer1::Dims & dims) const1185 bool HasStaticShape(const nvinfer1::Dims& dims) const {
1186 if (dims.nbDims < 0) return false;
1187 for (int i = 0; i < dims.nbDims; ++i) {
1188 if (dims.d[i] < 0) return false;
1189 }
1190 return true;
1191 }
1192
1193 // Add ITensor for both validation and conversion.
AddTestTensor(const char * name,const std::vector<int32> & dims,int batch_size=1,nvinfer1::DataType trt_dtype=nvinfer1::DataType::kFLOAT)1194 void AddTestTensor(
1195 const char* name, const std::vector<int32>& dims, int batch_size = 1,
1196 nvinfer1::DataType trt_dtype = nvinfer1::DataType::kFLOAT) {
1197 DataType tf_dtype = TrtDataTypeToTf(trt_dtype);
1198 ops::Placeholder::Attrs attrs;
1199 TF_EXPECT_OK(TensorShapeUtils::MakeShape(dims, &attrs.shape_));
1200 attrs.shape_.InsertDim(0, batch_size);
1201 auto input = ops::Placeholder(scope_.WithOpName(name), tf_dtype, attrs);
1202 validator_inputs_[name] = input.operation.node()->def();
1203
1204 // Add a real ITensor for conversion conditionally.
1205 const nvinfer1::Dims trt_dims = GetTestDims(dims);
1206 if (HasStaticShape(trt_dims)) {
1207 TF_EXPECT_OK(
1208 converter_->AddInputTensor(name, trt_dtype, trt_dims, batch_size));
1209 ASSERT_EQ(batch_size, converter_->batch_size_);
1210 }
1211 }
1212
1213 // Add weights for both validation and conversion.
1214 template <typename T>
AddTestWeights(const char * name,const std::vector<int> & dims,const std::vector<T> & values)1215 void AddTestWeights(const char* name, const std::vector<int>& dims,
1216 const std::vector<T>& values) {
1217 const DataType dtype = DataTypeToEnum<T>::v();
1218 const nvinfer1::Dims trt_dims = GetTestDims(dims);
1219 const int64_t num_elements = TrtDimsNumElements(trt_dims);
1220 QCHECK_EQ(num_elements, values.size())
1221 << num_elements << " vs " << values.size();
1222 TRT_ShapedWeights weights(dtype);
1223 if (num_elements) {
1224 weights = converter_->weight_store_.GetTempWeights(dtype, trt_dims);
1225 QCHECK_EQ(weights.size_bytes(), sizeof(T) * values.size())
1226 << weights.size_bytes() << " vs " << sizeof(T) * values.size();
1227 memcpy(const_cast<void*>(weights.GetValues()), values.data(),
1228 weights.size_bytes());
1229 }
1230 // Add weights for validation.
1231 TensorShape shape;
1232 TF_EXPECT_OK(TensorShapeUtils::MakeShape(dims, &shape));
1233 validator_inputs_[name] = MakeConstNodeDef<T>(name, values, shape);
1234 // Add weights for conversion.
1235 TF_EXPECT_OK(
1236 converter_->AddTensorOrWeights(name, TRT_TensorOrWeights{weights}));
1237 }
1238
1239 // Test validation in validation-only mode.
RunValidation(const NodeDef & node_def,error::Code expected_code=error::OK,const char * expected_msg_substr=nullptr)1240 void RunValidation(const NodeDef& node_def,
1241 error::Code expected_code = error::OK,
1242 const char* expected_msg_substr = nullptr) {
1243 std::vector<std::pair<const NodeDef*, int>> input_node_and_ports;
1244 for (const string& input : node_def.input()) {
1245 input_node_and_ports.emplace_back(&validator_inputs_[input], 0);
1246 }
1247 grappler::GrapplerItem item;
1248 TF_EXPECT_OK(scope_.ToGraphDef(&item.graph));
1249 grappler::GraphProperties graph_properties(item);
1250 TF_EXPECT_OK(graph_properties.InferStatically(true));
1251
1252 ExpectStatus(
1253 validator_->ValidateNode(node_def, input_node_and_ports,
1254 precision_mode_to_test_, graph_properties),
1255 expected_code, expected_msg_substr);
1256 }
1257
RunConversion(const NodeDef & node_def,error::Code expected_code=error::OK,const char * expected_msg_substr=nullptr)1258 void RunConversion(const NodeDef& node_def,
1259 error::Code expected_code = error::OK,
1260 const char* expected_msg_substr = nullptr) {
1261 ExpectStatus(converter_->ConvertNode(node_def), expected_code,
1262 expected_msg_substr);
1263 }
1264
1265 // Helper method to run both validation and conversion, when the expected
1266 // output are same.
RunValidationAndConversion(const NodeDef & node_def,error::Code expected_code=error::OK,const char * expected_msg_substr=nullptr,bool should_run_conversion=true)1267 void RunValidationAndConversion(const NodeDef& node_def,
1268 error::Code expected_code = error::OK,
1269 const char* expected_msg_substr = nullptr,
1270 bool should_run_conversion = true) {
1271 RunValidation(node_def, expected_code, expected_msg_substr);
1272 if (should_run_conversion) {
1273 RunConversion(node_def, expected_code, expected_msg_substr);
1274 }
1275 }
1276
1277 // Expose quantization_ranges_ for tests
quantization_ranges()1278 std::unordered_map<nvinfer1::ITensor*, float>& quantization_ranges() {
1279 return converter_->quantization_ranges_;
1280 }
1281
1282 std::unique_ptr<Converter> converter_;
1283 std::unique_ptr<TrtNodeValidator> validator_;
1284
1285 protected:
1286 // TODO(laigd): parameterize the test and make the precision mode a parameter.
1287 TrtPrecisionMode precision_mode_to_test_ = TrtPrecisionMode::FP32;
1288
1289 private:
1290 Logger logger_;
1291 TrtUniquePtrType<nvinfer1::IBuilder> builder_;
1292 TrtUniquePtrType<nvinfer1::INetworkDefinition> network_;
1293 TrtUniquePtrType<nvinfer1::ICudaEngine> engine_;
1294 cudaStream_t stream_;
1295 // Used to create placeholders with shape and data type information. The
1296 // created placeholders will be used as inputs to the node to be verified,
1297 // thus we need the shape and data type information to get a non-empty
1298 // GraphProperties.
1299 // TODO(laigd): consider use this Scope to create the NodeDef to verify.
1300 Scope scope_;
1301 std::unordered_map<string, NodeDef> validator_inputs_;
1302 };
1303
1304 template <typename T>
CopyTensorElements(const Tensor & tensor,protobuf::RepeatedField<T> * out)1305 void CopyTensorElements(const Tensor& tensor, protobuf::RepeatedField<T>* out) {
1306 out->Clear();
1307 if (tensor.NumElements() == 0) return;
1308
1309 // TensorProto does not need to have all the elements present and can truncate
1310 // trailing elements with the same value for compressed representation. Such
1311 // elements are derived based on the tensor shape.
1312 const auto flat = tensor.flat<T>();
1313 int64 last_index = 0;
1314 for (int64 i = 0; i < tensor.NumElements(); ++i) {
1315 if (flat(i) != flat(last_index)) {
1316 last_index = i;
1317 }
1318 }
1319
1320 int num_out_elements = last_index + 1;
1321 out->Reserve(num_out_elements);
1322 out->AddNAlreadyReserved(num_out_elements);
1323 const T* src = flat.data();
1324 T* dst = out->mutable_data();
1325 std::copy(src, src + num_out_elements, dst);
1326 }
1327
1328 template <DataType dtype, typename InputCType, typename OutputCType>
TestConvertConst(OpConverterTest * test)1329 void TestConvertConst(OpConverterTest* test) {
1330 NodeDef node_def;
1331 node_def.set_name("my_const");
1332 node_def.set_op("Const");
1333
1334 auto reset_and_test = [&node_def, test](
1335 const Tensor& tensor, const bool as_tensor_content,
1336 const std::vector<int>& expected_dims,
1337 const std::vector<OutputCType>& expected_value) {
1338 test->Reset();
1339
1340 TensorProto* tensor_attr =
1341 (*node_def.mutable_attr())["value"].mutable_tensor();
1342 tensor_attr->Clear();
1343
1344 if (as_tensor_content) {
1345 tensor.AsProtoTensorContent(tensor_attr);
1346 } else {
1347 tensor.shape().AsProto(tensor_attr->mutable_tensor_shape());
1348 tensor_attr->set_dtype(tensor.dtype());
1349
1350 if (tensor.dtype() == DT_FLOAT) {
1351 CopyTensorElements<float>(tensor, tensor_attr->mutable_float_val());
1352 } else if (tensor.dtype() == DT_INT32) {
1353 CopyTensorElements<int32>(tensor, tensor_attr->mutable_int_val());
1354 } else {
1355 tensor.AsProtoField(tensor_attr);
1356 }
1357 }
1358 test->RunValidationAndConversion(node_def);
1359 TRT_TensorOrWeights output;
1360 TF_EXPECT_OK(test->GetTensorOrWeights("my_const", &output));
1361 ValidateWeights(output.weights(), expected_dims, expected_value);
1362 };
1363
1364 auto& attr = *node_def.mutable_attr();
1365 attr["dtype"].set_type(dtype);
1366 {
1367 // By default empty tensor will pick DT_FLOAT as data type and we fix it
1368 // here.
1369 Tensor t(dtype); // Empty tensor.
1370 reset_and_test(t, false, {}, {});
1371 }
1372 {
1373 Tensor t = test::AsScalar<InputCType>(12);
1374 reset_and_test(t, false, {1}, {12});
1375 reset_and_test(t, true, {1}, {12});
1376 }
1377 {
1378 Tensor t = test::AsTensor<InputCType>({1, 2});
1379 reset_and_test(t, false, {2}, {1, 2});
1380 reset_and_test(t, true, {2}, {1, 2});
1381 }
1382 {
1383 Tensor t =
1384 test::AsTensor<InputCType>({1, 2, 3, 4, 5, 6}, TensorShape({2, 3}));
1385 reset_and_test(t, false, {2, 3}, {1, 2, 3, 4, 5, 6});
1386 reset_and_test(t, true, {2, 3}, {1, 2, 3, 4, 5, 6});
1387 }
1388 {
1389 // Set all tensor elements to the same value. Such tensors are encoded
1390 // using a single element list in tensor proto.
1391 Tensor t =
1392 test::AsTensor<InputCType>({1, 1, 1, 1, 1, 1}, TensorShape({2, 3}));
1393 reset_and_test(t, false, {2, 3}, {1, 1, 1, 1, 1, 1});
1394 reset_and_test(t, true, {2, 3}, {1, 1, 1, 1, 1, 1});
1395 }
1396 {
1397 // Set trailing tensor elements to the same value. Such tensors are
1398 // encoded by truncating all equal elements except the first one.
1399 Tensor t =
1400 test::AsTensor<InputCType>({2, 2, 1, 1, 1, 1}, TensorShape({2, 3}));
1401 reset_and_test(t, false, {2, 3}, {2, 2, 1, 1, 1, 1});
1402 reset_and_test(t, true, {2, 3}, {2, 2, 1, 1, 1, 1});
1403 }
1404 }
1405
TEST_F(OpConverterTest,ConvertConst)1406 TEST_F(OpConverterTest, ConvertConst) {
1407 {
1408 Reset();
1409 NodeDef node_def = MakeNodeDef("my_const", "Const", {"input"});
1410 AddTestTensor("input", {1});
1411 RunValidationAndConversion(
1412 node_def, error::INVALID_ARGUMENT,
1413 "Constant node is expected to have empty input list: my_const");
1414 }
1415 {
1416 Reset();
1417 NodeDef node_def = MakeConstNodeDef<double>("my_const", {});
1418 RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
1419 "Unsupported data type double");
1420 }
1421
1422 TestConvertConst<DT_FLOAT, float, float>(this);
1423 TestConvertConst<DT_INT8, int8, int32>(this);
1424 TestConvertConst<DT_INT32, int32, int32>(this);
1425 }
1426
TEST_F(OpConverterTest,ConvertTranspose)1427 TEST_F(OpConverterTest, ConvertTranspose) {
1428 {
1429 // Input list is empty, should fail.
1430 NodeDef node_def = MakeNodeDef("my_transpose", "Transpose", {});
1431 RunValidationAndConversion(
1432 node_def, error::INVALID_ARGUMENT,
1433 "Transpose got 0 inputs but expected 2, at my_transpose");
1434 }
1435
1436 // Get the NodeDef for Transpose.
1437 Scope s = Scope::NewRootScope();
1438 auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
1439 auto weights = ops::Placeholder(s.WithOpName("weights"), DT_INT32);
1440 auto transpose = ops::Transpose(s.WithOpName("my_transpose"), input, weights);
1441 const NodeDef& node_def = transpose.operation.node()->def();
1442
1443 {
1444 // Permutation is a tensor, should fail.
1445 Reset();
1446 AddTestTensor("input", {1, 2, 3});
1447 AddTestTensor("weights", {3});
1448 RunValidationAndConversion(
1449 node_def, error::UNIMPLEMENTED,
1450 "The input \"perm\" for Transpose must be a constant, at my_transpose");
1451 }
1452 {
1453 // Transpose at batch dimension, should fail.
1454 Reset();
1455 AddTestTensor("input", {1, 2, 3});
1456 AddTestWeights<int32>("weights", {4}, {1, 0, 2, 3});
1457 RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
1458 "Transpose at batch dimension is not supported");
1459 }
1460 {
1461 // Permutation rank doesn't match, should fail.
1462 Reset();
1463 AddTestTensor("input", {1, 2, 3});
1464 AddTestWeights<int32>("weights", {3}, {0, 1, 2});
1465 RunValidationAndConversion(
1466 node_def, error::INVALID_ARGUMENT,
1467 "Rank of perm for transpose does not match with that of the input.");
1468 }
1469 {
1470 // Ok.
1471 Reset();
1472 AddTestTensor("input", {1, 2, 3});
1473 AddTestWeights<int32>("weights", {4}, {0, 3, 1, 2});
1474 RunValidationAndConversion(node_def);
1475 TRT_TensorOrWeights output;
1476 TF_EXPECT_OK(GetTensorOrWeights("my_transpose", &output));
1477 EXPECT_TRUE(output.is_tensor());
1478 ExpectTrtDimsEqualsArray({3, 1, 2}, output.tensor()->getDimensions());
1479
1480 const DataVec input_data{
1481 {"input", test::AsTensor<float>({1, 2, 3, 4, 5, 6})}};
1482 DataVec output_data{{"my_transpose", ConstructTensor<float>(6)}};
1483 BuildAndRun(input_data, &output_data);
1484 EXPECT_THAT(GetSpanForData<float>(output_data[0]),
1485 ElementsAre(1, 4, 2, 5, 3, 6));
1486 }
1487 }
1488
TEST_F(OpConverterTest,ConvertReshape)1489 TEST_F(OpConverterTest, ConvertReshape) {
1490 {
1491 // Input list is empty, should fail.
1492 NodeDef node_def = MakeNodeDef("my_reshape", "Reshape", {});
1493 RunValidationAndConversion(
1494 node_def, error::INVALID_ARGUMENT,
1495 "Reshape got 0 inputs but expected 2, at my_reshape");
1496 }
1497
1498 // Get the NodeDef for Reshape.
1499 Scope s = Scope::NewRootScope();
1500 auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
1501 auto weights = ops::Placeholder(s.WithOpName("weights"), DT_INT32);
1502 auto reshape = ops::Reshape(s.WithOpName("my_reshape"), input, weights);
1503 const NodeDef& node_def = reshape.operation.node()->def();
1504
1505 {
1506 // Shape is a tensor, should fail.
1507 Reset();
1508 AddTestTensor("input", {1, 2, 3});
1509 AddTestTensor("weights", {3});
1510 RunValidationAndConversion(
1511 node_def, error::UNIMPLEMENTED,
1512 "The input \"shape\" for Reshape must be a constant, at my_reshape");
1513 }
1514 {
1515 // Reshape to scalar, should fail.
1516 Reset();
1517 AddTestTensor("input", {1, 2, 3});
1518 AddTestWeights<int32>("weights", {0}, {});
1519 RunValidationAndConversion(
1520 node_def, error::UNIMPLEMENTED,
1521 "Reshape to shape=[] is not supported, at my_reshape");
1522 }
1523
1524 struct TestParams {
1525 int batch_size;
1526 std::vector<int> tensor_dims;
1527 std::vector<int> shape;
1528 };
1529
1530 // Reshape at batch dimension, should fail.
1531 const int kReshapeBatchDimsCases = 5;
1532 TestParams params[kReshapeBatchDimsCases] = {
1533 TestParams{1, {1, 2, 3}, {3, 1, 1, 2}},
1534 TestParams{1, {1, 2, -1}, {-1, 1, 1, 2}},
1535 TestParams{1, {1, 2, 3}, {-1, 1, 1, 2}},
1536 TestParams{-1, {1, 2, 3}, {1, 1, 1, 2}},
1537 TestParams{-1, {-1, 2, 3}, {1, 1, 1, 6}}, // TODO(laigd): it should pass.
1538 };
1539 for (int i = 0; i < kReshapeBatchDimsCases; ++i) {
1540 Reset();
1541 const std::vector<int>& dims = params[i].tensor_dims;
1542 AddTestTensor("input", dims, params[i].batch_size);
1543 AddTestWeights<int32>("weights", {4}, params[i].shape);
1544 RunValidationAndConversion(
1545 node_def, error::UNIMPLEMENTED,
1546 "Reshape on batch dimension is not supported, at my_reshape",
1547 /*should_run_conversion=*/(dims[0] > 0 && dims[1] > 0 && dims[2] > 0));
1548 }
1549
1550 // Reshape on non batch dimensions, ok.
1551 const int kReshapeOKCases = 3;
1552 TestParams ok_params[kReshapeOKCases] = {
1553 TestParams{-1, {1, 2, 3}, {-1, 1, 3, 2}},
1554 TestParams{1, {1, 2, 3}, {-1, 1, 3, 2}},
1555 TestParams{1, {1, 2, 3}, {1, 1, 3, 2}},
1556 };
1557 for (int i = 0; i < kReshapeOKCases; ++i) {
1558 Reset();
1559 AddTestTensor("input", ok_params[i].tensor_dims, ok_params[i].batch_size);
1560 AddTestWeights<int32>("weights", {4}, ok_params[i].shape);
1561 RunValidationAndConversion(node_def);
1562 TRT_TensorOrWeights output;
1563 TF_EXPECT_OK(GetTensorOrWeights("my_reshape", &output));
1564 EXPECT_TRUE(output.is_tensor());
1565 ExpectTrtDimsEqualsArray({1, 3, 2}, output.tensor()->getDimensions());
1566
1567 const DataVec input_data{
1568 {"input", test::AsTensor<float>({1, 2, 3, 4, 5, 6})}};
1569 DataVec output_data{{"my_reshape", ConstructTensor<float>(6)}};
1570 BuildAndRun(input_data, &output_data);
1571 EXPECT_THAT(GetSpanForData<float>(output_data[0]),
1572 ElementsAre(1, 2, 3, 4, 5, 6));
1573 }
1574 }
1575
TEST_F(OpConverterTest,ConvertMatMul)1576 TEST_F(OpConverterTest, ConvertMatMul) {
1577 {
1578 // Input list is empty, should fail.
1579 NodeDef node_def = MakeNodeDef("my_matmul", "MatMul", {});
1580 RunValidationAndConversion(
1581 node_def, error::INVALID_ARGUMENT,
1582 "MatMul got 0 inputs but expected 2, at my_matmul");
1583 }
1584
1585 // Get the NodeDef for MatMul.
1586 auto get_matmul_nodedef = [](DataType dtype, bool transpose_a,
1587 bool transpose_b) -> NodeDef {
1588 Scope s = Scope::NewRootScope();
1589 auto input = ops::Placeholder(s.WithOpName("input"), dtype);
1590 auto weights = ops::Placeholder(s.WithOpName("weights"), dtype);
1591 const auto matmul_attrs =
1592 ops::MatMul::TransposeA(transpose_a).TransposeB(transpose_b);
1593 auto matmul =
1594 ops::MatMul(s.WithOpName("my_matmul"), input, weights, matmul_attrs);
1595 return matmul.operation.node()->def();
1596 };
1597
1598 {
1599 // Unsupported data type.
1600 Reset();
1601 NodeDef node_def = get_matmul_nodedef(DT_INT32, false, false);
1602 AddTestTensor("input", {2}, /*batch_size=*/1, nvinfer1::DataType::kINT32);
1603 AddTestWeights<int32>("weights", {2, 1}, {3, 5});
1604 RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
1605 "Data type int32 is not supported for MatMul, "
1606 "must be one of [float, half], at my_matmul");
1607 }
1608 // transpose_a is set.
1609 for (bool transpose_b : {false, true}) {
1610 Reset();
1611 NodeDef node_def =
1612 get_matmul_nodedef(DT_FLOAT, /*transpose_a=*/true, transpose_b);
1613 AddTestTensor("input", {2}, /*batch_size=*/1);
1614 AddTestWeights<float>("weights", {2, 2}, {0, 1, 2, 3});
1615 RunValidationAndConversion(
1616 node_def, error::INVALID_ARGUMENT,
1617 "transpose_a is not supported for TensorRT FullyConnected");
1618 }
1619 // OK.
1620 for (bool transpose_b : {false, true}) {
1621 Reset();
1622 NodeDef node_def =
1623 get_matmul_nodedef(DT_FLOAT, /*transpose_a=*/false, transpose_b);
1624 AddTestTensor("input", {2}, /*batch_size=*/1);
1625 AddTestWeights<float>("weights", {2, 2}, {0, 1, 2, 3});
1626 RunValidationAndConversion(node_def);
1627 TRT_TensorOrWeights output;
1628 TF_EXPECT_OK(GetTensorOrWeights("my_matmul", &output));
1629 EXPECT_TRUE(output.is_tensor());
1630 ExpectTrtDimsEqualsArray({2}, output.tensor()->getDimensions());
1631
1632 const DataVec input_data{{"input", test::AsTensor<float>({0, 1})}};
1633 DataVec output_data{{"my_matmul", ConstructTensor<float>(2)}};
1634 BuildAndRun(input_data, &output_data);
1635 if (transpose_b) {
1636 EXPECT_THAT(GetSpanForData<float>(output_data[0]), ElementsAre(1, 3));
1637 } else {
1638 EXPECT_THAT(GetSpanForData<float>(output_data[0]), ElementsAre(2, 3));
1639 }
1640 }
1641 }
1642
1643 template <DataType dtype>
TestConvertBiasAdd(OpConverterTest * test)1644 void TestConvertBiasAdd(OpConverterTest* test) {
1645 // Get the NodeDef for BiasAdd.
1646 auto get_biasadd_nodedef = [](const string& data_format) -> NodeDef {
1647 Scope s = Scope::NewRootScope();
1648 auto input = ops::Placeholder(s.WithOpName("input"), dtype);
1649 auto weights = ops::Placeholder(s.WithOpName("weights"), dtype);
1650 const auto biasadd_attrs = ops::BiasAdd::DataFormat(data_format);
1651 auto biasadd =
1652 ops::BiasAdd(s.WithOpName("my_biasadd"), input, weights, biasadd_attrs);
1653 return biasadd.operation.node()->def();
1654 };
1655
1656 typedef typename EnumToDataType<dtype>::Type CType;
1657 for (const string& data_format : {"NHWC", "NCHW"}) {
1658 for (const int trt_input_rank : {1, 2, 3, 4}) {
1659 test->Reset();
1660 NodeDef node_def = get_biasadd_nodedef(data_format);
1661
1662 // Add input, dims_array will be like {2, 1, ..., 1, 3}
1663 std::vector<int32> dims_array(trt_input_rank, 1);
1664 if (trt_input_rank == 1) {
1665 dims_array[0] = (data_format == "NHWC" ? 3 : 2);
1666 } else {
1667 dims_array[0] = 2;
1668 dims_array[trt_input_rank - 1] = 3;
1669 }
1670 test->AddTestTensor("input", dims_array, /*batch_size=*/1,
1671 TfDataTypeToTrt(dtype));
1672
1673 // Add bias weights.
1674 const int channel_size = (data_format == "NHWC" ? 3 : 2);
1675 std::vector<CType> bias(channel_size);
1676 for (int i = 0; i < channel_size; ++i) {
1677 bias[i] = CType(i + 1); // bias will be {1, 2, 3, ...}
1678 }
1679 test->AddTestWeights<CType>("weights", {channel_size}, bias);
1680
1681 // Run the conversion.
1682 test->RunValidationAndConversion(node_def);
1683 TRT_TensorOrWeights output;
1684 TF_EXPECT_OK(test->GetTensorOrWeights("my_biasadd", &output));
1685 EXPECT_TRUE(output.is_tensor());
1686 ExpectTrtDimsEqualsArray(dims_array, output.tensor()->getDimensions());
1687
1688 // Build and run the engine.
1689 const int num_input = TrtDimsNumElements(GetTestDims(dims_array));
1690 ASSERT_EQ(trt_input_rank > 1 ? 6 : (data_format == "NHWC" ? 3 : 2),
1691 num_input);
1692
1693 const DataVec input_data{
1694 {"input", ConstructTensor<CType>(num_input, CType(0))}};
1695 DataVec output_data{{"my_biasadd", ConstructTensor<CType>(num_input)}};
1696 test->BuildAndRun(input_data, &output_data);
1697 if (trt_input_rank == 1) {
1698 if (data_format == "NHWC") {
1699 EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
1700 ElementsAre(CType(1), CType(2), CType(3)));
1701 } else {
1702 EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
1703 ElementsAre(CType(1), CType(2)));
1704 }
1705 } else {
1706 if (data_format == "NHWC") {
1707 EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
1708 ElementsAre(CType(1), CType(2), CType(3), CType(1),
1709 CType(2), CType(3)));
1710 } else {
1711 EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
1712 ElementsAre(CType(1), CType(1), CType(1), CType(2),
1713 CType(2), CType(2)));
1714 }
1715 }
1716 }
1717 }
1718 }
1719
TEST_F(OpConverterTest,ConvertBiasAdd)1720 TEST_F(OpConverterTest, ConvertBiasAdd) {
1721 {
1722 // Input list is empty, should fail.
1723 NodeDef node_def = MakeNodeDef("my_biasadd", "BiasAdd", {});
1724 RunValidationAndConversion(
1725 node_def, error::INVALID_ARGUMENT,
1726 "BiasAdd got 0 inputs but expected 2, at my_biasadd");
1727 }
1728
1729 // OK. Note that kINT32 is not supported by IScaleLayer, so we don't test
1730 // DT_INT32 type here.
1731 TestConvertBiasAdd<DT_FLOAT>(this);
1732 TestConvertBiasAdd<DT_HALF>(this);
1733 }
1734
1735 template <typename OpType>
GetBinaryOpNodeDef(const string & input_name_l,const string & input_name_r,DataType dtype)1736 NodeDef GetBinaryOpNodeDef(const string& input_name_l,
1737 const string& input_name_r, DataType dtype) {
1738 Scope s = Scope::NewRootScope();
1739 auto input_l = ops::Placeholder(s.WithOpName(input_name_l), dtype);
1740 auto input_r = ops::Placeholder(s.WithOpName(input_name_r), dtype);
1741 auto op = OpType(s.WithOpName("my_binary"), input_l, input_r);
1742 return op.operation.node()->def();
1743 }
1744
CheckAddedLayers(OpConverterTest * test,bool expect_scale_layer)1745 void CheckAddedLayers(OpConverterTest* test, bool expect_scale_layer) {
1746 bool element_wise_layer_found = false;
1747 bool scale_layer_found = false;
1748 for (int i = 0; i < test->converter_->network()->getNbLayers(); i++) {
1749 nvinfer1::ILayer* layer = test->converter_->network()->getLayer(i);
1750 if (dynamic_cast<nvinfer1::IScaleLayer*>(layer)) {
1751 scale_layer_found = true;
1752 } else if (dynamic_cast<nvinfer1::IElementWiseLayer*>(layer)) {
1753 element_wise_layer_found = true;
1754 }
1755 }
1756 EXPECT_EQ(expect_scale_layer, scale_layer_found);
1757 EXPECT_NE(expect_scale_layer, element_wise_layer_found);
1758 }
1759
1760 template <typename OpType, DataType dtype>
TestBinaryTensorOpWeightNoBroadcast(OpConverterTest * test)1761 void TestBinaryTensorOpWeightNoBroadcast(OpConverterTest* test) {
1762 typedef typename EnumToDataType<dtype>::Type CType;
1763 for (auto swap_inputs : {false, true}) {
1764 test->Reset();
1765 NodeDef node_def;
1766 if (swap_inputs) {
1767 node_def = GetBinaryOpNodeDef<OpType>("weights", "input", dtype);
1768 } else {
1769 node_def = GetBinaryOpNodeDef<OpType>("input", "weights", dtype);
1770 }
1771
1772 const std::vector<CType> operand1{CType(3), CType(7.5)};
1773 const std::vector<CType> operand2{CType(2), CType(3)};
1774
1775 // It requires the dims to be at least of rank 3 to apply an IScaleLayer.
1776 test->AddTestTensor("input", /*dims=*/{1, 1, 2}, /*batch_size=*/1,
1777 TfDataTypeToTrt(dtype));
1778 test->AddTestWeights<CType>("weights", /*dims=*/{1, 1, 2},
1779 /*values=*/swap_inputs ? operand1 : operand2);
1780 test->RunValidationAndConversion(node_def);
1781
1782 // Make sure it does use BinaryTensorOpWeight, not BinaryTensorOpTensor.
1783 CheckAddedLayers(test, /*expect_scale_layer=*/true);
1784
1785 // Check the dims of the output ITensor.
1786 TRT_TensorOrWeights output;
1787 TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output));
1788 EXPECT_TRUE(output.is_tensor());
1789 ExpectTrtDimsEqualsArray({1, 1, 2}, output.tensor()->getDimensions());
1790
1791 const DataVec input_data{
1792 {"input", test::AsTensor<CType>(swap_inputs ? operand2 : operand1)}};
1793 DataVec output_data{{"my_binary", ConstructTensor<CType>(2)}};
1794 test->BuildAndRun(
1795 input_data, &output_data,
1796 dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32);
1797 if (node_def.op() == "Add") {
1798 EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
1799 ElementsAre(CType(5), CType(10.5)));
1800 } else if (node_def.op() == "Sub") {
1801 EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
1802 ElementsAre(CType(1), CType(4.5)));
1803 } else if (node_def.op() == "Mul") {
1804 EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
1805 ElementsAre(CType(6), CType(22.5)));
1806 } else if (node_def.op() == "Div") {
1807 EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
1808 ElementsAre(CType(1.5), CType(2.5)));
1809 } else if (node_def.op() == "RealDiv") {
1810 EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
1811 ElementsAre(CType(1.5), CType(2.5)));
1812 } else {
1813 ASSERT_TRUE(false);
1814 }
1815 }
1816 }
1817
1818 template <DataType dtype>
TestBinaryTensorOpWeightWithChannelWiseBroadcast(OpConverterTest * test)1819 void TestBinaryTensorOpWeightWithChannelWiseBroadcast(OpConverterTest* test) {
1820 typedef typename EnumToDataType<dtype>::Type CType;
1821 const NodeDef node_def =
1822 GetBinaryOpNodeDef<ops::Add>("input", "weights", dtype);
1823 const std::vector<CType> input{CType(1), CType(2), CType(3), CType(4)};
1824 const std::vector<CType> weights{CType(10), CType(20)};
1825 // There are two types of valid dim pairs which requires channel-wise
1826 // broadcasting:
1827 // - input dims (X Y Z) vs weights dims (X 1 1)
1828 // - input dims (X Y Z) vs weights dims (Z)
1829 // Here X=Z=2 and Y=1.
1830 for (auto weights_dims : std::vector<std::vector<int>>{{2, 1, 1}, {2}}) {
1831 test->Reset();
1832 test->AddTestTensor("input", /*dims=*/{2, 1, 2}, /*batch_size=*/1,
1833 TfDataTypeToTrt(dtype));
1834 test->AddTestWeights<CType>("weights", weights_dims, weights);
1835 test->RunValidationAndConversion(node_def);
1836
1837 // Make sure it does use BinaryTensorOpWeight, not BinaryTensorOpTensor.
1838 CheckAddedLayers(test, /*expect_scale_layer=*/true);
1839
1840 // Check the dims of the output ITensor.
1841 TRT_TensorOrWeights output;
1842 TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output));
1843 EXPECT_TRUE(output.is_tensor());
1844 ExpectTrtDimsEqualsArray({2, 1, 2}, output.tensor()->getDimensions());
1845
1846 const DataVec input_data{{"input", test::AsTensor<CType>(input)}};
1847 DataVec output_data{{"my_binary", ConstructTensor<CType>(4)}};
1848 test->BuildAndRun(input_data, &output_data);
1849 if (weights_dims.size() == 1) {
1850 EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
1851 ElementsAre(CType(11), CType(22), CType(13), CType(24)));
1852 } else {
1853 EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
1854 ElementsAre(CType(11), CType(12), CType(23), CType(24)));
1855 }
1856 }
1857 }
1858
1859 template <DataType dtype>
TestBinaryTensorOpWeightWithUniformlyBroadcast(OpConverterTest * test)1860 void TestBinaryTensorOpWeightWithUniformlyBroadcast(OpConverterTest* test) {
1861 typedef typename EnumToDataType<dtype>::Type CType;
1862 const NodeDef node_def =
1863 GetBinaryOpNodeDef<ops::Add>("input", "weights", dtype);
1864 const std::vector<CType> input{CType(1), CType(2), CType(3), CType(4)};
1865 const std::vector<CType> weights{CType(10)};
1866 test->Reset();
1867 test->AddTestTensor("input", /*dims=*/{2, 1, 2}, /*batch_size=*/1,
1868 TfDataTypeToTrt(dtype));
1869 test->AddTestWeights<CType>("weights", {1, 1, 1, 1}, weights);
1870 test->RunValidationAndConversion(node_def);
1871
1872 // Make sure it does use BinaryTensorOpWeight, not BinaryTensorOpTensor.
1873 CheckAddedLayers(test, /*expect_scale_layer=*/true);
1874
1875 // Check the dims of the output ITensor.
1876 TRT_TensorOrWeights output;
1877 TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output));
1878 EXPECT_TRUE(output.is_tensor());
1879 ExpectTrtDimsEqualsArray({2, 1, 2}, output.tensor()->getDimensions());
1880
1881 const DataVec input_data{{"input", test::AsTensor<CType>(input)}};
1882 DataVec output_data{{"my_binary", ConstructTensor<CType>(4)}};
1883 test->BuildAndRun(input_data, &output_data);
1884 EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
1885 ElementsAre(CType(11), CType(12), CType(13), CType(14)));
1886 }
1887
1888 template <typename OpType>
TestBinaryTensorOpWeightFallback(OpConverterTest * test,const std::vector<int32> & input_dims,const std::vector<int> & weights_dims,error::Code code=error::OK,const char * error_msg_substr=nullptr,const int input_batch_size=1)1889 void TestBinaryTensorOpWeightFallback(OpConverterTest* test,
1890 const std::vector<int32>& input_dims,
1891 const std::vector<int>& weights_dims,
1892 error::Code code = error::OK,
1893 const char* error_msg_substr = nullptr,
1894 const int input_batch_size = 1) {
1895 const DataType dtype = DT_FLOAT;
1896 typedef typename EnumToDataType<dtype>::Type CType;
1897 const size_t num_inputs = TrtDimsNumElements(GetTestDims(input_dims));
1898 const size_t num_weights = TrtDimsNumElements(GetTestDims(weights_dims));
1899
1900 test->Reset();
1901 const NodeDef node_def =
1902 GetBinaryOpNodeDef<OpType>("input", "weights", dtype);
1903 test->AddTestTensor("input", /*dims=*/input_dims, input_batch_size,
1904 TfDataTypeToTrt(dtype));
1905 test->AddTestWeights<CType>(
1906 "weights", /*dims=*/weights_dims,
1907 /*values=*/std::vector<CType>(num_weights, CType(1)));
1908 test->RunValidationAndConversion(node_def, code, error_msg_substr);
1909 if (code != error::OK) return;
1910
1911 // Make sure it does use BinaryTensorOpTensor, not BinaryTensorOpWeight.
1912 CheckAddedLayers(test, /*expect_scale_layer=*/false);
1913
1914 TRT_TensorOrWeights output;
1915 TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output));
1916 EXPECT_TRUE(output.is_tensor());
1917
1918 // Check the dims of the output ITensor.
1919 std::vector<int> expected_output_dims = input_dims;
1920 for (int i = expected_output_dims.size() - 1, j = weights_dims.size() - 1;
1921 i >= 0 && j >= 0; --i, --j) {
1922 if (expected_output_dims[i] == 1) {
1923 expected_output_dims[i] = weights_dims[j];
1924 }
1925 }
1926 ExpectTrtDimsEqualsArray(expected_output_dims,
1927 output.tensor()->getDimensions());
1928
1929 // Check the result of running the engine.
1930 const int expected_num_outputs =
1931 TrtDimsNumElements(GetTestDims(expected_output_dims));
1932 const DataVec input_data{
1933 {"input", ConstructTensor<CType>(num_inputs, CType(2))}};
1934 DataVec output_data{
1935 {"my_binary", ConstructTensor<CType>(expected_num_outputs)}};
1936 test->BuildAndRun(input_data, &output_data);
1937 if (node_def.op() == "Add") {
1938 EXPECT_THAT(
1939 GetSpanForData<CType>(output_data[0]),
1940 ElementsAreArray(std::vector<CType>(expected_num_outputs, CType(3))));
1941 } else if (node_def.op() == "Minimum") {
1942 EXPECT_THAT(
1943 GetSpanForData<CType>(output_data[0]),
1944 ElementsAreArray(std::vector<CType>(expected_num_outputs, CType(1))));
1945 } else {
1946 ASSERT_TRUE(false);
1947 }
1948 }
1949
1950 template <typename OpType, DataType dtype>
TestBinaryTensorOpTensor(OpConverterTest * test)1951 void TestBinaryTensorOpTensor(OpConverterTest* test) {
1952 typedef typename EnumToDataType<dtype>::Type CType;
1953 test->Reset();
1954 const NodeDef node_def =
1955 GetBinaryOpNodeDef<OpType>("input1", "input2", dtype);
1956 test->AddTestTensor("input1", /*dims=*/{1, 2}, /*batch_size=*/1,
1957 TfDataTypeToTrt(dtype));
1958 test->AddTestTensor("input2", /*dims=*/{2, 1}, /*batch_size=*/1,
1959 TfDataTypeToTrt(dtype));
1960 test->RunValidationAndConversion(node_def);
1961
1962 // Make sure it does use BinaryTensorOpTensor, not BinaryTensorOpWeight.
1963 CheckAddedLayers(test, /*expect_scale_layer=*/false);
1964
1965 // Check output dims.
1966 TRT_TensorOrWeights output;
1967 TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output));
1968 EXPECT_TRUE(output.is_tensor());
1969 ExpectTrtDimsEqualsArray({2, 2}, output.tensor()->getDimensions());
1970
1971 const DataVec input_data{
1972 {"input1", test::AsTensor<CType>({CType(3), CType(6)})},
1973 {"input2", test::AsTensor<CType>({CType(2), CType(3)})}};
1974 DataVec output_data{{"my_binary", ConstructTensor<CType>(4)}};
1975 // After broadcasting first input becomes {3, 6, 3, 6} and second input
1976 // becomes {2, 3, 2, 3}.
1977 test->BuildAndRun(
1978 input_data, &output_data,
1979 dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32);
1980 if (node_def.op() == "Add") {
1981 EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
1982 ElementsAre(CType(5), CType(8), CType(6), CType(9)));
1983 } else if (node_def.op() == "Sub") {
1984 EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
1985 ElementsAre(CType(1), CType(4), CType(0), CType(3)));
1986 } else if (node_def.op() == "Mul") {
1987 EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
1988 ElementsAre(CType(6), CType(12), CType(9), CType(18)));
1989 } else if (node_def.op() == "Div") {
1990 EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
1991 ElementsAre(CType(1.5), CType(3), CType(1), CType(2)));
1992 } else if (node_def.op() == "RealDiv") {
1993 EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
1994 ElementsAre(CType(1.5), CType(3), CType(1), CType(2)));
1995 } else if (node_def.op() == "Minimum") {
1996 EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
1997 ElementsAre(CType(2), CType(2), CType(3), CType(3)));
1998 } else if (node_def.op() == "Maximum") {
1999 EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
2000 ElementsAre(CType(3), CType(6), CType(3), CType(6)));
2001 } else if (node_def.op() == "Pow") {
2002 ExpectArrayNear(
2003 std::vector<CType>{CType(9), CType(36), CType(27), CType(216)},
2004 GetSpanForData<CType>(output_data[0]));
2005 } else {
2006 ASSERT_TRUE(false);
2007 }
2008 }
2009
TEST_F(OpConverterTest,ConvertBinary)2010 TEST_F(OpConverterTest, ConvertBinary) {
2011 AttrValue dtype;
2012 dtype.set_type(DT_FLOAT);
2013 // Input size doesn't match, should fail.
2014 for (size_t num_inputs = 0; num_inputs < 2; ++num_inputs) {
2015 Reset();
2016 NodeDef node_def =
2017 MakeNodeDef("my_add", "Add", {num_inputs, "input"}, {{"T", dtype}});
2018 AddTestTensor("input", {1}, /*batch_size=*/1, nvinfer1::DataType::kFLOAT);
2019 RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
2020 StrCat("Add got ", std::to_string(num_inputs),
2021 " inputs but expected 2, at my_add")
2022 .c_str());
2023 }
2024 {
2025 // Both inputs are weights.
2026 Reset();
2027 NodeDef node_def =
2028 MakeNodeDef("my_add", "Add", {"weights1", "weights2"}, {{"T", dtype}});
2029 AddTestWeights<float>("weights1", {1}, {1});
2030 AddTestWeights<float>("weights2", {1}, {1});
2031 RunValidationAndConversion(
2032 node_def, error::UNIMPLEMENTED,
2033 "Constant folding is falled back to TensorFlow, binary op received "
2034 "both input as constant at: my_add");
2035 }
2036
2037 // Test BinaryTensorOpWeight() without broadcasting.
2038 TestBinaryTensorOpWeightNoBroadcast<ops::Add, DT_FLOAT>(this);
2039 TestBinaryTensorOpWeightNoBroadcast<ops::Sub, DT_FLOAT>(this);
2040 TestBinaryTensorOpWeightNoBroadcast<ops::Mul, DT_FLOAT>(this);
2041 TestBinaryTensorOpWeightNoBroadcast<ops::Div, DT_FLOAT>(this);
2042 TestBinaryTensorOpWeightNoBroadcast<ops::RealDiv, DT_FLOAT>(this);
2043
2044 TestBinaryTensorOpWeightNoBroadcast<ops::Add, DT_HALF>(this);
2045 TestBinaryTensorOpWeightNoBroadcast<ops::Sub, DT_HALF>(this);
2046 TestBinaryTensorOpWeightNoBroadcast<ops::Mul, DT_HALF>(this);
2047 TestBinaryTensorOpWeightNoBroadcast<ops::Div, DT_HALF>(this);
2048 TestBinaryTensorOpWeightNoBroadcast<ops::RealDiv, DT_HALF>(this);
2049
2050 // Test BinaryTensorOpWeight() with channel-wise broadcasting.
2051 TestBinaryTensorOpWeightWithChannelWiseBroadcast<DT_FLOAT>(this);
2052
2053 // Test BinaryTensorOpWeight() with uniformly broadcasting.
2054 TestBinaryTensorOpWeightWithUniformlyBroadcast<DT_FLOAT>(this);
2055
2056 // Test BinaryTensorOpWeight() falling back to BinaryTensorOpTensor().
2057 // Unsupported op.
2058 TestBinaryTensorOpWeightFallback<ops::Minimum>(this, {1, 1, 1}, {1});
2059 // Rank of input tensor dimension <3.
2060 TestBinaryTensorOpWeightFallback<ops::Add>(this, {1, 1}, {1});
2061 // Broadcast on batch dimension, should fail.
2062 TestBinaryTensorOpWeightFallback<ops::Add>(
2063 this, {1, 1, 1}, {2, 1, 1, 1}, error::INVALID_ARGUMENT,
2064 "Unsupported binary op broadcast scheme for op my_binary",
2065 /*input_batch_size=*/2);
2066 // Incompatible dims with per-channel mode.
2067 TestBinaryTensorOpWeightFallback<ops::Add>(this, {1, 1, 1}, {1, 2, 1});
2068 // Incompatible dims.
2069 TestBinaryTensorOpWeightFallback<ops::Add>(this, {1, 2, 1}, {2});
2070
2071 // Test BinaryTensorOpTensor() with broadcasting.
2072 TestBinaryTensorOpTensor<ops::Add, DT_FLOAT>(this);
2073 TestBinaryTensorOpTensor<ops::Sub, DT_FLOAT>(this);
2074 TestBinaryTensorOpTensor<ops::Mul, DT_FLOAT>(this);
2075 TestBinaryTensorOpTensor<ops::Div, DT_FLOAT>(this);
2076 TestBinaryTensorOpTensor<ops::RealDiv, DT_FLOAT>(this);
2077 TestBinaryTensorOpTensor<ops::Minimum, DT_FLOAT>(this);
2078 TestBinaryTensorOpTensor<ops::Maximum, DT_FLOAT>(this);
2079 TestBinaryTensorOpTensor<ops::Pow, DT_FLOAT>(this);
2080
2081 TestBinaryTensorOpTensor<ops::Add, DT_HALF>(this);
2082 TestBinaryTensorOpTensor<ops::Sub, DT_HALF>(this);
2083 TestBinaryTensorOpTensor<ops::Mul, DT_HALF>(this);
2084 TestBinaryTensorOpTensor<ops::Div, DT_HALF>(this);
2085 TestBinaryTensorOpTensor<ops::RealDiv, DT_HALF>(this);
2086 TestBinaryTensorOpTensor<ops::Minimum, DT_HALF>(this);
2087 TestBinaryTensorOpTensor<ops::Maximum, DT_HALF>(this);
2088 TestBinaryTensorOpTensor<ops::Pow, DT_HALF>(this);
2089 }
2090
TEST_F(OpConverterTest,ConvertQuantize)2091 TEST_F(OpConverterTest, ConvertQuantize) {
2092 precision_mode_to_test_ = TrtPrecisionMode::INT8;
2093 const std::pair<string, int> op_with_num_inputs[4] = {
2094 {"FakeQuantWithMinMaxArgs", 1},
2095 {"FakeQuantWithMinMaxVars", 3},
2096 {"QuantizeAndDequantizeV2", 3},
2097 {"QuantizeAndDequantizeV3", 4}};
2098 for (const auto& pair : op_with_num_inputs) {
2099 // Input list is empty, should fail.
2100 NodeDef node_def = MakeNodeDef("my_quantize", pair.first, {});
2101 RunValidationAndConversion(
2102 node_def, error::INVALID_ARGUMENT,
2103 StrCat(pair.first, " got 0 inputs but expected ",
2104 std::to_string(pair.second), ", at my_quantize")
2105 .c_str());
2106 }
2107 {
2108 // FakeQuantWithMinMaxArgs attributes are empty, should fail.
2109 NodeDef node_def =
2110 MakeNodeDef("my_quantize", "FakeQuantWithMinMaxArgs", {"input"});
2111 AddTestTensor("input", {1, 2, 3});
2112 RunValidationAndConversion(
2113 node_def, error::INVALID_ARGUMENT,
2114 "Min or max attribute not found for FakeQuantWithMinMaxArgs "
2115 "at my_quantize");
2116 }
2117 {
2118 // FakeQuantWithMinMaxArgs ranges set via attributes, ok.
2119 Reset();
2120 Scope s = Scope::NewRootScope();
2121 auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
2122 auto quantize_attrs = ops::FakeQuantWithMinMaxArgs::Min(-6.0f).Max(6.0f);
2123 auto quantize = ops::FakeQuantWithMinMaxArgs(s.WithOpName("my_quantize"),
2124 input, quantize_attrs);
2125 const NodeDef& node_def = quantize.operation.node()->def();
2126 AddTestTensor("input", {1, 2, 3});
2127 RunValidationAndConversion(node_def);
2128 TRT_TensorOrWeights output;
2129 TF_EXPECT_OK(GetTensorOrWeights("my_quantize", &output));
2130 EXPECT_TRUE(output.is_tensor());
2131 auto ranges = quantization_ranges();
2132 EXPECT_EQ(1, ranges.count(output.tensor()));
2133 EXPECT_EQ(6.0f, ranges[output.tensor()]);
2134 }
2135 {
2136 // FakeQuantWithMinMaxVars ranges set via inputs, ok.
2137 Reset();
2138 Scope s = Scope::NewRootScope();
2139 auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
2140 auto weights_min = ops::Placeholder(s.WithOpName("weights_min"), DT_FLOAT);
2141 auto weights_max = ops::Placeholder(s.WithOpName("weights_max"), DT_FLOAT);
2142 auto quantize = ops::FakeQuantWithMinMaxVars(
2143 s.WithOpName("my_quantize"), input, weights_min, weights_max);
2144 const NodeDef& node_def = quantize.operation.node()->def();
2145 AddTestTensor("input", {1, 2, 3});
2146 AddTestWeights<float>("weights_min", {1}, {-6.0f});
2147 AddTestWeights<float>("weights_max", {1}, {6.0f});
2148 RunValidationAndConversion(node_def);
2149 TRT_TensorOrWeights output;
2150 TF_EXPECT_OK(GetTensorOrWeights("my_quantize", &output));
2151 EXPECT_TRUE(output.is_tensor());
2152 auto ranges = quantization_ranges();
2153 EXPECT_EQ(1, ranges.count(output.tensor()));
2154 EXPECT_EQ(6.0f, ranges[output.tensor()]);
2155 }
2156 {
2157 // QuantizeAndDequantizeV2 ranges set via inputs, ok.
2158 Reset();
2159 Scope s = Scope::NewRootScope();
2160 auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
2161 auto weights_min = ops::Placeholder(s.WithOpName("weights_min"), DT_FLOAT);
2162 auto weights_max = ops::Placeholder(s.WithOpName("weights_max"), DT_FLOAT);
2163 auto quantize = ops::QuantizeAndDequantizeV2(
2164 s.WithOpName("my_quantize"), input, weights_min, weights_max);
2165 const NodeDef& node_def = quantize.operation.node()->def();
2166 AddTestTensor("input", {1, 2, 3});
2167 AddTestWeights<float>("weights_min", {1}, {-6.0f});
2168 AddTestWeights<float>("weights_max", {1}, {6.0f});
2169 RunValidationAndConversion(node_def);
2170 TRT_TensorOrWeights output;
2171 TF_EXPECT_OK(GetTensorOrWeights("my_quantize", &output));
2172 EXPECT_TRUE(output.is_tensor());
2173 auto ranges = quantization_ranges();
2174 EXPECT_EQ(1, ranges.count(output.tensor()));
2175 EXPECT_EQ(6.0f, ranges[output.tensor()]);
2176 }
2177 {
2178 // QuantizeAndDequantizeV2 Range inputs are tensors, should fail.
2179 Reset();
2180 Scope s = Scope::NewRootScope();
2181 auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
2182 auto weights_min = ops::Placeholder(s.WithOpName("weights_min"), DT_FLOAT);
2183 auto weights_max = ops::Placeholder(s.WithOpName("weights_max"), DT_FLOAT);
2184 auto quantize = ops::QuantizeAndDequantizeV2(
2185 s.WithOpName("my_quantize"), input, weights_min, weights_max);
2186 const NodeDef& node_def = quantize.operation.node()->def();
2187 AddTestTensor("input", {1, 2, 3});
2188 AddTestTensor("weights_min", {1});
2189 AddTestTensor("weights_max", {1});
2190 RunValidationAndConversion(
2191 node_def, error::UNIMPLEMENTED,
2192 "The input \"input_min\" for QuantizeAndDequantizeV2 must be a constant"
2193 ", at my_quantize");
2194 }
2195 {
2196 // QuantizeAndDequantizeV3 ranges set via inputs, ok.
2197 Reset();
2198 Scope s = Scope::NewRootScope();
2199 auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
2200 auto weights_min = ops::Placeholder(s.WithOpName("weights_min"), DT_FLOAT);
2201 auto weights_max = ops::Placeholder(s.WithOpName("weights_max"), DT_FLOAT);
2202 auto num_bits = ops::Placeholder(s.WithOpName("num_bits"), DT_INT32);
2203 auto quantize = ops::QuantizeAndDequantizeV3(
2204 s.WithOpName("my_quantize"), input, weights_min, weights_max, num_bits);
2205 const NodeDef& node_def = quantize.operation.node()->def();
2206 AddTestTensor("input", {1, 2, 3});
2207 AddTestWeights<float>("weights_min", {1}, {-6.0f});
2208 AddTestWeights<float>("weights_max", {1}, {6.0f});
2209 AddTestWeights<int>("num_bits", {1}, {8});
2210 RunValidationAndConversion(node_def);
2211 TRT_TensorOrWeights output;
2212 TF_EXPECT_OK(GetTensorOrWeights("my_quantize", &output));
2213 EXPECT_TRUE(output.is_tensor());
2214 auto ranges = quantization_ranges();
2215 EXPECT_EQ(1, ranges.count(output.tensor()));
2216 EXPECT_EQ(6.0f, ranges[output.tensor()]);
2217 }
2218 }
2219
2220 template <DataType dtype>
TestConvertSquare(OpConverterTest * test)2221 void TestConvertSquare(OpConverterTest* test) {
2222 test->Reset();
2223 typedef typename EnumToDataType<dtype>::Type CType;
2224
2225 Scope s = Scope::NewRootScope();
2226 auto input = ops::Placeholder(s.WithOpName("input"), dtype);
2227 auto square = ops::Square(s.WithOpName("my_square"), input);
2228 NodeDef node_def = square.operation.node()->def();
2229
2230 test->AddTestTensor("input", {1, 20}, /*batch_size=*/1,
2231 TfDataTypeToTrt(dtype));
2232 test->RunValidationAndConversion(node_def);
2233 TRT_TensorOrWeights output;
2234 TF_EXPECT_OK(test->GetTensorOrWeights("my_square", &output));
2235 EXPECT_TRUE(output.is_tensor());
2236 ExpectTrtDimsEqualsArray({1, 20}, output.tensor()->getDimensions());
2237
2238 const int num_inputs = 20;
2239 std::vector<CType> inputs(num_inputs);
2240 std::vector<CType> expected_outputs(num_inputs);
2241 for (int i = 0; i < num_inputs; ++i) {
2242 const CType value = CType(i - 9);
2243 inputs[i] = value;
2244 expected_outputs[i] = value * value;
2245 }
2246 const DataVec input_data{{"input", test::AsTensor<CType>(inputs)}};
2247 // Engine outputs are converted to FP16 automatically if we set FP16 mode in
2248 // the builder.
2249 DataVec output_data{{"my_square", ConstructTensor<CType>(num_inputs)}};
2250 test->BuildAndRun(
2251 input_data, &output_data,
2252 dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32);
2253 ExpectArrayNear(expected_outputs, GetSpanForData<CType>(output_data[0]));
2254 }
2255
TEST_F(OpConverterTest,ConvertSquare)2256 TEST_F(OpConverterTest, ConvertSquare) {
2257 {
2258 // Input list is empty, should fail.
2259 NodeDef node_def = MakeNodeDef("my_square", "Square", {});
2260 RunValidationAndConversion(
2261 node_def, error::INVALID_ARGUMENT,
2262 "Square got 0 inputs but expected 1, at my_square");
2263 }
2264 {
2265 // Input is weights, should fail.
2266 Reset();
2267 Scope s = Scope::NewRootScope();
2268 auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
2269 auto square = ops::Square(s.WithOpName("my_square"), input);
2270 NodeDef node_def = square.operation.node()->def();
2271 AddTestWeights<float>("input", {1, 2, 3}, {1, 2, 3, 4, -5, 6});
2272 RunValidationAndConversion(
2273 node_def, error::UNIMPLEMENTED,
2274 "The input \"x\" for Square must be a tensor, at my_square");
2275 }
2276
2277 // OK. Note that kINT32 is not supported by IElementWiseLayer, so we don't
2278 // test DT_INT32 type here.
2279 TestConvertSquare<DT_FLOAT>(this);
2280 TestConvertSquare<DT_HALF>(this);
2281 }
2282
TEST_F(OpConverterTest,ConvertActivation)2283 TEST_F(OpConverterTest, ConvertActivation) {
2284 {
2285 // Input list is empty, should fail.
2286 NodeDef node_def = MakeNodeDef("my_act", "Relu", {});
2287 RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
2288 "Relu got 0 inputs but expected 1, at my_act");
2289 }
2290 {
2291 // Input is weights, should fail.
2292 Reset();
2293 Scope s = Scope::NewRootScope();
2294 auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
2295 auto relu = ops::Relu(s.WithOpName("my_act"), input);
2296 const NodeDef& node_def = relu.operation.node()->def();
2297 AddTestWeights<int32>("input", {1, 2, 3}, {-3, -2, -1, 0, 1, 2});
2298 RunValidationAndConversion(
2299 node_def, error::UNIMPLEMENTED,
2300 "The input \"input\" for Relu must be a tensor, at my_act");
2301 }
2302
2303 constexpr float kAlpha = 0.2f;
2304
2305 // Get nodedef for activation layer.
2306 auto get_act_nodedef = [](string op_name) -> NodeDef {
2307 Scope s = Scope::NewRootScope();
2308 auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
2309 if (op_name == "LeakyRelu") {
2310 auto act =
2311 ops::internal::LeakyRelu(s.WithOpName("my_act"), input,
2312 ops::internal::LeakyRelu::Alpha(kAlpha));
2313 return act.operation.node()->def();
2314 } else if (op_name == "Relu") {
2315 auto act = ops::Relu(s.WithOpName("my_act"), input);
2316 return act.operation.node()->def();
2317 } else if (op_name == "Relu6") {
2318 auto act = ops::Relu6(s.WithOpName("my_act"), input);
2319 return act.operation.node()->def();
2320 } else if (op_name == "Sigmoid") {
2321 auto act = ops::Sigmoid(s.WithOpName("my_act"), input);
2322 return act.operation.node()->def();
2323 } else if (op_name == "Tanh") {
2324 auto act = ops::Tanh(s.WithOpName("my_act"), input);
2325 return act.operation.node()->def();
2326 }
2327 EXPECT_TRUE(false);
2328 return NodeDef();
2329 };
2330 // Get expected output for activation layer.
2331 auto get_act_output = [](string op_name, float input) -> float {
2332 if (op_name == "LeakyRelu") {
2333 return (input > 0.0f) ? input : input * kAlpha;
2334 } else if (op_name == "Relu") {
2335 return (input > 0.0f) ? input : 0.0f;
2336 } else if (op_name == "Relu6") {
2337 return std::min(std::max(input, 0.0f), 6.0f);
2338 } else if (op_name == "Sigmoid") {
2339 return 1.0f / (1.0f + std::exp(-input));
2340 } else if (op_name == "Tanh") {
2341 return std::tanh(input);
2342 }
2343 EXPECT_TRUE(false);
2344 return 0;
2345 };
2346
2347 // Ok.
2348 for (const string& op_name :
2349 {"LeakyRelu", "Relu", "Relu6", "Sigmoid", "Tanh"}) {
2350 Reset();
2351 NodeDef node_def = get_act_nodedef(op_name);
2352 AddTestTensor("input", {1, 2, 3});
2353 RunValidationAndConversion(node_def);
2354 TRT_TensorOrWeights output;
2355 TF_EXPECT_OK(GetTensorOrWeights("my_act", &output));
2356 EXPECT_TRUE(output.is_tensor());
2357 ExpectTrtDimsEqualsArray({1, 2, 3}, output.tensor()->getDimensions());
2358 if (op_name == "Relu6") {
2359 // Relu6 should set quantization range automatically.
2360 auto ranges = quantization_ranges();
2361 EXPECT_EQ(ranges[output.tensor()], 6.0f);
2362 }
2363
2364 const std::vector<float> input = {-100, -2, -1, 0, 1, 100};
2365 const DataVec input_data{{"input", test::AsTensor<float>(input)}};
2366 DataVec output_data{{"my_act", ConstructTensor<float>(6)}};
2367 BuildAndRun(input_data, &output_data);
2368 for (int i = 0; i < input.size(); i++) {
2369 const float expected_output = get_act_output(op_name, input[i]);
2370 EXPECT_FLOAT_EQ(GetSpanForData<float>(output_data[0])[i],
2371 expected_output);
2372 }
2373 }
2374 }
2375
TEST_F(OpConverterTest,ConvertExpandDims)2376 TEST_F(OpConverterTest, ConvertExpandDims) {
2377 {
2378 // Input list is empty, should fail.
2379 NodeDef node_def = MakeNodeDef("my_expanddims", "ExpandDims", {});
2380 RunValidationAndConversion(
2381 node_def, error::INVALID_ARGUMENT,
2382 "ExpandDims got 0 inputs but expected 2, at my_expanddims");
2383 }
2384
2385 // Get the NodeDef for ExpandDims.
2386 Scope s = Scope::NewRootScope();
2387 auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
2388 auto weights = ops::Placeholder(s.WithOpName("weights"), DT_INT32);
2389 auto expanddims =
2390 ops::ExpandDims(s.WithOpName("my_expanddims"), input, weights);
2391 const NodeDef& node_def = expanddims.operation.node()->def();
2392 {
2393 // Input is weights, should fail.
2394 Reset();
2395 AddTestWeights<int32>("input", {1, 2, 3}, {1, 2, 3, 4, 5, 6});
2396 AddTestWeights<int32>("weights", {1}, {1});
2397 RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
2398 "The input \"input\" for ExpandDims must be a "
2399 "tensor, at my_expanddims");
2400 }
2401 {
2402 // Axis is a tensor, should fail.
2403 Reset();
2404 AddTestTensor("input", {1, 2, 3});
2405 AddTestTensor("weights", {3});
2406 RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
2407 "The input \"axis\" for ExpandDims must be a "
2408 "constant, at my_expanddims");
2409 }
2410 {
2411 // Add dim at batch dimension, should fail.
2412 Reset();
2413 AddTestTensor("input", {1, 2, 3});
2414 AddTestWeights<int32>("weights", {1}, {0});
2415 RunValidationAndConversion(
2416 node_def, error::UNIMPLEMENTED,
2417 "Modifying batch dimension is not supported for ExpandDims, at "
2418 "my_expanddims");
2419 }
2420 {
2421 // Add dim at batch dimension via negative axis, should fail.
2422 Reset();
2423 AddTestTensor("input", {1, 2, 3});
2424 // Input is rank 4 (batch dim included)
2425 AddTestWeights<int32>("weights", {1}, {-5});
2426 RunValidationAndConversion(
2427 node_def, error::UNIMPLEMENTED,
2428 "Modifying batch dimension is not supported for ExpandDims, at "
2429 "my_expanddims");
2430 }
2431 {
2432 // Axis > rank(input), should fail.
2433 Reset();
2434 AddTestTensor("input", {1, 2, 3});
2435 // Input is rank 4 (batch dim included)
2436 AddTestWeights<int32>("weights", {1}, {5});
2437 RunValidationAndConversion(
2438 node_def, error::INVALID_ARGUMENT,
2439 "Axis for ExpandDims is invalid, must be in the range "
2440 "[-rank(input) - 1, rank(input)], at my_expanddims");
2441 }
2442 {
2443 // Axis < -rank(input)-1, should fail.
2444 Reset();
2445 AddTestTensor("input", {1, 2, 3});
2446 // Input is rank 4 (batch dim included)
2447 AddTestWeights<int32>("weights", {1}, {-6});
2448 RunValidationAndConversion(
2449 node_def, error::INVALID_ARGUMENT,
2450 "Axis for ExpandDims is invalid, must be in the range "
2451 "[-rank(input) - 1, rank(input)], at my_expanddims");
2452 }
2453
2454 struct TestParams {
2455 std::vector<int> input_dims;
2456 int axis;
2457 std::vector<int> expected_output_dims;
2458 };
2459
2460 // Ok.
2461 const int kExpandDimsOKCases = 8;
2462 TestParams ok_params[kExpandDimsOKCases] = {
2463 TestParams{{2, 3}, 1, {1, 2, 3}}, TestParams{{2, 3}, -3, {1, 2, 3}},
2464 TestParams{{2, 3}, 3, {2, 3, 1}}, TestParams{{2, 3}, -1, {2, 3, 1}},
2465 TestParams{{2, 3}, 2, {2, 1, 3}}, TestParams{{2, 3}, -2, {2, 1, 3}},
2466 TestParams{{6}, 1, {1, 6}}, TestParams{{6}, -1, {6, 1}},
2467 };
2468 for (int i = 0; i < kExpandDimsOKCases; ++i) {
2469 Reset();
2470 AddTestTensor("input", ok_params[i].input_dims);
2471 AddTestWeights<int32>("weights", {1}, {ok_params[i].axis});
2472 RunValidationAndConversion(node_def);
2473 TRT_TensorOrWeights output;
2474 TF_EXPECT_OK(GetTensorOrWeights("my_expanddims", &output));
2475 EXPECT_TRUE(output.is_tensor());
2476 ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims,
2477 output.tensor()->getDimensions());
2478
2479 const DataVec input_data{
2480 {"input", test::AsTensor<float>({1, 2, 3, 4, 5, 6})}};
2481 DataVec output_data{{"my_expanddims", ConstructTensor<float>(6)}};
2482 BuildAndRun(input_data, &output_data);
2483 EXPECT_THAT(GetSpanForData<float>(output_data[0]),
2484 ElementsAre(1, 2, 3, 4, 5, 6));
2485 }
2486 }
2487
TEST_F(OpConverterTest,ConvertSqueeze)2488 TEST_F(OpConverterTest, ConvertSqueeze) {
2489 {
2490 // Input list is empty, should fail.
2491 NodeDef node_def = MakeNodeDef("my_squeeze", "Squeeze", {});
2492 RunValidationAndConversion(
2493 node_def, error::INVALID_ARGUMENT,
2494 "Squeeze got 0 inputs but expected 1, at my_squeeze");
2495 }
2496 {
2497 // No attrs, should fail.
2498 Reset();
2499 Scope s = Scope::NewRootScope();
2500 auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
2501 auto squeeze = ops::Squeeze(s.WithOpName("my_squeeze"), input);
2502 const NodeDef& node_def = squeeze.operation.node()->def();
2503 AddTestTensor("input", {1, 2, 3});
2504 RunValidationAndConversion(
2505 node_def, error::UNIMPLEMENTED,
2506 "Squeeze is only implemented for explicit dims, at my_squeeze");
2507 }
2508
2509 // Get the NodeDef for Squeeze.
2510 auto get_squeeze_nodedef = [](std::vector<int> axis) -> NodeDef {
2511 Scope s = Scope::NewRootScope();
2512 auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
2513 ops::Squeeze::Attrs squeeze_attrs;
2514 squeeze_attrs.axis_ = gtl::ArraySlice<int>(axis); // non-absl ok
2515 auto squeeze =
2516 ops::Squeeze(s.WithOpName("my_squeeze"), input, squeeze_attrs);
2517 return squeeze.operation.node()->def();
2518 };
2519
2520 {
2521 // Input is weights, should fail.
2522 Reset();
2523 NodeDef node_def = get_squeeze_nodedef({0});
2524 AddTestWeights<float>("input", {1, 2, 3}, {1, 2, 3, 4, 5, 6});
2525 RunValidationAndConversion(
2526 node_def, error::UNIMPLEMENTED,
2527 "The input \"input\" for Squeeze must be a tensor, at my_squeeze");
2528 }
2529 {
2530 // Squeeze batch dim, should fail.
2531 Reset();
2532 NodeDef node_def = get_squeeze_nodedef({0});
2533 AddTestTensor("input", {1, 2, 3});
2534 RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
2535 "Cannot squeeze batch dimension, at my_squeeze");
2536 }
2537 {
2538 // Squeeze batch dim via negative axis, should fail.
2539 Reset();
2540 NodeDef node_def = get_squeeze_nodedef({-4});
2541 AddTestTensor("input", {1, 2, 3});
2542 RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
2543 "Cannot squeeze batch dimension, at my_squeeze");
2544 }
2545 {
2546 // Squeeze >= rank(input), should fail.
2547 Reset();
2548 NodeDef node_def = get_squeeze_nodedef({4});
2549 AddTestTensor("input", {1, 2, 3});
2550 RunValidationAndConversion(
2551 node_def, error::INVALID_ARGUMENT,
2552 "Axis for Squeeze is invalid, must be in the range "
2553 "[-rank(input), rank(input)), at my_squeeze");
2554 }
2555 {
2556 // Squeeze < -rank(input), should fail.
2557 Reset();
2558 NodeDef node_def = get_squeeze_nodedef({-5});
2559 AddTestTensor("input", {1, 2, 3});
2560 RunValidationAndConversion(
2561 node_def, error::INVALID_ARGUMENT,
2562 "Axis for Squeeze is invalid, must be in the range "
2563 "[-rank(input), rank(input)), at my_squeeze");
2564 }
2565
2566 struct TestParams {
2567 std::vector<int> input_dims;
2568 std::vector<int> axis;
2569 std::vector<int> expected_output_dims;
2570 };
2571
2572 // Ok.
2573 const int kSqueezeOKCases = 10;
2574 TestParams ok_params[kSqueezeOKCases] = {
2575 TestParams{{1, 2, 3}, {1}, {2, 3}},
2576 TestParams{{1, 2, 3}, {-3}, {2, 3}},
2577 TestParams{{2, 3, 1}, {3}, {2, 3}},
2578 TestParams{{2, 3, 1}, {-1}, {2, 3}},
2579 TestParams{{1, 2, 1, 3, 1}, {1, 3, 5}, {2, 3}},
2580 TestParams{{1, 2, 1, 3, 1}, {3, 1, 5}, {2, 3}},
2581 TestParams{{1, 2, 1, 3, 1}, {-1, -3, -5}, {2, 3}},
2582 TestParams{{1, 2, 1, 3, 1}, {1, -3, 5}, {2, 3}},
2583 TestParams{{1, 6}, {1}, {6}},
2584 TestParams{{6, 1}, {2}, {6}},
2585 };
2586 for (int i = 0; i < kSqueezeOKCases; ++i) {
2587 Reset();
2588 NodeDef node_def = get_squeeze_nodedef(ok_params[i].axis);
2589 AddTestTensor("input", ok_params[i].input_dims);
2590 RunValidationAndConversion(node_def);
2591 TRT_TensorOrWeights output;
2592 TF_EXPECT_OK(GetTensorOrWeights("my_squeeze", &output));
2593 EXPECT_TRUE(output.is_tensor());
2594 ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims,
2595 output.tensor()->getDimensions());
2596
2597 const DataVec input_data{
2598 {"input", test::AsTensor<float>({1, 2, 3, 4, 5, 6})}};
2599 DataVec output_data{{"my_squeeze", ConstructTensor<float>(6)}};
2600 BuildAndRun(input_data, &output_data);
2601 EXPECT_THAT(GetSpanForData<float>(output_data[0]),
2602 ElementsAre(1, 2, 3, 4, 5, 6));
2603 }
2604 }
2605
TEST_F(OpConverterTest,ConvertStridedSlice)2606 TEST_F(OpConverterTest, ConvertStridedSlice) {
2607 {
2608 // Input list is empty, should fail.
2609 NodeDef node_def = MakeNodeDef("my_strided_slice", "StridedSlice", {});
2610 RunValidationAndConversion(
2611 node_def, error::INVALID_ARGUMENT,
2612 "StridedSlice got 0 inputs but expected 4, at my_strided_slice");
2613 }
2614
2615 // Get nodedef for StridedSlice layer.
2616 auto get_strided_slice_nodedef =
2617 [](int64 begin_mask = 0, int64 end_mask = 0, int64 ellipsis_mask = 0,
2618 int64 new_axis_mask = 0, int64 shrink_axis_mask = 0) -> NodeDef {
2619 Scope s = Scope::NewRootScope();
2620 auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
2621 auto begin = ops::Placeholder(s.WithOpName("begin"), DT_INT32);
2622 auto end = ops::Placeholder(s.WithOpName("end"), DT_INT32);
2623 auto strides = ops::Placeholder(s.WithOpName("strides"), DT_INT32);
2624 ops::StridedSlice::Attrs attrs = ops::StridedSlice::Attrs()
2625 .BeginMask(begin_mask)
2626 .EndMask(end_mask)
2627 .EllipsisMask(ellipsis_mask)
2628 .NewAxisMask(new_axis_mask)
2629 .ShrinkAxisMask(shrink_axis_mask);
2630 auto strided_slice = ops::StridedSlice(s.WithOpName("my_strided_slice"),
2631 input, begin, end, strides, attrs);
2632 return strided_slice.operation.node()->def();
2633 };
2634
2635 {
2636 // Input is weights, should fail.
2637 Reset();
2638 NodeDef node_def = get_strided_slice_nodedef();
2639 AddTestWeights<int32>("input", {1, 2, 3}, {1, 2, 3, 4, 5, 6});
2640 AddTestWeights<int32>("begin", {4}, {0, 0, 0, 0});
2641 AddTestWeights<int32>("end", {4}, {1, 1, 2, 3});
2642 AddTestWeights<int32>("strides", {4}, {1, 1, 1, 1});
2643 RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
2644 "The input \"input\" for StridedSlice must be a "
2645 "tensor, at my_strided_slice");
2646 }
2647 {
2648 // Begin, end, strides are tensors, should fail.
2649 Reset();
2650 NodeDef node_def = get_strided_slice_nodedef();
2651 AddTestTensor("input", {1, 2, 3});
2652 AddTestTensor("begin", {4});
2653 AddTestTensor("end", {4});
2654 AddTestTensor("strides", {4});
2655 RunValidationAndConversion(
2656 node_def, error::UNIMPLEMENTED,
2657 "The input \"begin\" for StridedSlice must be a constant, at "
2658 "my_strided_slice");
2659 }
2660 {
2661 // Non-zero ellipsis_mask, should fail.
2662 Reset();
2663 NodeDef node_def = get_strided_slice_nodedef(
2664 /*begin_mask=*/0, /*end_mask=*/0, /*ellipsis_mask=*/2,
2665 /*new_axis_mask=*/0, /*shrink_axis_mask=*/0);
2666 AddTestTensor("input", {1, 2, 3});
2667 AddTestWeights<int32>("begin", {4}, {0, 0, 0, 0});
2668 AddTestWeights<int32>("end", {4}, {1, 1, 2, 3});
2669 AddTestWeights<int32>("strides", {4}, {1, 1, 1, 1});
2670 RunValidationAndConversion(
2671 node_def, error::UNIMPLEMENTED,
2672 "ellipsis_mask is not supported for StridedSlice, at "
2673 "my_strided_slice");
2674 }
2675 {
2676 // Modify batch dim, should fail.
2677 Reset();
2678 NodeDef node_def = get_strided_slice_nodedef();
2679 AddTestTensor("input", {1, 2, 3});
2680 AddTestWeights<int32>("begin", {4}, {0, 0, 0, 0});
2681 AddTestWeights<int32>("end", {4}, {0, 1, 2, 3});
2682 AddTestWeights<int32>("strides", {4}, {1, 1, 1, 1});
2683 RunValidationAndConversion(
2684 node_def, error::UNIMPLEMENTED,
2685 "TensorRT does not allow modifications to the batch dimension, at "
2686 "my_strided_slice");
2687 }
2688 {
2689 // Dynamic batch size without end_mask, should fail.
2690 Reset();
2691 NodeDef node_def = get_strided_slice_nodedef();
2692 AddTestTensor("input", {1, 2, 3}, /*batch_size=*/-1);
2693 AddTestWeights<int32>("begin", {4}, {0, 0, 0, 0});
2694 AddTestWeights<int32>("end", {4}, {1, 1, 2, 3});
2695 AddTestWeights<int32>("strides", {4}, {1, 1, 1, 1});
2696 RunValidationAndConversion(
2697 node_def, error::UNIMPLEMENTED,
2698 "TensorRT does not allow modifications to the batch dimension, at "
2699 "my_strided_slice");
2700 }
2701 {
2702 // Dynamic batch size but using end_mask, ok.
2703 Reset();
2704 NodeDef node_def = get_strided_slice_nodedef(/*begin_mask=*/0,
2705 /*end_mask=*/1);
2706 AddTestTensor("input", {1, 2, 3}, /*batch_size=*/-1);
2707 AddTestWeights<int32>("begin", {4}, {0, 0, 0, 0});
2708 AddTestWeights<int32>("end", {4}, {0, 1, 2, 2});
2709 AddTestWeights<int32>("strides", {4}, {1, 1, 1, 1});
2710 RunValidationAndConversion(node_def);
2711 }
2712 // TRT 5.1+ supports strides
2713 #if IS_TRT_VERSION_GE(5, 1, 0)
2714 {
2715 // Negative strides, should fail.
2716 Reset();
2717 NodeDef node_def = get_strided_slice_nodedef();
2718 AddTestTensor("input", {1, 2, 3});
2719 AddTestWeights<int32>("begin", {4}, {0, 0, 0, 0});
2720 AddTestWeights<int32>("end", {4}, {1, 1, 2, 3});
2721 AddTestWeights<int32>("strides", {4}, {1, 1, 1, -1});
2722 RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
2723 "Negative or zero stride values are not "
2724 "supported for StridedSlice, at "
2725 "my_strided_slice");
2726 }
2727 #else
2728 {
2729 // Stride is not 1, should fail.
2730 Reset();
2731 NodeDef node_def = get_strided_slice_nodedef();
2732 AddTestTensor("input", {1, 2, 3});
2733 AddTestWeights<int32>("begin", {4}, {0, 0, 0, 0});
2734 AddTestWeights<int32>("end", {4}, {1, 1, 2, 3});
2735 AddTestWeights<int32>("strides", {4}, {1, 2, 1, 3});
2736 RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
2737 "Strides other than 1 are not supported with "
2738 "this version of TRT, at my_strided_slice");
2739 }
2740 #endif
2741 {
2742 // Size of sliced dim is negative, should fail.
2743 Reset();
2744 NodeDef node_def = get_strided_slice_nodedef();
2745 AddTestTensor("input", {1, 2, 3});
2746 AddTestWeights<int32>("begin", {4}, {0, 0, 2, 0});
2747 AddTestWeights<int32>("end", {4}, {1, 1, 0, 3});
2748 AddTestWeights<int32>("strides", {4}, {1, 1, 1, 1});
2749 RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
2750 "\"size\" cannot be negative or zero for "
2751 "StridedSlice, at my_strided_slice");
2752 }
2753
2754 struct TestParams {
2755 std::vector<int> input_dims;
2756 std::vector<int> begin;
2757 std::vector<int> end;
2758 std::vector<int> strides;
2759 int begin_mask;
2760 int end_mask;
2761 std::vector<int> expected_output_dims;
2762 std::vector<float> expected_output;
2763 };
2764
2765 auto get_mask = [](const std::vector<int>& mask) {
2766 int result = 0;
2767 for (int i = 0; i < mask.size(); i++) {
2768 if (mask[i]) result += (1 << i);
2769 }
2770 return result;
2771 };
2772
2773 // Same input is used for all tests.
2774 const std::vector<float> ok_input = {1, 2, 3, 4, 5, 6};
2775
2776 #if IS_TRT_VERSION_GE(5, 1, 0)
2777 const int kStridedSliceOKCases = 23;
2778 #else
2779 const int kStridedSliceOKCases = 19;
2780 #endif
2781 // Ok.
2782 TestParams ok_params[kStridedSliceOKCases] = {
2783 // 2D Crop.
2784 TestParams{/*input_dims=*/{1, 2, 3}, /*begin=*/{0, 0, 0, 0},
2785 /*end=*/{0, 0, 1, 2}, /*strides=*/{1, 1, 1, 1},
2786 /*begin_mask=*/get_mask({0, 0, 0, 0}),
2787 /*end_mask=*/get_mask({1, 1, 0, 0}),
2788 /*expected_output_dims=*/{1, 1, 2}, /*expected_output=*/{1, 2}},
2789 TestParams{
2790 /*input_dims=*/{1, 2, 3},
2791 /*begin=*/{0, 0, 1, 1}, /*end=*/{0, 0, 0, 0}, /*strides=*/{1, 1, 1, 1},
2792 /*begin_mask=*/get_mask({0, 0, 0, 0}),
2793 /*end_mask=*/get_mask({1, 1, 1, 1}), /*expected_output_dims=*/{1, 1, 2},
2794 /*expected_output=*/{5, 6}},
2795 TestParams{
2796 /*input_dims=*/{1, 2, 3},
2797 /*begin=*/{0, 0, 1, 1}, /*end=*/{0, 1, 2, 3}, /*strides=*/{1, 1, 1, 1},
2798 /*begin_mask=*/get_mask({0, 0, 0, 0}),
2799 /*end_mask=*/get_mask({1, 1, 0, 0}), /*expected_output_dims=*/{1, 1, 2},
2800 /*expected_output=*/{5, 6}},
2801 // 2D Crop, with transpose.
2802 TestParams{
2803 /*input_dims=*/{2, 3, 1},
2804 /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 1, 2, 1}, /*strides=*/{1, 1, 1, 1},
2805 /*begin_mask=*/get_mask({0, 0, 0, 0}),
2806 /*end_mask=*/get_mask({1, 0, 0, 0}), /*expected_output_dims=*/{1, 2, 1},
2807 /*expected_output=*/{1, 2}},
2808 TestParams{
2809 /*input_dims=*/{2, 3, 1},
2810 /*begin=*/{0, 1, 1, 0}, /*end=*/{0, 2, 3, 1}, /*strides=*/{1, 1, 1, 1},
2811 /*begin_mask=*/get_mask({0, 0, 0, 0}),
2812 /*end_mask=*/get_mask({1, 0, 0, 0}), /*expected_output_dims=*/{1, 2, 1},
2813 /*expected_output=*/{5, 6}},
2814 TestParams{
2815 /*input_dims=*/{2, 1, 3},
2816 /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 1, 1, 2}, /*strides=*/{1, 1, 1, 1},
2817 /*begin_mask=*/get_mask({0, 0, 0, 0}),
2818 /*end_mask=*/get_mask({1, 0, 0, 0}), /*expected_output_dims=*/{1, 1, 2},
2819 /*expected_output=*/{1, 2}},
2820 TestParams{
2821 /*input_dims=*/{2, 1, 3},
2822 /*begin=*/{0, 1, 0, 1}, /*end=*/{0, 2, 1, 3}, /*strides=*/{1, 1, 1, 1},
2823 /*begin_mask=*/get_mask({0, 0, 0, 0}),
2824 /*end_mask=*/get_mask({1, 0, 0, 0}), /*expected_output_dims=*/{1, 1, 2},
2825 /*expected_output=*/{5, 6}},
2826 // 2D Crop, with reshape.
2827 TestParams{/*input_dims=*/{2, 3},
2828 /*begin=*/{0, 0, 0}, /*end=*/{0, 1, 2}, /*strides=*/{1, 1, 1},
2829 /*begin_mask=*/get_mask({0, 0, 0}),
2830 /*end_mask=*/get_mask({1, 0, 0}),
2831 /*expected_output_dims=*/{1, 2},
2832 /*expected_output=*/{1, 2}},
2833 TestParams{/*input_dims=*/{2, 3},
2834 /*begin=*/{0, 1, 1}, /*end=*/{0, 0, 0}, /*strides=*/{1, 1, 1},
2835 /*begin_mask=*/get_mask({0, 0, 0}),
2836 /*end_mask=*/get_mask({1, 1, 1}),
2837 /*expected_output_dims=*/{1, 2},
2838 /*expected_output=*/{5, 6}},
2839 // 1D Crop.
2840 TestParams{
2841 /*input_dims=*/{1, 2, 3},
2842 /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 0, 0, 2}, /*strides=*/{1, 1, 1, 1},
2843 /*begin_mask=*/get_mask({0, 0, 0, 0}),
2844 /*end_mask=*/get_mask({1, 1, 1, 0}), /*expected_output_dims=*/{1, 2, 2},
2845 /*expected_output=*/{1, 2, 4, 5}},
2846 TestParams{
2847 /*input_dims=*/{1, 2, 3},
2848 /*begin=*/{0, 0, 1, 0}, /*end=*/{0, 0, 0, 0}, /*strides=*/{1, 1, 1, 1},
2849 /*begin_mask=*/get_mask({0, 0, 0, 0}),
2850 /*end_mask=*/get_mask({1, 1, 1, 1}), /*expected_output_dims=*/{1, 1, 3},
2851 /*expected_output=*/{4, 5, 6}},
2852 // 1D Crop, with transpose.
2853 TestParams{
2854 /*input_dims=*/{2, 3, 1},
2855 /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 1, 0, 0}, /*strides=*/{1, 1, 1, 1},
2856 /*begin_mask=*/get_mask({0, 0, 0, 0}),
2857 /*end_mask=*/get_mask({1, 0, 1, 1}), /*expected_output_dims=*/{1, 3, 1},
2858 /*expected_output=*/{1, 2, 3}},
2859 TestParams{
2860 /*input_dims=*/{2, 3, 1},
2861 /*begin=*/{0, 1, 0, 0}, /*end=*/{0, 0, 0, 0}, /*strides=*/{1, 1, 1, 1},
2862 /*begin_mask=*/get_mask({0, 0, 0, 0}),
2863 /*end_mask=*/get_mask({1, 1, 1, 1}), /*expected_output_dims=*/{1, 3, 1},
2864 /*expected_output=*/{4, 5, 6}},
2865 // 1D Crop, with reshape.
2866 TestParams{/*input_dims=*/{6},
2867 /*begin=*/{0, 0}, /*end=*/{0, 3}, /*strides=*/{1, 1},
2868 /*begin_mask=*/get_mask({0, 0}), /*end_mask=*/get_mask({1, 0}),
2869 /*expected_output_dims=*/{3},
2870 /*expected_output=*/{1, 2, 3}},
2871 TestParams{/*input_dims=*/{1, 6},
2872 /*begin=*/{0, 0, 2}, /*end=*/{0, 0, 5}, /*strides=*/{1, 1, 1},
2873 /*begin_mask=*/get_mask({0, 0, 0}),
2874 /*end_mask=*/get_mask({1, 1, 0}),
2875 /*expected_output_dims=*/{1, 3},
2876 /*expected_output=*/{3, 4, 5}},
2877 TestParams{/*input_dims=*/{6, 1},
2878 /*begin=*/{0, 2, 0}, /*end=*/{0, 5, 0}, /*strides=*/{1, 1, 1},
2879 /*begin_mask=*/get_mask({0, 0, 0}),
2880 /*end_mask=*/get_mask({1, 0, 1}),
2881 /*expected_output_dims=*/{3, 1},
2882 /*expected_output=*/{3, 4, 5}},
2883 // Negative axis.
2884 TestParams{/*input_dims=*/{6, 1},
2885 /*begin=*/{0, -6, 0}, /*end=*/{0, -3, 0}, /*strides=*/{1, 1, 1},
2886 /*begin_mask=*/get_mask({0, 0, 0}),
2887 /*end_mask=*/get_mask({1, 0, 1}),
2888 /*expected_output_dims=*/{3, 1},
2889 /*expected_output=*/{1, 2, 3}},
2890 TestParams{/*input_dims=*/{6, 1},
2891 /*begin=*/{0, 0, 0}, /*end=*/{0, -1, 0}, /*strides=*/{1, 1, 1},
2892 /*begin_mask=*/get_mask({0, 0, 0}),
2893 /*end_mask=*/get_mask({1, 0, 1}),
2894 /*expected_output_dims=*/{5, 1},
2895 /*expected_output=*/{1, 2, 3, 4, 5}},
2896 // Clamp out of bounds begin and end.
2897 TestParams{/*input_dims=*/{1, 2, 3}, /*begin=*/{0, 0, -9999, -9},
2898 /*end=*/{0, 1, 1000, 4}, /*strides=*/{1, 1, 1, 1},
2899 /*begin_mask=*/get_mask({0, 0, 0, 0}),
2900 /*end_mask=*/get_mask({1, 0, 0, 0}),
2901 /*expected_output_dims=*/{1, 2, 3},
2902 /*expected_output=*/{1, 2, 3, 4, 5, 6}},
2903 #if IS_TRT_VERSION_GE(5, 1, 0)
2904 // Strides
2905 TestParams{/*input_dims=*/{6},
2906 /*begin=*/{0, 0}, /*end=*/{0, 5}, /*strides=*/{1, 2},
2907 /*begin_mask=*/get_mask({0, 0}), /*end_mask=*/get_mask({1, 0}),
2908 /*expected_output_dims=*/{3},
2909 /*expected_output=*/{1, 3, 5}},
2910 TestParams{/*input_dims=*/{6},
2911 /*begin=*/{0, 0}, /*end=*/{0, 6}, /*strides=*/{1, 2},
2912 /*begin_mask=*/get_mask({0, 0}), /*end_mask=*/get_mask({1, 0}),
2913 /*expected_output_dims=*/{3},
2914 /*expected_output=*/{1, 3, 5}},
2915 TestParams{/*input_dims=*/{6},
2916 /*begin=*/{0, 1}, /*end=*/{0, 6}, /*strides=*/{1, 2},
2917 /*begin_mask=*/get_mask({0, 0}), /*end_mask=*/get_mask({1, 0}),
2918 /*expected_output_dims=*/{3},
2919 /*expected_output=*/{2, 4, 6}},
2920 TestParams{/*input_dims=*/{6},
2921 /*begin=*/{0, 2}, /*end=*/{0, 6}, /*strides=*/{1, 3},
2922 /*begin_mask=*/get_mask({0, 0}), /*end_mask=*/get_mask({1, 0}),
2923 /*expected_output_dims=*/{2},
2924 /*expected_output=*/{3, 6}},
2925 #endif
2926 };
2927
2928 for (int i = 0; i < kStridedSliceOKCases; i++) {
2929 Reset();
2930 NodeDef node_def = get_strided_slice_nodedef(ok_params[i].begin_mask,
2931 ok_params[i].end_mask);
2932 AddTestTensor("input", ok_params[i].input_dims);
2933 AddTestWeights<int32>("begin",
2934 {static_cast<int>(ok_params[i].begin.size())},
2935 ok_params[i].begin);
2936 AddTestWeights<int32>("end", {static_cast<int>(ok_params[i].end.size())},
2937 ok_params[i].end);
2938 AddTestWeights<int32>("strides",
2939 {static_cast<int>(ok_params[i].strides.size())},
2940 ok_params[i].strides);
2941 RunValidationAndConversion(node_def);
2942
2943 TRT_TensorOrWeights output;
2944 TF_EXPECT_OK(GetTensorOrWeights("my_strided_slice", &output));
2945 EXPECT_TRUE(output.is_tensor());
2946 ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims,
2947 output.tensor()->getDimensions());
2948
2949 const DataVec input_data{{"input", test::AsTensor<float>(ok_input)}};
2950 DataVec output_data{
2951 {"my_strided_slice",
2952 ConstructTensor<float>(ok_params[i].expected_output.size())}};
2953 BuildAndRun(input_data, &output_data);
2954 EXPECT_THAT(GetSpanForData<float>(output_data[0]),
2955 ElementsAreArray(ok_params[i].expected_output));
2956 }
2957 }
2958
TEST_F(OpConverterTest,ConvertSlice)2959 TEST_F(OpConverterTest, ConvertSlice) {
2960 // Get nodedef for Slice layer.
2961 auto get_slice_nodedef = []() -> NodeDef {
2962 Scope s = Scope::NewRootScope();
2963 auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
2964 auto begin = ops::Placeholder(s.WithOpName("begin"), DT_INT32);
2965 auto size = ops::Placeholder(s.WithOpName("size"), DT_INT32);
2966 auto slice = ops::Slice(s.WithOpName("my_slice"), input, begin, size);
2967 return slice.operation.node()->def();
2968 };
2969
2970 {
2971 // Begin is below bounds, should fail.
2972 Reset();
2973 NodeDef node_def = get_slice_nodedef();
2974 AddTestTensor("input", {1, 2, 3});
2975 AddTestWeights<int32>("begin", {4}, {0, 0, -1, 0});
2976 AddTestWeights<int32>("size", {4}, {1, 1, 2, 3});
2977 RunValidationAndConversion(
2978 node_def, error::INVALID_ARGUMENT,
2979 "\"begin\" for dimension 2 in Slice is out of range, at my_slice");
2980 }
2981 {
2982 // Begin is above bounds, should fail.
2983 Reset();
2984 NodeDef node_def = get_slice_nodedef();
2985 AddTestTensor("input", {1, 2, 3});
2986 AddTestWeights<int32>("begin", {4}, {0, 0, 3, 0});
2987 AddTestWeights<int32>("size", {4}, {1, 1, 2, 3});
2988 RunValidationAndConversion(
2989 node_def, error::INVALID_ARGUMENT,
2990 "\"begin\" for dimension 2 in Slice is out of range, at my_slice");
2991 }
2992 {
2993 // Size is below bounds, should fail.
2994 Reset();
2995 NodeDef node_def = get_slice_nodedef();
2996 AddTestTensor("input", {1, 2, 3});
2997 AddTestWeights<int32>("begin", {4}, {0, 0, 0, 0});
2998 AddTestWeights<int32>("size", {4}, {1, 1, 2, -2});
2999 RunValidationAndConversion(
3000 node_def, error::INVALID_ARGUMENT,
3001 "\"begin\" + \"size\" for dimension 3 in Slice is out of range, at "
3002 "my_slice");
3003 }
3004 {
3005 // Size is above bounds, should fail.
3006 Reset();
3007 NodeDef node_def = get_slice_nodedef();
3008 AddTestTensor("input", {1, 2, 3});
3009 AddTestWeights<int32>("begin", {4}, {0, 0, 0, 0});
3010 AddTestWeights<int32>("size", {4}, {1, 1, 3, 3});
3011 RunValidationAndConversion(
3012 node_def, error::INVALID_ARGUMENT,
3013 "\"begin\" + \"size\" for dimension 2 in Slice is out of range, at "
3014 "my_slice");
3015 }
3016 {
3017 // Modify batch dim, should fail.
3018 Reset();
3019 NodeDef node_def = get_slice_nodedef();
3020 AddTestTensor("input", {1, 2, 3});
3021 AddTestWeights<int32>("begin", {4}, {0, 0, 0, 0});
3022 AddTestWeights<int32>("size", {4}, {0, 1, 2, 3});
3023 RunValidationAndConversion(
3024 node_def, error::UNIMPLEMENTED,
3025 "TensorRT does not allow modifications to the batch dimension, at "
3026 "my_slice");
3027 }
3028 {
3029 // Dynamic batch size with size[0] not -1, should fail.
3030 Reset();
3031 NodeDef node_def = get_slice_nodedef();
3032 AddTestTensor("input", {1, 2, 3}, /*batch_size=*/-1);
3033 AddTestWeights<int32>("begin", {4}, {0, 0, 0, 0});
3034 AddTestWeights<int32>("size", {4}, {1, 1, 2, 3});
3035 RunValidationAndConversion(
3036 node_def, error::UNIMPLEMENTED,
3037 "TensorRT does not allow modifications to the batch dimension, at "
3038 "my_slice");
3039 }
3040 {
3041 // Dynamic batch size but using size[0] of -1, ok.
3042 Reset();
3043 NodeDef node_def = get_slice_nodedef();
3044 AddTestTensor("input", {1, 2, 3}, /*batch_size=*/-1);
3045 AddTestWeights<int32>("begin", {4}, {0, 0, 0, 0});
3046 AddTestWeights<int32>("size", {4}, {-1, 1, 2, 2});
3047 RunValidationAndConversion(node_def);
3048 }
3049
3050 struct TestParams {
3051 std::vector<int> input_dims;
3052 std::vector<int> begin;
3053 std::vector<int> size;
3054 std::vector<int> expected_output_dims;
3055 std::vector<int> expected_output;
3056 };
3057
3058 // Ok.
3059 const int kSliceOKCases = 5;
3060 TestParams ok_params[kSliceOKCases] = {
3061 TestParams{{1, 2, 3},
3062 {0, 0, 0, 0},
3063 {-1, -1, -1, -1},
3064 {1, 2, 3},
3065 {1, 2, 3, 4, 5, 6}},
3066 TestParams{
3067 {1, 2, 3}, {0, 0, 0, 0}, {1, 1, 2, 3}, {1, 2, 3}, {1, 2, 3, 4, 5, 6}},
3068 TestParams{
3069 {1, 2, 3}, {0, 0, 0, 0}, {1, -1, 2, 2}, {1, 2, 2}, {1, 2, 4, 5}},
3070 TestParams{{6}, {0, 1}, {1, 5}, {5}, {2, 3, 4, 5, 6}},
3071 TestParams{{6}, {0, 1}, {-1, 3}, {3}, {2, 3, 4}},
3072 };
3073
3074 for (int i = 0; i < kSliceOKCases; i++) {
3075 Reset();
3076 NodeDef node_def = get_slice_nodedef();
3077 AddTestTensor("input", ok_params[i].input_dims);
3078 AddTestWeights<int32>("begin",
3079 {static_cast<int>(ok_params[i].begin.size())},
3080 ok_params[i].begin);
3081 AddTestWeights<int32>("size", {static_cast<int>(ok_params[i].size.size())},
3082 ok_params[i].size);
3083 RunValidationAndConversion(node_def);
3084
3085 TRT_TensorOrWeights output;
3086 TF_EXPECT_OK(GetTensorOrWeights("my_slice", &output));
3087 EXPECT_TRUE(output.is_tensor());
3088 ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims,
3089 output.tensor()->getDimensions());
3090
3091 const DataVec input_data{
3092 {"input", test::AsTensor<float>({1, 2, 3, 4, 5, 6})}};
3093 DataVec output_data{{"my_slice", ConstructTensor<float>(
3094 ok_params[i].expected_output.size())}};
3095 BuildAndRun(input_data, &output_data);
3096 EXPECT_THAT(GetSpanForData<float>(output_data[0]),
3097 ElementsAreArray(ok_params[i].expected_output));
3098 }
3099 }
3100
TEST_F(OpConverterTest,ConvertConv2D)3101 TEST_F(OpConverterTest, ConvertConv2D) {
3102 {
3103 // Input list is empty, should fail.
3104 NodeDef node_def = MakeNodeDef("my_conv2d", "Conv2D", {});
3105 RunValidationAndConversion(
3106 node_def, error::INVALID_ARGUMENT,
3107 "Conv2D got 0 inputs but expected 2, at my_conv2d");
3108 }
3109
3110 // Get nodedef for Conv2D layer.
3111 auto get_conv2d_nodedef =
3112 [](std::vector<int> strides = {1, 1, 1, 1}, string padding = "SAME",
3113 string data_format = "NCHW", std::vector<int> dilations = {1, 1, 1, 1},
3114 bool is_conv2d_backprop_input = false) -> NodeDef {
3115 Scope s = Scope::NewRootScope();
3116 auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
3117 auto filter = ops::Placeholder(s.WithOpName("weights"), DT_FLOAT);
3118 if (is_conv2d_backprop_input) {
3119 auto input_sizes =
3120 ops::Placeholder(s.WithOpName("input_sizes"), DT_INT32);
3121 ops::Conv2DBackpropInput::Attrs attrs = ops::Conv2DBackpropInput::Attrs()
3122 .DataFormat(data_format)
3123 .Dilations(dilations);
3124 auto conv2d =
3125 ops::Conv2DBackpropInput(s.WithOpName("my_conv2d"), input_sizes,
3126 filter, input, strides, padding, attrs);
3127 return conv2d.operation.node()->def();
3128 } else {
3129 ops::Conv2D::Attrs attrs =
3130 ops::Conv2D::Attrs().DataFormat(data_format).Dilations(dilations);
3131 auto conv2d = ops::Conv2D(s.WithOpName("my_conv2d"), input, filter,
3132 strides, padding, attrs);
3133 return conv2d.operation.node()->def();
3134 }
3135 };
3136
3137 {
3138 // Input is weights, should fail.
3139 Reset();
3140 NodeDef node_def = get_conv2d_nodedef();
3141 AddTestWeights<float>("input", {1, 2, 3}, {1, 2, 3, 4, 5, 6});
3142 AddTestWeights<float>("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
3143 RunValidationAndConversion(
3144 node_def, error::UNIMPLEMENTED,
3145 "The input \"input\" for Conv2D must be a tensor, at my_conv2d");
3146 }
3147 {
3148 // Filter is tensor, should fail.
3149 Reset();
3150 NodeDef node_def = get_conv2d_nodedef();
3151 AddTestTensor("input", {1, 2, 3});
3152 AddTestTensor("weights", {3, 3, 1, 1});
3153 RunValidationAndConversion(
3154 node_def, error::UNIMPLEMENTED,
3155 "The input \"filter\" for Conv2D must be a constant, at my_conv2d");
3156 }
3157 {
3158 // Filter is not 4D, should fail.
3159 Reset();
3160 NodeDef node_def = get_conv2d_nodedef();
3161 AddTestTensor("input", {1, 2, 3});
3162 AddTestWeights<float>("weights", {3, 3, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
3163 RunValidationAndConversion(
3164 node_def, error::INVALID_ARGUMENT,
3165 "Conv2D expects kernel of dimension 4, at my_conv2d");
3166 }
3167 {
3168 // Dilations is not 4D, should fail.
3169 Reset();
3170 NodeDef node_def =
3171 get_conv2d_nodedef({1, 1, 1, 1}, "SAME", "NCHW", {1, 1, 1});
3172 AddTestTensor("input", {1, 2, 3});
3173 AddTestWeights<float>("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
3174 RunValidationAndConversion(
3175 node_def, error::INVALID_ARGUMENT,
3176 "Convolution dilations field must specify 4 dimensions, at my_conv2d");
3177 }
3178 {
3179 // Dilation value is not 1 for channel, should fail.
3180 Reset();
3181 NodeDef node_def =
3182 get_conv2d_nodedef({1, 1, 1, 1}, "SAME", "NCHW", {1, 2, 1, 1});
3183 AddTestTensor("input", {1, 2, 3});
3184 AddTestWeights<float>("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
3185 RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
3186 "Dilation rate must be 1 for batch and channel "
3187 "dimensions, at my_conv2d");
3188 }
3189 {
3190 // Dilation value is not 1 for channel (NHWC), should fail.
3191 Reset();
3192 NodeDef node_def =
3193 get_conv2d_nodedef({1, 1, 1, 1}, "SAME", "NHWC", {1, 1, 1, 2});
3194 AddTestTensor("input", {2, 3, 1});
3195 AddTestWeights<float>("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
3196 RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
3197 "Dilation rate must be 1 for batch and channel "
3198 "dimensions, at my_conv2d");
3199 }
3200 {
3201 // Dilation + Conv2DBackpropInput, should fail.
3202 Reset();
3203 NodeDef node_def =
3204 get_conv2d_nodedef({1, 1, 1, 1}, "SAME", "NHWC", {1, 1, 2, 1}, true);
3205 AddTestTensor("input", {2, 3, 1});
3206 AddTestWeights<float>("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
3207 AddTestWeights<int>("input_sizes", {4}, {1, 2, 3, 1});
3208 RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
3209 "Dilation with Conv2DBackpropInput "
3210 "(conv2d_transpose) is not supported, "
3211 "at my_conv2d");
3212 }
3213 {
3214 // Strides is not 4D, should fail.
3215 Reset();
3216 NodeDef node_def =
3217 get_conv2d_nodedef({1, 1, 1}, "SAME", "NCHW", {1, 1, 1, 1});
3218 AddTestTensor("input", {1, 2, 3});
3219 AddTestWeights<float>("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
3220 RunValidationAndConversion(
3221 node_def, error::INVALID_ARGUMENT,
3222 "Convolution strides field must specify 4 dimensions, at my_conv2d");
3223 }
3224 {
3225 // Stride value is not 1 for channel, should fail.
3226 Reset();
3227 NodeDef node_def =
3228 get_conv2d_nodedef({1, 2, 1, 1}, "SAME", "NCHW", {1, 1, 1, 1});
3229 AddTestTensor("input", {1, 2, 3});
3230 AddTestWeights<float>("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
3231 RunValidationAndConversion(
3232 node_def, error::UNIMPLEMENTED,
3233 "Stride must be 1 for batch and channel dimensions, at my_conv2d");
3234 }
3235
3236 struct TestParams {
3237 std::vector<int> input_dims;
3238 std::vector<float> input;
3239 std::vector<int> filter_dims;
3240 std::vector<float> filter;
3241 std::vector<int> strides;
3242 string padding;
3243 string data_format;
3244 std::vector<int> dilations;
3245 bool is_conv2d_backprop_input;
3246 std::vector<int> expected_output_dims;
3247 std::vector<float> expected_output;
3248 };
3249
3250 // Ok.
3251 const int kConv2DOKCases = 7;
3252 TestParams ok_params[kConv2DOKCases] = {
3253 // Basic
3254 TestParams{/*input_dims=*/{1, 2, 3},
3255 /*input=*/{0, 1, 2, 3, 3, 4},
3256 /*filter_dims=*/{1, 2, 1, 1},
3257 /*filter=*/{-1, 1},
3258 /*strides=*/{1, 1, 1, 1},
3259 /*padding=*/"VALID",
3260 /*data_format=*/"NCHW",
3261 /*dilations=*/{1, 1, 1, 1},
3262 /*is_conv2d_backprop_input=*/false,
3263 /*expected_output_dims=*/{1, 2, 2},
3264 /*expected_output=*/{1, 1, 0, 1}},
3265 // SAME padding (Asymmetric)
3266 TestParams{/*input_dims=*/{1, 2, 3},
3267 /*input=*/{0, 1, 2, 3, 3, 4},
3268 /*filter_dims=*/{1, 2, 1, 1},
3269 /*filter=*/{-1, 1},
3270 /*strides=*/{1, 1, 1, 1},
3271 /*padding=*/"SAME",
3272 /*data_format=*/"NCHW",
3273 /*dilations=*/{1, 1, 1, 1},
3274 /*is_conv2d_backprop_input=*/false,
3275 /*expected_output_dims=*/{1, 2, 3},
3276 /*expected_output=*/{1, 1, -2, 0, 1, -4}},
3277 // SAME padding (Symmetric)
3278 TestParams{/*input_dims=*/{1, 2, 3},
3279 /*input=*/{0, 1, 2, 3, 3, 4},
3280 /*filter_dims=*/{1, 3, 1, 1},
3281 /*filter=*/{-1, 0, 1},
3282 /*strides=*/{1, 1, 1, 1},
3283 /*padding=*/"SAME",
3284 /*data_format=*/"NCHW",
3285 /*dilations=*/{1, 1, 1, 1},
3286 /*is_conv2d_backprop_input=*/false,
3287 /*expected_output_dims=*/{1, 2, 3},
3288 /*expected_output=*/{1, 2, -1, 3, 1, -3}},
3289 // NHWC
3290 TestParams{/*input_dims=*/{2, 3, 1},
3291 /*input=*/{0, 1, 2, 3, 3, 4},
3292 /*filter_dims=*/{1, 2, 1, 1},
3293 /*filter=*/{-1, 1},
3294 /*strides=*/{1, 1, 1, 1},
3295 /*padding=*/"VALID",
3296 /*data_format=*/"NHWC",
3297 /*dilations=*/{1, 1, 1, 1},
3298 /*is_conv2d_backprop_input=*/false,
3299 /*expected_output_dims=*/{2, 2, 1},
3300 /*expected_output=*/{1, 1, 0, 1}},
3301 // Dilated
3302 TestParams{/*input_dims=*/{1, 2, 3},
3303 /*input=*/{0, 1, 2, 3, 3, 4},
3304 /*filter_dims=*/{1, 2, 1, 1},
3305 /*filter=*/{-1, 1},
3306 /*strides=*/{1, 1, 1, 1},
3307 /*padding=*/"VALID",
3308 /*data_format=*/"NCHW",
3309 /*dilations=*/{1, 1, 1, 2},
3310 /*is_conv2d_backprop_input=*/false,
3311 /*expected_output_dims=*/{1, 2, 1},
3312 /*expected_output=*/{2, 1}},
3313 // Strided
3314 TestParams{/*input_dims=*/{1, 2, 4},
3315 /*input=*/{0, 1, 2, 2, 3, 4, 4, 7},
3316 /*filter_dims=*/{1, 2, 1, 1},
3317 /*filter=*/{-1, 1},
3318 /*strides=*/{1, 1, 1, 2},
3319 /*padding=*/"VALID",
3320 /*data_format=*/"NCHW",
3321 /*dilations=*/{1, 1, 1, 1},
3322 /*is_conv2d_backprop_input=*/false,
3323 /*expected_output_dims=*/{1, 2, 2},
3324 /*expected_output=*/{1, 0, 1, 3}},
3325 // Transpose Strided
3326 TestParams{/*input_dims=*/{1, 2, 2},
3327 /*input=*/{0, 1, 2, 3},
3328 /*filter_dims=*/{1, 2, 1, 1},
3329 /*filter=*/{-1, 1},
3330 /*strides=*/{1, 1, 1, 2},
3331 /*padding=*/"SAME",
3332 /*data_format=*/"NCHW",
3333 /*dilations=*/{1, 1, 1, 1},
3334 /*is_conv2d_backprop_input=*/true,
3335 /*expected_output_dims=*/{1, 2, 4},
3336 /*expected_output=*/{0, 0, -1, 1, -2, 2, -3, 3}},
3337 };
3338
3339 for (int i = 0; i < kConv2DOKCases; i++) {
3340 Reset();
3341 NodeDef node_def = get_conv2d_nodedef(
3342 ok_params[i].strides, ok_params[i].padding, ok_params[i].data_format,
3343 ok_params[i].dilations, ok_params[i].is_conv2d_backprop_input);
3344 AddTestTensor("input", ok_params[i].input_dims);
3345 AddTestWeights<float>("weights", ok_params[i].filter_dims,
3346 ok_params[i].filter);
3347 if (ok_params[i].is_conv2d_backprop_input) {
3348 AddTestWeights<float>(
3349 "input_sizes",
3350 {static_cast<int>(ok_params[i].expected_output.size())},
3351 ok_params[i].expected_output);
3352 }
3353 RunValidationAndConversion(node_def);
3354 TRT_TensorOrWeights output;
3355 TF_EXPECT_OK(GetTensorOrWeights("my_conv2d", &output));
3356 EXPECT_TRUE(output.is_tensor());
3357 ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims,
3358 output.tensor()->getDimensions());
3359
3360 const DataVec input_data{
3361 {"input", test::AsTensor<float>(ok_params[i].input)}};
3362 DataVec output_data{
3363 {"my_conv2d",
3364 ConstructTensor<float>(ok_params[i].expected_output.size())}};
3365 BuildAndRun(input_data, &output_data);
3366 EXPECT_THAT(GetSpanForData<float>(output_data[0]),
3367 ElementsAreArray(ok_params[i].expected_output));
3368 }
3369 }
3370
TEST_F(OpConverterTest,ConvertTopK)3371 TEST_F(OpConverterTest, ConvertTopK) {
3372 {
3373 // Input list is empty, should fail.
3374 NodeDef node_def = MakeNodeDef("my_topk", "TopKV2", {});
3375 RunValidationAndConversion(
3376 node_def, error::INVALID_ARGUMENT,
3377 "TopKV2 got 0 inputs but expected 2, at my_topk");
3378 }
3379
3380 for (const auto dtype : {DT_FLOAT, DT_INT32}) {
3381 // Get the NodeDef for TopKV2.
3382 Scope s = Scope::NewRootScope();
3383 auto input = ops::Placeholder(s.WithOpName("input"), dtype);
3384 auto weights = ops::Placeholder(s.WithOpName("weights"), DT_INT32);
3385 auto topk = ops::TopK(s.WithOpName("my_topk"), input, weights);
3386 const NodeDef& node_def = topk.operation.node()->def();
3387 {
3388 // K is a tensor, should fail.
3389 Reset();
3390 AddTestTensor("input", {1, 2, 3}, /*batch_size=*/1,
3391 /*trt_dtype=*/TfDataTypeToTrt(dtype));
3392 AddTestTensor("weights", {2});
3393 RunValidationAndConversion(
3394 node_def, error::UNIMPLEMENTED,
3395 "The input \"k\" for TopKV2 must be a constant, at my_topk");
3396 }
3397 {
3398 // Ok.
3399 Reset();
3400 AddTestTensor("input", {1, 2, 5});
3401 AddTestWeights<int32>("weights", {1}, {2});
3402 RunValidationAndConversion(node_def);
3403 TRT_TensorOrWeights outputs[2];
3404 TF_EXPECT_OK(GetTensorOrWeights("my_topk", &outputs[0]));
3405 TF_EXPECT_OK(GetTensorOrWeights("my_topk:1", &outputs[1]));
3406 for (auto& output : outputs) {
3407 EXPECT_TRUE(output.is_tensor());
3408 ExpectTrtDimsEqualsArray({1, 2, 2}, output.tensor()->getDimensions());
3409 }
3410
3411 const DataVec input_data{
3412 {"input", test::AsTensor<float>({-9, 3, 5, 1, 6, -5, 7, 1, 0, -1})}};
3413 DataVec output_data{{"my_topk", ConstructTensor<float>(4)},
3414 {"my_topk:1", ConstructTensor<int32>(4)}};
3415 BuildAndRun(input_data, &output_data);
3416 EXPECT_THAT(GetSpanForData<float>(output_data[0]),
3417 ElementsAre(6, 5, 7, 1));
3418 EXPECT_THAT(GetSpanForData<int32>(output_data[1]),
3419 ElementsAre(4, 2, 1, 2));
3420 }
3421 }
3422 }
3423
3424 template <DataType dtype>
TestConvertGather(OpConverterTest * test)3425 void TestConvertGather(OpConverterTest* test) {
3426 typedef typename EnumToDataType<dtype>::Type CType;
3427
3428 // Get the NodeDef for GatherV2.
3429 Scope s = Scope::NewRootScope();
3430 auto params = ops::Placeholder(s.WithOpName("params"), dtype);
3431 auto indices = ops::Placeholder(s.WithOpName("indices"), DT_INT32);
3432 auto axis = ops::Placeholder(s.WithOpName("axis"), DT_INT32);
3433 auto gather = ops::GatherV2(s.WithOpName("my_gather"), params, indices, axis);
3434 const NodeDef& node_def = gather.operation.node()->def();
3435
3436 struct TestParams {
3437 std::vector<int> params_dims;
3438 std::vector<int> indices_dims;
3439 std::vector<int> indices;
3440 int axis;
3441 std::vector<int> expected_output_dims;
3442 std::vector<int> expected_output;
3443 };
3444
3445 // Input is the same {1, 2, 3, 4, 5, 6} for all cases.
3446 const int kGatherOKCases = 5;
3447 const std::vector<CType> params_input = {CType(1), CType(2), CType(3),
3448 CType(4), CType(5), CType(6)};
3449 TestParams ok_params[kGatherOKCases] = {
3450 // Indices are always of rank>1, and output rank is
3451 // rank(params) + rank(indices) - 1.
3452 // TODO(laigd): do we support 0-rank ITensor as indices?
3453 TestParams{{1, 2, 3}, {1}, {0}, 3, {1, 2, 1, 1}, {1, 4}},
3454 TestParams{{1, 2, 3}, {1}, {1}, 3, {1, 2, 1, 1}, {2, 5}},
3455 TestParams{{1, 2, 3}, {1}, {2}, -1, {1, 2, 1, 1}, {3, 6}},
3456 TestParams{
3457 {1, 2, 3}, {3}, {2, 0, 1}, 3, {1, 2, 1, 3}, {3, 1, 2, 6, 4, 5}},
3458 TestParams{{3, 2},
3459 {2, 2},
3460 {0, 0, 1, 0},
3461 2,
3462 {3, 1, 2, 2},
3463 {1, 1, 2, 1, 3, 3, 4, 3, 5, 5, 6, 5}},
3464 };
3465
3466 // Ok.
3467 for (int i = 0; i < kGatherOKCases; i++) {
3468 test->Reset();
3469 test->AddTestTensor("params", ok_params[i].params_dims, 1,
3470 TfDataTypeToTrt(dtype));
3471 test->AddTestTensor("indices", ok_params[i].indices_dims, 1,
3472 nvinfer1::DataType::kINT32);
3473 test->AddTestWeights<int32>("axis", {1}, {ok_params[i].axis});
3474 test->RunValidationAndConversion(node_def);
3475 TRT_TensorOrWeights output;
3476 TF_EXPECT_OK(test->GetTensorOrWeights("my_gather", &output));
3477 EXPECT_TRUE(output.is_tensor());
3478 ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims,
3479 output.tensor()->getDimensions());
3480
3481 // Create input in CType and convert expected output to CType.
3482 std::vector<CType> converted_expected_output(
3483 ok_params[i].expected_output.begin(),
3484 ok_params[i].expected_output.end());
3485
3486 const DataVec input_data{
3487 {"params", test::AsTensor<CType>(params_input)},
3488 {"indices", test::AsTensor<int32>(ok_params[i].indices)}};
3489 DataVec output_data{
3490 {"my_gather",
3491 ConstructTensor<CType>(ok_params[i].expected_output.size())}};
3492 test->BuildAndRun(input_data, &output_data);
3493 EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
3494 ElementsAreArray(converted_expected_output));
3495 }
3496 }
3497
TEST_F(OpConverterTest,ConvertGather)3498 TEST_F(OpConverterTest, ConvertGather) {
3499 {
3500 // Input list is empty, should fail.
3501 NodeDef node_def = MakeNodeDef("my_gather", "GatherV2", {});
3502 RunValidationAndConversion(
3503 node_def, error::INVALID_ARGUMENT,
3504 "GatherV2 got 0 inputs but expected 3, at my_gather");
3505 }
3506
3507 // Get the NodeDef for GatherV2.
3508 Scope s = Scope::NewRootScope();
3509 auto params = ops::Placeholder(s.WithOpName("params"), DT_FLOAT);
3510 auto indices = ops::Placeholder(s.WithOpName("indices"), DT_INT32);
3511 auto axis = ops::Placeholder(s.WithOpName("axis"), DT_INT32);
3512 auto gather = ops::GatherV2(s.WithOpName("my_gather"), params, indices, axis);
3513 const NodeDef& node_def = gather.operation.node()->def();
3514 {
3515 // Axis is a tensor, should fail.
3516 Reset();
3517 AddTestTensor("params", {1, 2, 3});
3518 AddTestTensor("indices", {2});
3519 AddTestTensor("axis", {1});
3520 RunValidationAndConversion(
3521 node_def, error::UNIMPLEMENTED,
3522 "The input \"axis\" for GatherV2 must be a constant, at my_gather");
3523 }
3524 {
3525 // Axis is out of bounds, should fail.
3526 Reset();
3527 AddTestTensor("params", {1, 2, 3});
3528 AddTestTensor("indices", {2});
3529 AddTestWeights<int32>("axis", {1}, {4});
3530 RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
3531 "Axis value of 4 is out of bounds, must be in "
3532 "range [-4, 4), at my_gather");
3533 }
3534 {
3535 // Axis is batch dimension, should fail.
3536 Reset();
3537 AddTestTensor("params", {1, 2, 3});
3538 AddTestTensor("indices", {2});
3539 AddTestWeights<int32>("axis", {1}, {0});
3540 RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
3541 "TensorRT does not allow manipulation of the "
3542 "batch dimension, at my_gather");
3543 }
3544
3545 Reset();
3546 TestConvertGather<DT_FLOAT>(this);
3547 TestConvertGather<DT_HALF>(this);
3548 TestConvertGather<DT_INT32>(this);
3549 }
3550
TEST_F(OpConverterTest,ConvertUnary)3551 TEST_F(OpConverterTest, ConvertUnary) {
3552 {
3553 // Input list is empty, should fail.
3554 NodeDef node_def = MakeNodeDef("my_unary", "Neg", {});
3555 RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
3556 "Neg got 0 inputs but expected 1, at my_unary");
3557 }
3558 {
3559 // Input is weights, should fail.
3560 Reset();
3561 Scope s = Scope::NewRootScope();
3562 auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
3563 auto neg = ops::Neg(s.WithOpName("my_unary"), input);
3564 const NodeDef& node_def = neg.operation.node()->def();
3565 AddTestWeights<float>("input", {1, 2, 3}, {-3, -2, -1, 0, 1, 2});
3566 RunValidationAndConversion(
3567 node_def, error::UNIMPLEMENTED,
3568 "The input \"x\" for Neg must be a tensor, at my_unary");
3569 }
3570
3571 // Get nodedef for unary layer.
3572 auto get_unary_nodedef = [](string op_name) -> NodeDef {
3573 Scope s = Scope::NewRootScope();
3574 auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
3575 if (op_name == "Abs") {
3576 auto unary = ops::Abs(s.WithOpName("my_unary"), input);
3577 return unary.operation.node()->def();
3578 } else if (op_name == "Acos") {
3579 auto unary = ops::Acos(s.WithOpName("my_unary"), input);
3580 return unary.operation.node()->def();
3581 } else if (op_name == "Acosh") {
3582 auto unary = ops::Acosh(s.WithOpName("my_unary"), input);
3583 return unary.operation.node()->def();
3584 } else if (op_name == "Asin") {
3585 auto unary = ops::Asin(s.WithOpName("my_unary"), input);
3586 return unary.operation.node()->def();
3587 } else if (op_name == "Asinh") {
3588 auto unary = ops::Asinh(s.WithOpName("my_unary"), input);
3589 return unary.operation.node()->def();
3590 } else if (op_name == "Atan") {
3591 auto unary = ops::Atan(s.WithOpName("my_unary"), input);
3592 return unary.operation.node()->def();
3593 } else if (op_name == "Atanh") {
3594 auto unary = ops::Atanh(s.WithOpName("my_unary"), input);
3595 return unary.operation.node()->def();
3596 } else if (op_name == "Ceil") {
3597 auto unary = ops::Ceil(s.WithOpName("my_unary"), input);
3598 return unary.operation.node()->def();
3599 } else if (op_name == "Cos") {
3600 auto unary = ops::Cos(s.WithOpName("my_unary"), input);
3601 return unary.operation.node()->def();
3602 } else if (op_name == "Cosh") {
3603 auto unary = ops::Cosh(s.WithOpName("my_unary"), input);
3604 return unary.operation.node()->def();
3605 } else if (op_name == "Exp") {
3606 auto unary = ops::Exp(s.WithOpName("my_unary"), input);
3607 return unary.operation.node()->def();
3608 } else if (op_name == "Floor") {
3609 auto unary = ops::Floor(s.WithOpName("my_unary"), input);
3610 return unary.operation.node()->def();
3611 } else if (op_name == "Log") {
3612 auto unary = ops::Log(s.WithOpName("my_unary"), input);
3613 return unary.operation.node()->def();
3614 } else if (op_name == "Neg") {
3615 auto unary = ops::Neg(s.WithOpName("my_unary"), input);
3616 return unary.operation.node()->def();
3617 } else if (op_name == "Reciprocal") {
3618 auto unary = ops::Reciprocal(s.WithOpName("my_unary"), input);
3619 return unary.operation.node()->def();
3620 } else if (op_name == "Rsqrt") {
3621 auto unary = ops::Rsqrt(s.WithOpName("my_unary"), input);
3622 return unary.operation.node()->def();
3623 } else if (op_name == "Sin") {
3624 auto unary = ops::Sin(s.WithOpName("my_unary"), input);
3625 return unary.operation.node()->def();
3626 } else if (op_name == "Sinh") {
3627 auto unary = ops::Sinh(s.WithOpName("my_unary"), input);
3628 return unary.operation.node()->def();
3629 } else if (op_name == "Sqrt") {
3630 auto unary = ops::Sqrt(s.WithOpName("my_unary"), input);
3631 return unary.operation.node()->def();
3632 } else if (op_name == "Tan") {
3633 auto unary = ops::Tan(s.WithOpName("my_unary"), input);
3634 return unary.operation.node()->def();
3635 }
3636 EXPECT_TRUE(false);
3637 return NodeDef();
3638 };
3639 // Get expected output for unary layer.
3640 auto get_unary_output = [](string op_name, float input) -> float {
3641 if (op_name == "Abs") {
3642 return std::abs(input);
3643 } else if (op_name == "Acos") {
3644 return std::acos(input);
3645 } else if (op_name == "Acosh") {
3646 return std::acosh(input);
3647 } else if (op_name == "Asin") {
3648 return std::asin(input);
3649 } else if (op_name == "Asinh") {
3650 return std::asinh(input);
3651 } else if (op_name == "Atan") {
3652 return std::atan(input);
3653 } else if (op_name == "Atanh") {
3654 return std::atanh(input);
3655 } else if (op_name == "Ceil") {
3656 return std::ceil(input);
3657 } else if (op_name == "Cos") {
3658 return std::cos(input);
3659 } else if (op_name == "Cosh") {
3660 return std::cosh(input);
3661 } else if (op_name == "Exp") {
3662 return std::exp(input);
3663 } else if (op_name == "Floor") {
3664 return std::floor(input);
3665 } else if (op_name == "Log") {
3666 return std::log(input);
3667 } else if (op_name == "Neg") {
3668 return -input;
3669 } else if (op_name == "Reciprocal") {
3670 return 1.0 / input;
3671 } else if (op_name == "Rsqrt") {
3672 return 1.0 / std::sqrt(input);
3673 } else if (op_name == "Sin") {
3674 return std::sin(input);
3675 } else if (op_name == "Sinh") {
3676 return std::sinh(input);
3677 } else if (op_name == "Sqrt") {
3678 return std::sqrt(input);
3679 } else if (op_name == "Tan") {
3680 return std::tan(input);
3681 }
3682 EXPECT_TRUE(false);
3683 return 0;
3684 };
3685
3686 // Get list of ops to test.
3687 std::vector<string> ops_to_test;
3688 // Add all ops supported by ConvertUnary.
3689 auto* map = UnaryOperationMap();
3690 ops_to_test.reserve(map->size());
3691 for (auto& pair : *map) {
3692 ops_to_test.push_back(pair.first);
3693 }
3694 // Add other unary ops to test.
3695 ops_to_test.push_back("Rsqrt");
3696 // Ok.
3697 for (string op_name : ops_to_test) {
3698 Reset();
3699 NodeDef node_def = get_unary_nodedef(op_name);
3700 AddTestTensor("input", {1, 2, 3});
3701 RunValidationAndConversion(node_def);
3702 TRT_TensorOrWeights output;
3703 TF_EXPECT_OK(GetTensorOrWeights("my_unary", &output));
3704 EXPECT_TRUE(output.is_tensor());
3705 ExpectTrtDimsEqualsArray({1, 2, 3}, output.tensor()->getDimensions());
3706
3707 const std::vector<float> input = {-0.9f, 0.6f, 0.0f, -3.5f, 100.0f, 2.9f};
3708 const DataVec input_data{{"input", test::AsTensor<float>(input)}};
3709 DataVec output_data{{"my_unary", ConstructTensor<float>(6)}};
3710 BuildAndRun(input_data, &output_data);
3711 for (int i = 0; i < input.size(); ++i) {
3712 const float expected_output = get_unary_output(op_name, input[i]);
3713 EXPECT_THAT(GetSpanForData<float>(output_data[0])[i],
3714 NanSensitiveFloatNear(expected_output, 0.0001));
3715 }
3716 }
3717 }
3718
3719 } // namespace convert
3720 } // namespace tensorrt
3721 } // namespace tensorflow
3722
3723 #endif // GOOGLE_TENSORRT
3724 #endif // GOOGLE_CUDA
3725