1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/lite/utils/nms_utils.h"
17
18 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
19 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
20
21 namespace mlir {
22 namespace TFL {
23
24 namespace {
25
26 // TODO(b/162842801): Consolidate all util definitions of kTFImplements.
27 constexpr char kTFImplements[] = "tf._implements";
28 constexpr char kCustomSSDPostprocessing[] = "TFLite_Detection_PostProcess";
29 constexpr char kTfNMSPadded[] = "non_max_suppression_padded_v2";
30
CustomOption(OpBuilder * builder,const std::string & content)31 inline OpaqueElementsAttr CustomOption(OpBuilder* builder,
32 const std::string& content) {
33 ShapedType type = RankedTensorType::get(
34 {static_cast<int64_t>(content.size())}, builder->getIntegerType(8));
35 return OpaqueElementsAttr::get(builder->getContext()->getLoadedDialect("tfl"),
36 type,
37 StringRef(content.data(), content.size()));
38 }
39
40 } // namespace
41
RewriteFunc()42 void ConvertNMSPaddedFunc::RewriteFunc() {
43 func_->setAttr(kTFImplements,
44 StringAttr::get(func_.getContext(), kTfNMSPadded));
45 Value boxes = func_.getArgument(0);
46 Value scores = func_.getArgument(1);
47 Value max_output_size = func_.getArgument(2);
48 Value iou_threshold = func_.getArgument(3);
49 Value score_threshold = func_.getArgument(4);
50 auto output_type0 = func_.getType().getResult(0);
51 auto output_type1 = func_.getType().getResult(1);
52
53 OpBuilder builder(func_.getBody());
54 auto op = builder.create<mlir::TFL::NonMaxSuppressionV4Op>(
55 func_.getLoc(), output_type0, output_type1, boxes, scores,
56 max_output_size, iou_threshold, score_threshold);
57
58 builder.create<mlir::ReturnOp>(func_.getLoc(), op.getResults());
59 }
60
VerifySignature()61 LogicalResult ConvertNMSPaddedFunc::VerifySignature() {
62 // Verify high-level function signature.
63 // Relevant argument characteristics are checked by the TFL op definition.
64 if (func_.getNumArguments() < 5) {
65 return func_.emitError()
66 << "Invalid number of arguments to "
67 "non_max_suppression_padded_v2 (need at least 5): "
68 << func_.getNumArguments();
69 }
70 if (func_.getType().getNumResults() != 2) {
71 return func_.emitError() << "Invalid number of results from "
72 "non_max_suppression_padded_v2 (need 2): "
73 << func_.getType().getNumResults();
74 }
75 // The TFLite fused op does not support batching yet.
76 // TODO(b/158709815): Add support for batches with padded NMS.
77 auto boxes_type = func_.getArgument(0).getType().dyn_cast<RankedTensorType>();
78 if (!boxes_type.hasRank() || boxes_type.getRank() != 2) {
79 return func_.emitError() << "TFLite does not support batched input for "
80 "non_max_suppression_padded";
81 }
82 return success();
83 }
84
RewriteFunc()85 LogicalResult ConvertSSDPostProcessFunc::RewriteFunc() {
86 func_.eraseBody();
87 func_.addEntryBlock();
88 func_->setAttr(kTFImplements,
89 StringAttr::get(func_.getContext(), kCustomSSDPostprocessing));
90
91 OpBuilder builder(func_.getBody());
92 std::string custom_option_buffer;
93 if (failed(CreateNMSCustomOptions(func_, attr_.GetAttrs(),
94 custom_option_buffer))) {
95 return failure();
96 }
97 auto op = builder.create<CustomOp>(
98 func_.getLoc(), func_.getType().getResults(), func_.getArguments(),
99 kCustomSSDPostprocessing, CustomOption(&builder, custom_option_buffer));
100 builder.create<ReturnOp>(func_.getLoc(), op.getResults());
101
102 return success();
103 }
104
CreateNMSCustomOptions(FuncOp func,DictionaryAttr attrs,std::string & custom_option_buffer)105 LogicalResult ConvertSSDPostProcessFunc::CreateNMSCustomOptions(
106 FuncOp func, DictionaryAttr attrs, std::string& custom_option_buffer) {
107 flexbuffers::Builder fbb;
108 size_t start_map = fbb.StartMap();
109
110 if (failed(AddIntAttr(func, attrs, "max_detections", &fbb)) ||
111 failed(AddIntAttr(func, attrs, "max_classes_per_detection", &fbb)) ||
112 failed(AddIntAttr(func, attrs, "num_classes", &fbb)) ||
113 failed(AddFloatAttr(func, attrs, "nms_score_threshold", &fbb)) ||
114 failed(AddFloatAttr(func, attrs, "nms_iou_threshold", &fbb)) ||
115 failed(AddFloatAttr(func, attrs, "y_scale", &fbb)) ||
116 failed(AddFloatAttr(func, attrs, "x_scale", &fbb)) ||
117 failed(AddFloatAttr(func, attrs, "h_scale", &fbb)) ||
118 failed(AddFloatAttr(func, attrs, "w_scale", &fbb)))
119 return failure();
120 auto use_regular_nms =
121 attrs.get("use_regular_nms").dyn_cast_or_null<BoolAttr>();
122 if (!use_regular_nms) {
123 return func.emitError()
124 << "use_regular_nms attribute is not set or not a bool";
125 }
126 fbb.Int("use_regular_nms", use_regular_nms.getValue());
127
128 fbb.EndMap(start_map);
129 fbb.Finish();
130 custom_option_buffer.assign(fbb.GetBuffer().begin(), fbb.GetBuffer().end());
131 return success();
132 }
133
AddIntAttr(FuncOp func,DictionaryAttr attrs,const std::string & attribute,flexbuffers::Builder * builder)134 LogicalResult ConvertSSDPostProcessFunc::AddIntAttr(
135 FuncOp func, DictionaryAttr attrs, const std::string& attribute,
136 flexbuffers::Builder* builder) {
137 auto int_attr = attrs.get(attribute).dyn_cast_or_null<IntegerAttr>();
138 if (!int_attr) {
139 return func.emitError()
140 << attribute.c_str() << " attribute is not set or not an integer";
141 }
142 builder->Int(attribute.c_str(), int_attr.getInt());
143 return success();
144 }
145
AddFloatAttr(FuncOp func,DictionaryAttr attrs,const std::string & attribute,flexbuffers::Builder * builder)146 LogicalResult ConvertSSDPostProcessFunc::AddFloatAttr(
147 FuncOp func, DictionaryAttr attrs, const std::string& attribute,
148 flexbuffers::Builder* builder) {
149 auto float_attr = attrs.get(attribute).dyn_cast_or_null<FloatAttr>();
150 if (!float_attr) {
151 return func.emitError()
152 << attribute.c_str() << " attribute is not set or not a float";
153 }
154 builder->Float(attribute.c_str(), float_attr.getValue().convertToFloat());
155 return success();
156 }
157
VerifySignature()158 LogicalResult ConvertSSDPostProcessFunc::VerifySignature() {
159 // Verify high-level function signature.
160 if (func_.getNumArguments() != 3) {
161 return func_.emitError()
162 << "Invalid number of arguments to " << kCustomSSDPostprocessing
163 << ": " << func_.getNumArguments();
164 }
165 if (func_.getType().getNumResults() != 4) {
166 return func_.emitError()
167 << "Invalid number of results from " << kCustomSSDPostprocessing
168 << ": " << func_.getType().getNumResults();
169 }
170 return success();
171 }
172
173 } // namespace TFL
174 } // namespace mlir
175