1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/lite/utils/nms_utils.h"
17 
18 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
19 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
20 
21 namespace mlir {
22 namespace TFL {
23 
24 namespace {
25 
26 // TODO(b/162842801): Consolidate all util definitions of kTFImplements.
27 constexpr char kTFImplements[] = "tf._implements";
28 constexpr char kCustomSSDPostprocessing[] = "TFLite_Detection_PostProcess";
29 constexpr char kTfNMSPadded[] = "non_max_suppression_padded_v2";
30 
CustomOption(OpBuilder * builder,const std::string & content)31 inline OpaqueElementsAttr CustomOption(OpBuilder* builder,
32                                        const std::string& content) {
33   ShapedType type = RankedTensorType::get(
34       {static_cast<int64_t>(content.size())}, builder->getIntegerType(8));
35   return OpaqueElementsAttr::get(builder->getContext()->getLoadedDialect("tfl"),
36                                  type,
37                                  StringRef(content.data(), content.size()));
38 }
39 
40 }  // namespace
41 
RewriteFunc()42 void ConvertNMSPaddedFunc::RewriteFunc() {
43   func_->setAttr(kTFImplements,
44                  StringAttr::get(func_.getContext(), kTfNMSPadded));
45   Value boxes = func_.getArgument(0);
46   Value scores = func_.getArgument(1);
47   Value max_output_size = func_.getArgument(2);
48   Value iou_threshold = func_.getArgument(3);
49   Value score_threshold = func_.getArgument(4);
50   auto output_type0 = func_.getType().getResult(0);
51   auto output_type1 = func_.getType().getResult(1);
52 
53   OpBuilder builder(func_.getBody());
54   auto op = builder.create<mlir::TFL::NonMaxSuppressionV4Op>(
55       func_.getLoc(), output_type0, output_type1, boxes, scores,
56       max_output_size, iou_threshold, score_threshold);
57 
58   builder.create<mlir::ReturnOp>(func_.getLoc(), op.getResults());
59 }
60 
VerifySignature()61 LogicalResult ConvertNMSPaddedFunc::VerifySignature() {
62   // Verify high-level function signature.
63   // Relevant argument characteristics are checked by the TFL op definition.
64   if (func_.getNumArguments() < 5) {
65     return func_.emitError()
66            << "Invalid number of arguments to "
67               "non_max_suppression_padded_v2 (need at least 5): "
68            << func_.getNumArguments();
69   }
70   if (func_.getType().getNumResults() != 2) {
71     return func_.emitError() << "Invalid number of results from "
72                                 "non_max_suppression_padded_v2 (need 2): "
73                              << func_.getType().getNumResults();
74   }
75   // The TFLite fused op does not support batching yet.
76   // TODO(b/158709815): Add support for batches with padded NMS.
77   auto boxes_type = func_.getArgument(0).getType().dyn_cast<RankedTensorType>();
78   if (!boxes_type.hasRank() || boxes_type.getRank() != 2) {
79     return func_.emitError() << "TFLite does not support batched input for "
80                                 "non_max_suppression_padded";
81   }
82   return success();
83 }
84 
RewriteFunc()85 LogicalResult ConvertSSDPostProcessFunc::RewriteFunc() {
86   func_.eraseBody();
87   func_.addEntryBlock();
88   func_->setAttr(kTFImplements,
89                  StringAttr::get(func_.getContext(), kCustomSSDPostprocessing));
90 
91   OpBuilder builder(func_.getBody());
92   std::string custom_option_buffer;
93   if (failed(CreateNMSCustomOptions(func_, attr_.GetAttrs(),
94                                     custom_option_buffer))) {
95     return failure();
96   }
97   auto op = builder.create<CustomOp>(
98       func_.getLoc(), func_.getType().getResults(), func_.getArguments(),
99       kCustomSSDPostprocessing, CustomOption(&builder, custom_option_buffer));
100   builder.create<ReturnOp>(func_.getLoc(), op.getResults());
101 
102   return success();
103 }
104 
CreateNMSCustomOptions(FuncOp func,DictionaryAttr attrs,std::string & custom_option_buffer)105 LogicalResult ConvertSSDPostProcessFunc::CreateNMSCustomOptions(
106     FuncOp func, DictionaryAttr attrs, std::string& custom_option_buffer) {
107   flexbuffers::Builder fbb;
108   size_t start_map = fbb.StartMap();
109 
110   if (failed(AddIntAttr(func, attrs, "max_detections", &fbb)) ||
111       failed(AddIntAttr(func, attrs, "max_classes_per_detection", &fbb)) ||
112       failed(AddIntAttr(func, attrs, "num_classes", &fbb)) ||
113       failed(AddFloatAttr(func, attrs, "nms_score_threshold", &fbb)) ||
114       failed(AddFloatAttr(func, attrs, "nms_iou_threshold", &fbb)) ||
115       failed(AddFloatAttr(func, attrs, "y_scale", &fbb)) ||
116       failed(AddFloatAttr(func, attrs, "x_scale", &fbb)) ||
117       failed(AddFloatAttr(func, attrs, "h_scale", &fbb)) ||
118       failed(AddFloatAttr(func, attrs, "w_scale", &fbb)))
119     return failure();
120   auto use_regular_nms =
121       attrs.get("use_regular_nms").dyn_cast_or_null<BoolAttr>();
122   if (!use_regular_nms) {
123     return func.emitError()
124            << "use_regular_nms attribute is not set or not a bool";
125   }
126   fbb.Int("use_regular_nms", use_regular_nms.getValue());
127 
128   fbb.EndMap(start_map);
129   fbb.Finish();
130   custom_option_buffer.assign(fbb.GetBuffer().begin(), fbb.GetBuffer().end());
131   return success();
132 }
133 
AddIntAttr(FuncOp func,DictionaryAttr attrs,const std::string & attribute,flexbuffers::Builder * builder)134 LogicalResult ConvertSSDPostProcessFunc::AddIntAttr(
135     FuncOp func, DictionaryAttr attrs, const std::string& attribute,
136     flexbuffers::Builder* builder) {
137   auto int_attr = attrs.get(attribute).dyn_cast_or_null<IntegerAttr>();
138   if (!int_attr) {
139     return func.emitError()
140            << attribute.c_str() << " attribute is not set or not an integer";
141   }
142   builder->Int(attribute.c_str(), int_attr.getInt());
143   return success();
144 }
145 
AddFloatAttr(FuncOp func,DictionaryAttr attrs,const std::string & attribute,flexbuffers::Builder * builder)146 LogicalResult ConvertSSDPostProcessFunc::AddFloatAttr(
147     FuncOp func, DictionaryAttr attrs, const std::string& attribute,
148     flexbuffers::Builder* builder) {
149   auto float_attr = attrs.get(attribute).dyn_cast_or_null<FloatAttr>();
150   if (!float_attr) {
151     return func.emitError()
152            << attribute.c_str() << " attribute is not set or not a float";
153   }
154   builder->Float(attribute.c_str(), float_attr.getValue().convertToFloat());
155   return success();
156 }
157 
VerifySignature()158 LogicalResult ConvertSSDPostProcessFunc::VerifySignature() {
159   // Verify high-level function signature.
160   if (func_.getNumArguments() != 3) {
161     return func_.emitError()
162            << "Invalid number of arguments to " << kCustomSSDPostprocessing
163            << ": " << func_.getNumArguments();
164   }
165   if (func_.getType().getNumResults() != 4) {
166     return func_.emitError()
167            << "Invalid number of results from " << kCustomSSDPostprocessing
168            << ": " << func_.getType().getNumResults();
169   }
170   return success();
171 }
172 
173 }  // namespace TFL
174 }  // namespace mlir
175