1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "CpuOperationUtils.h"
18 #include "OperationResolver.h"
19 #include "OperationsUtils.h"
20 
21 #include <cfloat>
22 #include <cmath>
23 #include <numeric>
24 
25 #include "Tracing.h"
26 
27 namespace android {
28 namespace nn {
29 namespace bbox_ops {
30 
31 namespace {
32 
33 struct BoxEncodingCorner {
34     float x1, y1, x2, y2;
35 };
36 struct BoxEncodingCenter {
37     float w, h, x, y;
38 };
toBoxEncodingCorner(const BoxEncodingCenter & ctr)39 BoxEncodingCorner toBoxEncodingCorner(const BoxEncodingCenter& ctr) {
40     return {.x1 = ctr.x - ctr.w / 2,
41             .y1 = ctr.y - ctr.h / 2,
42             .x2 = ctr.x + ctr.w / 2,
43             .y2 = ctr.y + ctr.h / 2};
44 }
toBoxEncodingCenter(const BoxEncodingCorner & cnr)45 BoxEncodingCenter toBoxEncodingCenter(const BoxEncodingCorner& cnr) {
46     return {.w = cnr.x2 - cnr.x1,
47             .h = cnr.y2 - cnr.y1,
48             .x = (cnr.x1 + cnr.x2) / 2,
49             .y = (cnr.y1 + cnr.y2) / 2};
50 }
51 
bboxTransformFloat32(const float * roiData,const Shape & roiShape,const float * bboxDeltasData,const Shape & bboxDeltasShape,const int32_t * batchesData,const Shape & batchesShape,const float * imageInfoData,const Shape & imageInfoDataShape,float * outputData,const Shape & outputShape)52 inline bool bboxTransformFloat32(const float* roiData, const Shape& roiShape,
53                                  const float* bboxDeltasData, const Shape& bboxDeltasShape,
54                                  const int32_t* batchesData, const Shape& batchesShape,
55                                  const float* imageInfoData, const Shape& imageInfoDataShape,
56                                  float* outputData, const Shape& outputShape) {
57     const uint32_t roiLength = 4;
58     const uint32_t imageLength = 2;
59 
60     uint32_t numClasses = getSizeOfDimension(bboxDeltasShape, 1) / roiLength;
61     uint32_t numBatches = getSizeOfDimension(imageInfoDataShape, 0);
62 
63     const float* roiDataEnd = roiData + getNumberOfElements(roiShape);
64     const float* deltas = bboxDeltasData;
65     float* outPtr = outputData;
66     uint32_t roiIndex = 0;
67     for (const float* roiBase = roiData; roiBase < roiDataEnd; roiBase += roiLength, roiIndex++) {
68         uint32_t batchIndex = batchesData[roiIndex];
69         // Check for malformed data
70         // 1. Invalid batch id
71         // 2. Invalid region: x2 < x1 || y2 < y1
72         NN_RET_CHECK_GE(batchIndex, 0);
73         NN_RET_CHECK_LT(batchIndex, numBatches);
74         NN_RET_CHECK_LE(roiBase[0], roiBase[2]);
75         NN_RET_CHECK_LE(roiBase[1], roiBase[3]);
76 
77         const float* imageInfoBase = imageInfoData + batchIndex * imageLength;
78         float imageHeight = imageInfoBase[0];
79         float imageWidth = imageInfoBase[1];
80         auto roiBefore = toBoxEncodingCenter(
81                 {.x1 = roiBase[0], .y1 = roiBase[1], .x2 = roiBase[2], .y2 = roiBase[3]});
82         for (uint32_t i = 0; i < numClasses; i++) {
83             auto roiAfter = toBoxEncodingCorner({.w = std::exp(deltas[2]) * roiBefore.w,
84                                                  .h = std::exp(deltas[3]) * roiBefore.h,
85                                                  .x = roiBefore.x + deltas[0] * roiBefore.w,
86                                                  .y = roiBefore.y + deltas[1] * roiBefore.h});
87             BoxEncodingCorner cliped = {.x1 = std::min(std::max(roiAfter.x1, 0.0f), imageWidth),
88                                         .y1 = std::min(std::max(roiAfter.y1, 0.0f), imageHeight),
89                                         .x2 = std::min(std::max(roiAfter.x2, 0.0f), imageWidth),
90                                         .y2 = std::min(std::max(roiAfter.y2, 0.0f), imageHeight)};
91             outPtr[0] = cliped.x1;
92             outPtr[1] = cliped.y1;
93             outPtr[2] = cliped.x2;
94             outPtr[3] = cliped.y2;
95             deltas += roiLength;
96             outPtr += roiLength;
97         }
98     }
99     return true;
100 }
101 
bboxTransformFloat16(const _Float16 * roiData,const Shape & roiShape,const _Float16 * bboxDeltasData,const Shape & bboxDeltasShape,const int32_t * batchesData,const Shape & batchesShape,const _Float16 * imageInfoData,const Shape & imageInfoDataShape,_Float16 * outputData,const Shape & outputShape)102 inline bool bboxTransformFloat16(const _Float16* roiData, const Shape& roiShape,
103                                  const _Float16* bboxDeltasData, const Shape& bboxDeltasShape,
104                                  const int32_t* batchesData, const Shape& batchesShape,
105                                  const _Float16* imageInfoData, const Shape& imageInfoDataShape,
106                                  _Float16* outputData, const Shape& outputShape) {
107     std::vector<float> roi_float32(getNumberOfElements(roiShape));
108     convertFloat16ToFloat32(roiData, &roi_float32);
109     std::vector<float> delta_float32(getNumberOfElements(bboxDeltasShape));
110     convertFloat16ToFloat32(bboxDeltasData, &delta_float32);
111     std::vector<float> imageInfo_float32(getNumberOfElements(imageInfoDataShape));
112     convertFloat16ToFloat32(imageInfoData, &imageInfo_float32);
113     std::vector<float> output_float32(getNumberOfElements(outputShape));
114     NN_RET_CHECK(bboxTransformFloat32(roi_float32.data(), roiShape, delta_float32.data(),
115                                       bboxDeltasShape, batchesData, batchesShape,
116                                       imageInfo_float32.data(), imageInfoDataShape,
117                                       output_float32.data(), outputShape));
118     convertFloat32ToFloat16(output_float32, outputData);
119     return true;
120 }
121 
bboxTransformQuant(const uint16_t * roiData,const Shape & roiShape,const uint8_t * bboxDeltasData,const Shape & bboxDeltasShape,const int32_t * batchesData,const Shape & batchesShape,const uint16_t * imageInfoData,const Shape & imageInfoDataShape,uint16_t * outputData,const Shape & outputShape)122 inline bool bboxTransformQuant(const uint16_t* roiData, const Shape& roiShape,
123                                const uint8_t* bboxDeltasData, const Shape& bboxDeltasShape,
124                                const int32_t* batchesData, const Shape& batchesShape,
125                                const uint16_t* imageInfoData, const Shape& imageInfoDataShape,
126                                uint16_t* outputData, const Shape& outputShape) {
127     std::vector<float> roi_float32(getNumberOfElements(roiShape));
128     convertQuantToFloat32(roiData, roiShape.scale, roiShape.offset, &roi_float32);
129     std::vector<float> delta_float32(getNumberOfElements(bboxDeltasShape));
130     convertQuantToFloat32(bboxDeltasData, bboxDeltasShape.scale, bboxDeltasShape.offset,
131                           &delta_float32);
132     std::vector<float> imageInfo_float32(getNumberOfElements(imageInfoDataShape));
133     convertQuantToFloat32(imageInfoData, imageInfoDataShape.scale, imageInfoDataShape.offset,
134                           &imageInfo_float32);
135     std::vector<float> output_float32(getNumberOfElements(outputShape));
136     NN_RET_CHECK(bboxTransformFloat32(roi_float32.data(), roiShape, delta_float32.data(),
137                                       bboxDeltasShape, batchesData, batchesShape,
138                                       imageInfo_float32.data(), imageInfoDataShape,
139                                       output_float32.data(), outputShape));
140     convertFloat32ToQuant(output_float32, outputShape.scale, outputShape.offset, outputData);
141     return true;
142 }
143 
144 // Taking two indices of bounding boxes, return the intersection-of-union.
getIoUAxisAligned(const float * roi1,const float * roi2)145 float getIoUAxisAligned(const float* roi1, const float* roi2) {
146     const float area1 = (roi1[2] - roi1[0]) * (roi1[3] - roi1[1]);
147     const float area2 = (roi2[2] - roi2[0]) * (roi2[3] - roi2[1]);
148     const float x1 = std::max(roi1[0], roi2[0]);
149     const float x2 = std::min(roi1[2], roi2[2]);
150     const float y1 = std::max(roi1[1], roi2[1]);
151     const float y2 = std::min(roi1[3], roi2[3]);
152     const float w = std::max(x2 - x1, 0.0f);
153     const float h = std::max(y2 - y1, 0.0f);
154     const float areaIntersect = w * h;
155     const float areaUnion = area1 + area2 - areaIntersect;
156     return areaIntersect / areaUnion;
157 }
158 
159 }  // namespace
160 
161 namespace axis_aligned_bbox_transform {
162 
163 constexpr char kOperationName[] = "AXIS_ALIGNED_BBOX_TRANSFORM";
164 
165 constexpr uint32_t kNumInputs = 4;
166 constexpr uint32_t kRoiTensor = 0;
167 constexpr uint32_t kDeltaTensor = 1;
168 constexpr uint32_t kBatchesTensor = 2;
169 constexpr uint32_t kImageInfoTensor = 3;
170 
171 constexpr uint32_t kNumOutputs = 1;
172 constexpr uint32_t kOutputTensor = 0;
173 
validate(const IOperationValidationContext * context)174 bool validate(const IOperationValidationContext* context) {
175     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
176     NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
177     std::vector<OperandType> inExpectedTypes;
178     auto inputType = context->getInputType(kRoiTensor);
179     if (inputType == OperandType::TENSOR_FLOAT32 || inputType == OperandType::TENSOR_FLOAT16) {
180         inExpectedTypes = {inputType, inputType, OperandType::TENSOR_INT32, inputType};
181     } else if (inputType == OperandType::TENSOR_QUANT16_ASYMM) {
182         inExpectedTypes = {OperandType::TENSOR_QUANT16_ASYMM, OperandType::TENSOR_QUANT8_ASYMM,
183                            OperandType::TENSOR_INT32, OperandType::TENSOR_QUANT16_ASYMM};
184     } else {
185         LOG(ERROR) << "Unsupported input tensor type for operation " << kOperationName;
186         return false;
187     }
188     NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
189     NN_RET_CHECK(validateOutputTypes(context, {inputType}));
190     return validateHalVersion(context, HalVersion::V1_2);
191 }
192 
prepare(IOperationExecutionContext * context)193 bool prepare(IOperationExecutionContext* context) {
194     Shape roiShape = context->getInputShape(kRoiTensor);
195     Shape bboxDeltasShape = context->getInputShape(kDeltaTensor);
196     Shape batchesShape = context->getInputShape(kBatchesTensor);
197     Shape imageInfoShape = context->getInputShape(kImageInfoTensor);
198     Shape outputShape = context->getOutputShape(kOutputTensor);
199 
200     NN_RET_CHECK_EQ(getNumberOfDimensions(roiShape), 2);
201     NN_RET_CHECK_EQ(getNumberOfDimensions(bboxDeltasShape), 2);
202     NN_RET_CHECK_EQ(getNumberOfDimensions(batchesShape), 1);
203     NN_RET_CHECK_EQ(getNumberOfDimensions(imageInfoShape), 2);
204 
205     // Only numRois can be zero.
206     const uint32_t kRoiDim = 4;
207     uint32_t numRois = getSizeOfDimension(roiShape, 0);
208     uint32_t numClasses = getSizeOfDimension(bboxDeltasShape, 1) / kRoiDim;
209     uint32_t numBatches = getSizeOfDimension(imageInfoShape, 0);
210     NN_RET_CHECK_GT(numClasses, 0);
211     NN_RET_CHECK_GT(numBatches, 0);
212     NN_RET_CHECK_EQ(getSizeOfDimension(roiShape, 1), kRoiDim);
213     NN_RET_CHECK_EQ(getSizeOfDimension(bboxDeltasShape, 0), numRois);
214     NN_RET_CHECK_EQ(getSizeOfDimension(bboxDeltasShape, 1), kRoiDim * numClasses);
215     NN_RET_CHECK_EQ(getSizeOfDimension(batchesShape, 0), numRois);
216     NN_RET_CHECK_EQ(getSizeOfDimension(imageInfoShape, 1), 2);
217 
218     if (roiShape.type == OperandType::TENSOR_QUANT16_ASYMM) {
219         NN_RET_CHECK_EQ(roiShape.scale, 0.125f);
220         NN_RET_CHECK_EQ(roiShape.offset, 0);
221         NN_RET_CHECK_EQ(imageInfoShape.scale, 0.125f);
222         NN_RET_CHECK_EQ(imageInfoShape.offset, 0);
223     }
224 
225     outputShape.type = roiShape.type;
226     outputShape.dimensions = {numRois, numClasses * kRoiDim};
227     outputShape.scale = 0.125f;
228     outputShape.offset = 0;
229     NN_RET_CHECK(context->setOutputShape(kOutputTensor, outputShape));
230     return true;
231 }
232 
execute(IOperationExecutionContext * context)233 bool execute(IOperationExecutionContext* context) {
234     NNTRACE_TRANS("axisAlignedBBoxTransform");
235     // Bypass execution in the case of zero-sized input.
236     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
237     switch (context->getInputType(kRoiTensor)) {
238         case OperandType::TENSOR_FLOAT16: {
239             return bboxTransformFloat16(context->getInputBuffer<_Float16>(kRoiTensor),
240                                         context->getInputShape(kRoiTensor),
241                                         context->getInputBuffer<_Float16>(kDeltaTensor),
242                                         context->getInputShape(kDeltaTensor),
243                                         context->getInputBuffer<int32_t>(kBatchesTensor),
244                                         context->getInputShape(kBatchesTensor),
245                                         context->getInputBuffer<_Float16>(kImageInfoTensor),
246                                         context->getInputShape(kImageInfoTensor),
247                                         context->getOutputBuffer<_Float16>(kOutputTensor),
248                                         context->getOutputShape(kOutputTensor));
249         }
250         case OperandType::TENSOR_FLOAT32: {
251             return bboxTransformFloat32(context->getInputBuffer<float>(kRoiTensor),
252                                         context->getInputShape(kRoiTensor),
253                                         context->getInputBuffer<float>(kDeltaTensor),
254                                         context->getInputShape(kDeltaTensor),
255                                         context->getInputBuffer<int32_t>(kBatchesTensor),
256                                         context->getInputShape(kBatchesTensor),
257                                         context->getInputBuffer<float>(kImageInfoTensor),
258                                         context->getInputShape(kImageInfoTensor),
259                                         context->getOutputBuffer<float>(kOutputTensor),
260                                         context->getOutputShape(kOutputTensor));
261         }
262         case OperandType::TENSOR_QUANT16_ASYMM: {
263             return bboxTransformQuant(context->getInputBuffer<uint16_t>(kRoiTensor),
264                                       context->getInputShape(kRoiTensor),
265                                       context->getInputBuffer<uint8_t>(kDeltaTensor),
266                                       context->getInputShape(kDeltaTensor),
267                                       context->getInputBuffer<int32_t>(kBatchesTensor),
268                                       context->getInputShape(kBatchesTensor),
269                                       context->getInputBuffer<uint16_t>(kImageInfoTensor),
270                                       context->getInputShape(kImageInfoTensor),
271                                       context->getOutputBuffer<uint16_t>(kOutputTensor),
272                                       context->getOutputShape(kOutputTensor));
273         }
274         default:
275             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
276     }
277 }
278 
279 }  // namespace axis_aligned_bbox_transform
280 
281 namespace box_with_nms_limit {
282 
283 constexpr char kOperationName[] = "BOX_WITH_NMS_LIMIT";
284 
285 constexpr uint32_t kNumInputs = 9;
286 constexpr uint32_t kScoreTensor = 0;
287 constexpr uint32_t kRoiTensor = 1;
288 constexpr uint32_t kBatchesTensor = 2;
289 constexpr uint32_t kScoreThresholdScalar = 3;
290 constexpr uint32_t kMaxNumDetectionScalar = 4;
291 constexpr uint32_t kNmsKernelScalar = 5;
292 constexpr uint32_t kIoUThresholdScalar = 6;
293 constexpr uint32_t kSigmaScalar = 7;
294 constexpr uint32_t kNmsScoreThresholdScalar = 8;
295 
296 constexpr uint32_t kNumOutputs = 4;
297 constexpr uint32_t kOutputScoreTensor = 0;
298 constexpr uint32_t kOutputRoiTensor = 1;
299 constexpr uint32_t kOutputClassTensor = 2;
300 constexpr uint32_t kOutputBatchesTensor = 3;
301 
302 namespace {
303 
304 // TODO(xusongw): Reduce code duplication with hard/soft nms path.
305 
306 // Inplace hard NMS within range [select, select + selectLength).
hardNmsSingleClass(const float * scoresData,float iouThreshold,int32_t maxNumDetections,std::function<const float * (uint32_t)> getRoiBase,uint32_t * select,uint32_t selectLength)307 uint32_t* hardNmsSingleClass(const float* scoresData, float iouThreshold, int32_t maxNumDetections,
308                              std::function<const float*(uint32_t)> getRoiBase, uint32_t* select,
309                              uint32_t selectLength) {
310     uint32_t *selectStart = select, *selectEnd = select + selectLength, numDetections = 0;
311     if (maxNumDetections < 0) {
312         maxNumDetections = selectLength;
313     }
314     while (selectStart < selectEnd && numDetections < maxNumDetections) {
315         // find max score and swap to the front
316         auto& maxScore = *std::max_element(selectStart, selectEnd,
317                                            [&scoresData](const uint32_t& lhs, const uint32_t& rhs) {
318                                                return scoresData[lhs] < scoresData[rhs];
319                                            });
320         std::swap(maxScore, *selectStart);
321 
322         // Calculate IoU of the rest, swap to the end (disgard) if needed.
323         for (uint32_t* i = selectStart + 1; i < selectEnd; i++) {
324             float iou = getIoUAxisAligned(getRoiBase(*i), getRoiBase(*selectStart));
325             if (iou >= iouThreshold) {
326                 std::swap(*i--, *(--selectEnd));
327             }
328         }
329         selectStart++;
330         numDetections++;
331     }
332     return selectStart;
333 }
334 
hardNmsMultiClass(const float * scoresData,uint32_t numClasses,uint32_t numRois,float scoreThreshold,float iouThreshold,int32_t maxNumDetections,int32_t maxNumDetectionsPerClass,std::function<const float * (uint32_t)> getRoiBase,std::vector<uint32_t> * select)335 void hardNmsMultiClass(const float* scoresData, uint32_t numClasses, uint32_t numRois,
336                        float scoreThreshold, float iouThreshold, int32_t maxNumDetections,
337                        int32_t maxNumDetectionsPerClass,
338                        std::function<const float*(uint32_t)> getRoiBase,
339                        std::vector<uint32_t>* select) {
340     // Exclude class 0 (background)
341     for (uint32_t c = 1; c < numClasses; c++) {
342         uint32_t size = select->size();
343         for (uint32_t b = 0; b < numRois; b++) {
344             const uint32_t index = b * numClasses + c;
345             const float score = scoresData[index];
346             if (score > scoreThreshold) {
347                 select->push_back(index);
348             }
349         }
350         uint32_t* selectStart = select->data() + size;
351         uint32_t selectLength = select->size() - size;
352         uint32_t* selectEnd = hardNmsSingleClass(scoresData, iouThreshold, maxNumDetectionsPerClass,
353                                                  getRoiBase, selectStart, selectLength);
354         select->resize(selectEnd - select->data());
355     }
356 
357     // Take top maxNumDetections.
358     std::sort(select->begin(), select->end(),
359               [&scoresData](const uint32_t& lhs, const uint32_t& rhs) {
360                   return scoresData[lhs] > scoresData[rhs];
361               });
362     if (maxNumDetections < 0 || select->size() <= maxNumDetections) {
363         return;
364     }
365     select->resize(maxNumDetections);
366 }
367 
368 // Inplace soft NMS within range [select, select + selectLength).
369 using SoftNmsKernel = std::function<float(float)>;
softNmsSingleClass(float * scoresData,float scoreThreshold,int32_t maxNumDetections,std::function<const float * (uint32_t)> getRoiBase,SoftNmsKernel kernel,uint32_t * select,uint32_t selectLength)370 uint32_t* softNmsSingleClass(float* scoresData, float scoreThreshold, int32_t maxNumDetections,
371                              std::function<const float*(uint32_t)> getRoiBase, SoftNmsKernel kernel,
372                              uint32_t* select, uint32_t selectLength) {
373     uint32_t *selectStart = select, *selectEnd = select + selectLength, numDetections = 0;
374     if (maxNumDetections < 0) {
375         maxNumDetections = selectLength;
376     }
377     while (selectStart < selectEnd && numDetections < maxNumDetections) {
378         // find max score and swap to the front
379         auto& maxScore = *std::max_element(selectStart, selectEnd,
380                                            [&scoresData](const uint32_t& lhs, const uint32_t& rhs) {
381                                                return scoresData[lhs] < scoresData[rhs];
382                                            });
383         std::swap(maxScore, *selectStart);
384 
385         // Calculate IoU of the rest, swap to the end (disgard) if needed.
386         for (uint32_t* i = selectStart + 1; i < selectEnd; i++) {
387             float iou = getIoUAxisAligned(getRoiBase(*i), getRoiBase(*selectStart));
388             scoresData[*i] *= kernel(iou);
389             if (scoresData[*i] < scoreThreshold) {
390                 std::swap(*i--, *(--selectEnd));
391             }
392         }
393         selectStart++;
394         numDetections++;
395     }
396     return selectStart;
397 }
398 
softNmsMultiClass(float * scoresData,uint32_t numClasses,uint32_t numRois,float scoreThreshold,float nmsScoreThreshold,int32_t maxNumDetections,int32_t maxNumDetectionsPerClass,std::function<const float * (uint32_t)> getRoiBase,SoftNmsKernel kernel,std::vector<uint32_t> * select)399 void softNmsMultiClass(float* scoresData, uint32_t numClasses, uint32_t numRois,
400                        float scoreThreshold, float nmsScoreThreshold, int32_t maxNumDetections,
401                        int32_t maxNumDetectionsPerClass,
402                        std::function<const float*(uint32_t)> getRoiBase, SoftNmsKernel kernel,
403                        std::vector<uint32_t>* select) {
404     // Exclude class 0 (background)
405     for (uint32_t c = 1; c < numClasses; c++) {
406         uint32_t size = select->size();
407         for (uint32_t b = 0; b < numRois; b++) {
408             const uint32_t index = b * numClasses + c;
409             const float score = scoresData[index];
410             if (score > scoreThreshold) {
411                 select->push_back(index);
412             }
413         }
414         uint32_t* selectStart = select->data() + size;
415         uint32_t selectLength = select->size() - size;
416         uint32_t* selectEnd =
417                 softNmsSingleClass(scoresData, nmsScoreThreshold, maxNumDetectionsPerClass,
418                                    getRoiBase, kernel, selectStart, selectLength);
419         select->resize(selectEnd - select->data());
420     }
421 
422     // Take top maxNumDetections.
423     std::sort(select->begin(), select->end(),
424               [&scoresData](const uint32_t& lhs, const uint32_t& rhs) {
425                   return scoresData[lhs] > scoresData[rhs];
426               });
427     if (maxNumDetections < 0 || select->size() <= maxNumDetections) {
428         return;
429     }
430     select->resize(maxNumDetections);
431 }
432 
boxWithNmsLimitFloat32Compute(float * scoresData,const Shape & scoresShape,const float * roiData,const Shape & roiShape,const int32_t * batchesData,const Shape & batchesShape,float scoreThreshold,int32_t maxNumDetections,int32_t softNmsKernel,float iouThreshold,float sigma,float nmsScoreThreshold,std::vector<uint32_t> * batchSplitIn,std::vector<uint32_t> * batchSplitOut,std::vector<uint32_t> * selected)433 bool boxWithNmsLimitFloat32Compute(float* scoresData, const Shape& scoresShape,
434                                    const float* roiData, const Shape& roiShape,
435                                    const int32_t* batchesData, const Shape& batchesShape,
436                                    float scoreThreshold, int32_t maxNumDetections,
437                                    int32_t softNmsKernel, float iouThreshold, float sigma,
438                                    float nmsScoreThreshold, std::vector<uint32_t>* batchSplitIn,
439                                    std::vector<uint32_t>* batchSplitOut,
440                                    std::vector<uint32_t>* selected) {
441     SoftNmsKernel kernel = nullptr;
442     if (softNmsKernel == 0) {
443         kernel = [&iouThreshold](float iou) { return iou < iouThreshold ? 1.0f : 0.0f; };
444     } else if (softNmsKernel == 1) {
445         kernel = [&iouThreshold](float iou) { return iou < iouThreshold ? 1.0f : 1.0f - iou; };
446     } else if (softNmsKernel == 2) {
447         kernel = [&sigma](float iou) { return std::exp(-1.0f * iou * iou / sigma); };
448     } else {
449         NN_RET_CHECK_FAIL() << "Unsupported soft NMS kernel " << softNmsKernel;
450     }
451 
452     const uint32_t kRoiDim = 4;
453     uint32_t numRois = getSizeOfDimension(scoresShape, 0);
454     uint32_t numClasses = getSizeOfDimension(scoresShape, 1);
455 
456     // We assume boxes of the same batch are grouped together.
457     std::vector<uint32_t> batch;
458     for (uint32_t i = 0, ind = -1; i < numRois; i++) {
459         if (batchesData[i] == ind) {
460             (batchSplitIn->back())++;
461         } else {
462             ind = batchesData[i];
463             batchSplitIn->push_back(1);
464         }
465     }
466 
467     float* scoresBase = scoresData;
468     const float* roiBase = roiData;
469     selected->clear();
470     for (uint32_t b = 0; b < batchSplitIn->size(); b++) {
471         for (uint32_t i = 0; i < batchSplitIn->at(b); i++) {
472             const float* roi = roiBase + i * kRoiDim;
473             // Check for malformed data: invalid region: x2 < x1 || y2 < y1
474             NN_RET_CHECK_LE(roi[0], roi[2]);
475             NN_RET_CHECK_LE(roi[1], roi[3]);
476         }
477         std::vector<uint32_t> result;
478         softNmsMultiClass(scoresBase, numClasses, batchSplitIn->at(b), scoreThreshold,
479                           nmsScoreThreshold, maxNumDetections, maxNumDetections,
480                           [&roiBase](uint32_t ind) { return roiBase + ind * kRoiDim; }, kernel,
481                           &result);
482         // Sort again by class.
483         std::sort(result.begin(), result.end(),
484                   [&scoresBase, numClasses](const uint32_t& lhs, const uint32_t& rhs) {
485                       uint32_t lhsClass = lhs % numClasses, rhsClass = rhs % numClasses;
486                       return lhsClass == rhsClass ? scoresBase[lhs] > scoresBase[rhs]
487                                                   : lhsClass < rhsClass;
488                   });
489         selected->insert(selected->end(), result.begin(), result.end());
490         batchSplitOut->push_back(result.size());
491         scoresBase += batchSplitIn->at(b) * numClasses;
492         roiBase += batchSplitIn->at(b) * numClasses * kRoiDim;
493     }
494     return true;
495 }
496 
497 template <typename T>
castTo(float val,const Shape &)498 T castTo(float val, const Shape&) {
499     return val;
500 }
501 template <>
castTo(float val,const Shape & shape)502 uint8_t castTo(float val, const Shape& shape) {
503     int32_t intVal = std::round(val / shape.scale + shape.offset);
504     intVal = std::min<int32_t>(std::max<int32_t>(intVal, std::numeric_limits<uint8_t>::min()),
505                                std::numeric_limits<uint8_t>::max());
506     return static_cast<uint8_t>(intVal);
507 }
508 
509 template <typename T_Score, typename T_Roi>
boxWithNmsLimitWriteOutput(const std::vector<uint32_t> & selected,const std::vector<uint32_t> & batchSplitIn,const std::vector<uint32_t> & batchSplitOut,const std::vector<float> & scores,IOperationExecutionContext * context)510 bool boxWithNmsLimitWriteOutput(const std::vector<uint32_t>& selected,
511                                 const std::vector<uint32_t>& batchSplitIn,
512                                 const std::vector<uint32_t>& batchSplitOut,
513                                 const std::vector<float>& scores,
514                                 IOperationExecutionContext* context) {
515     const uint32_t kRoiDim = 4;
516     Shape scoresShape = context->getInputShape(kScoreTensor);
517     uint32_t numClasses = getSizeOfDimension(scoresShape, 1);
518 
519     // Set output dimensions.
520     uint32_t numOutRois = selected.size();
521     if (numOutRois == 0) return true;
522     Shape scoresOutShape = context->getOutputShape(kOutputScoreTensor);
523     scoresOutShape.dimensions = {numOutRois};
524     NN_RET_CHECK(context->setOutputShape(kOutputScoreTensor, scoresOutShape));
525 
526     Shape roiOutShape = context->getOutputShape(kOutputRoiTensor);
527     roiOutShape.dimensions = {numOutRois, 4};
528     NN_RET_CHECK(context->setOutputShape(kOutputRoiTensor, roiOutShape));
529 
530     Shape classesOutShape = context->getOutputShape(kOutputClassTensor);
531     classesOutShape.dimensions = {numOutRois};
532     NN_RET_CHECK(context->setOutputShape(kOutputClassTensor, classesOutShape));
533 
534     Shape batchesOutShape = context->getOutputShape(kOutputBatchesTensor);
535     batchesOutShape.dimensions = {numOutRois};
536     NN_RET_CHECK(context->setOutputShape(kOutputBatchesTensor, batchesOutShape));
537 
538     // Write outputs.
539     const float* scoresBase = scores.data();
540     const T_Roi* roiBase = context->getInputBuffer<T_Roi>(kRoiTensor);
541     const int32_t* batchesInPtr = context->getInputBuffer<int32_t>(kBatchesTensor);
542     T_Score* scoresOutPtr = context->getOutputBuffer<T_Score>(kOutputScoreTensor);
543     T_Roi* roiOutPtr = context->getOutputBuffer<T_Roi>(kOutputRoiTensor);
544     int32_t* classesOutPtr = context->getOutputBuffer<int32_t>(kOutputClassTensor);
545     int32_t* batchesOutPtr = context->getOutputBuffer<int32_t>(kOutputBatchesTensor);
546     uint32_t i = 0;
547     for (uint32_t b = 0; b < batchSplitOut.size(); b++) {
548         for (uint32_t j = 0; j < batchSplitOut[b]; j++) {
549             uint32_t index = selected[i++];
550             *scoresOutPtr++ = castTo<T_Score>(scoresBase[index], scoresOutShape);
551             memcpy(roiOutPtr, roiBase + index * kRoiDim, kRoiDim * sizeof(T_Roi));
552             roiOutPtr += kRoiDim;
553             *classesOutPtr++ = index % numClasses;
554             *batchesOutPtr++ = *batchesInPtr;
555         }
556         scoresBase += batchSplitIn[b] * numClasses;
557         roiBase += batchSplitIn[b] * numClasses * kRoiDim;
558         batchesInPtr += batchSplitIn[b];
559     }
560     return true;
561 }
562 
boxWithNmsLimitFloat32(const float * scoresData,const Shape & scoresShape,const float * roiData,const Shape & roiShape,const int32_t * batchesData,const Shape & batchesShape,float scoreThreshold,int32_t maxNumDetections,int32_t softNmsKernel,float iouThreshold,float sigma,float nmsScoreThreshold,float * scoresOutData,Shape scoresOutShape,float * roiOutData,Shape roiOutShape,int32_t * classesOutData,Shape classesOutShape,int32_t * batchesOutData,const Shape & batchSplitOutShape,IOperationExecutionContext * context)563 bool boxWithNmsLimitFloat32(const float* scoresData, const Shape& scoresShape, const float* roiData,
564                             const Shape& roiShape, const int32_t* batchesData,
565                             const Shape& batchesShape, float scoreThreshold,
566                             int32_t maxNumDetections, int32_t softNmsKernel, float iouThreshold,
567                             float sigma, float nmsScoreThreshold, float* scoresOutData,
568                             Shape scoresOutShape, float* roiOutData, Shape roiOutShape,
569                             int32_t* classesOutData, Shape classesOutShape, int32_t* batchesOutData,
570                             const Shape& batchSplitOutShape, IOperationExecutionContext* context) {
571     NNTRACE_TRANS("boxWithNmsLimit");
572     std::vector<float> scores_float32(getNumberOfElements(scoresShape));
573     for (uint32_t i = 0; i < scores_float32.size(); i++) {
574         scores_float32[i] = scoresData[i];
575     }
576     std::vector<uint32_t> selected, batchSplitIn, batchSplitOut;
577     NN_RET_CHECK(boxWithNmsLimitFloat32Compute(
578             scores_float32.data(), scoresShape, roiData, roiShape, batchesData, batchesShape,
579             scoreThreshold, maxNumDetections, softNmsKernel, iouThreshold, sigma, nmsScoreThreshold,
580             &batchSplitIn, &batchSplitOut, &selected));
581     return boxWithNmsLimitWriteOutput<float, float>(selected, batchSplitIn, batchSplitOut,
582                                                     scores_float32, context);
583 }
584 
boxWithNmsLimitFloat16(const _Float16 * scoresData,const Shape & scoresShape,const _Float16 * roiData,const Shape & roiShape,const int32_t * batchesData,const Shape & batchesShape,_Float16 scoreThreshold,int32_t maxNumDetections,int32_t softNmsKernel,_Float16 iouThreshold,_Float16 sigma,_Float16 nmsScoreThreshold,_Float16 * scoresOutData,const Shape & scoresOutShape,_Float16 * roiOutData,const Shape & roiOutShape,int32_t * classesOutData,const Shape & classesOutShape,int32_t * batchesOutData,const Shape & batchSplitOutShape,IOperationExecutionContext * context)585 bool boxWithNmsLimitFloat16(const _Float16* scoresData, const Shape& scoresShape,
586                             const _Float16* roiData, const Shape& roiShape,
587                             const int32_t* batchesData, const Shape& batchesShape,
588                             _Float16 scoreThreshold, int32_t maxNumDetections,
589                             int32_t softNmsKernel, _Float16 iouThreshold, _Float16 sigma,
590                             _Float16 nmsScoreThreshold, _Float16* scoresOutData,
591                             const Shape& scoresOutShape, _Float16* roiOutData,
592                             const Shape& roiOutShape, int32_t* classesOutData,
593                             const Shape& classesOutShape, int32_t* batchesOutData,
594                             const Shape& batchSplitOutShape, IOperationExecutionContext* context) {
595     std::vector<float> scores_float32(getNumberOfElements(scoresShape));
596     convertFloat16ToFloat32(scoresData, &scores_float32);
597     std::vector<float> roi_float32(getNumberOfElements(roiShape));
598     convertFloat16ToFloat32(roiData, &roi_float32);
599     std::vector<uint32_t> selected, batchSplitIn, batchSplitOut;
600     NN_RET_CHECK(boxWithNmsLimitFloat32Compute(
601             scores_float32.data(), scoresShape, roi_float32.data(), roiShape, batchesData,
602             batchesShape, scoreThreshold, maxNumDetections, softNmsKernel, iouThreshold, sigma,
603             nmsScoreThreshold, &batchSplitIn, &batchSplitOut, &selected));
604     return boxWithNmsLimitWriteOutput<_Float16, _Float16>(selected, batchSplitIn, batchSplitOut,
605                                                           scores_float32, context);
606 }
607 
boxWithNmsLimitQuant(const uint8_t * scoresData,const Shape & scoresShape,const uint16_t * roiData,const Shape & roiShape,const int32_t * batchesData,const Shape & batchesShape,float scoreThreshold,int32_t maxNumDetections,int32_t softNmsKernel,float iouThreshold,float sigma,float nmsScoreThreshold,uint8_t * scoresOutData,const Shape & scoresOutShape,uint16_t * roiOutData,const Shape & roiOutShape,int32_t * classesOutData,const Shape & classesOutShape,int32_t * batchesOutData,const Shape & batchSplitOutShape,IOperationExecutionContext * context)608 bool boxWithNmsLimitQuant(const uint8_t* scoresData, const Shape& scoresShape,
609                           const uint16_t* roiData, const Shape& roiShape,
610                           const int32_t* batchesData, const Shape& batchesShape,
611                           float scoreThreshold, int32_t maxNumDetections, int32_t softNmsKernel,
612                           float iouThreshold, float sigma, float nmsScoreThreshold,
613                           uint8_t* scoresOutData, const Shape& scoresOutShape, uint16_t* roiOutData,
614                           const Shape& roiOutShape, int32_t* classesOutData,
615                           const Shape& classesOutShape, int32_t* batchesOutData,
616                           const Shape& batchSplitOutShape, IOperationExecutionContext* context) {
617     std::vector<float> scores_float32(getNumberOfElements(scoresShape));
618     convertQuantToFloat32(scoresData, scoresShape.scale, scoresShape.offset, &scores_float32);
619     std::vector<float> roi_float32(getNumberOfElements(roiShape));
620     convertQuantToFloat32(roiData, roiShape.scale, roiShape.offset, &roi_float32);
621     std::vector<uint32_t> selected, batchSplitIn, batchSplitOut;
622     NN_RET_CHECK(boxWithNmsLimitFloat32Compute(
623             scores_float32.data(), scoresShape, roi_float32.data(), roiShape, batchesData,
624             batchesShape, scoreThreshold, maxNumDetections, softNmsKernel, iouThreshold, sigma,
625             nmsScoreThreshold, &batchSplitIn, &batchSplitOut, &selected));
626     return boxWithNmsLimitWriteOutput<uint8_t, uint16_t>(selected, batchSplitIn, batchSplitOut,
627                                                          scores_float32, context);
628 }
629 
630 }  // namespace
631 
validate(const IOperationValidationContext * context)632 bool validate(const IOperationValidationContext* context) {
633     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
634     NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
635     std::vector<OperandType> inExpectedTypes;
636     std::vector<OperandType> outExpectedTypes;
637     auto inputType = context->getInputType(kScoreTensor);
638     if (inputType == OperandType::TENSOR_FLOAT16) {
639         inExpectedTypes = {
640                 OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16, OperandType::TENSOR_INT32,
641                 OperandType::FLOAT16,        OperandType::INT32,          OperandType::INT32,
642                 OperandType::FLOAT16,        OperandType::FLOAT16,        OperandType::FLOAT16};
643         outExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
644                             OperandType::TENSOR_INT32, OperandType::TENSOR_INT32};
645     } else if (inputType == OperandType::TENSOR_FLOAT32) {
646         inExpectedTypes = {
647                 OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32, OperandType::TENSOR_INT32,
648                 OperandType::FLOAT32,        OperandType::INT32,          OperandType::INT32,
649                 OperandType::FLOAT32,        OperandType::FLOAT32,        OperandType::FLOAT32};
650         outExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
651                             OperandType::TENSOR_INT32, OperandType::TENSOR_INT32};
652     } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
653         inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM,
654                            OperandType::TENSOR_QUANT16_ASYMM,
655                            OperandType::TENSOR_INT32,
656                            OperandType::FLOAT32,
657                            OperandType::INT32,
658                            OperandType::INT32,
659                            OperandType::FLOAT32,
660                            OperandType::FLOAT32,
661                            OperandType::FLOAT32};
662         outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_QUANT16_ASYMM,
663                             OperandType::TENSOR_INT32, OperandType::TENSOR_INT32};
664     } else {
665         NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
666     }
667     NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
668     NN_RET_CHECK(validateOutputTypes(context, outExpectedTypes));
669     return validateHalVersion(context, HalVersion::V1_2);
670 }
671 
prepare(IOperationExecutionContext * context)672 bool prepare(IOperationExecutionContext* context) {
673     Shape scoreShape = context->getInputShape(kScoreTensor);
674     Shape roiShape = context->getInputShape(kRoiTensor);
675     Shape batchesShape = context->getInputShape(kBatchesTensor);
676     Shape outputScoreShape = context->getOutputShape(kOutputScoreTensor);
677     Shape outputRoiShape = context->getOutputShape(kOutputRoiTensor);
678     Shape outputClassShape = context->getOutputShape(kOutputClassTensor);
679     Shape outputBatchSplitShape = context->getOutputShape(kOutputBatchesTensor);
680 
681     NN_RET_CHECK(getNumberOfDimensions(scoreShape) == 2);
682     NN_RET_CHECK(getNumberOfDimensions(roiShape) == 2);
683     NN_RET_CHECK(getNumberOfDimensions(batchesShape) == 1);
684 
685     // Only numRois can be zero.
686     const uint32_t kRoiDim = 4;
687     uint32_t numRois = getSizeOfDimension(scoreShape, 0);
688     uint32_t numClasses = getSizeOfDimension(scoreShape, 1);
689     NN_RET_CHECK(getSizeOfDimension(roiShape, 0) == numRois);
690     NN_RET_CHECK(getSizeOfDimension(roiShape, 1) == kRoiDim * numClasses);
691     NN_RET_CHECK(getSizeOfDimension(batchesShape, 0) == numRois);
692     NN_RET_CHECK_GT(numClasses, 1);
693 
694     if (scoreShape.type == OperandType::TENSOR_QUANT8_ASYMM) {
695         NN_RET_CHECK_EQ(roiShape.scale, 0.125f);
696         NN_RET_CHECK_EQ(roiShape.offset, 0);
697     }
698 
699     outputScoreShape.type = scoreShape.type;
700     outputScoreShape.dimensions = {0};
701     outputScoreShape.scale = scoreShape.scale;
702     outputScoreShape.offset = scoreShape.offset;
703     NN_RET_CHECK(context->setOutputShape(kOutputScoreTensor, outputScoreShape));
704 
705     outputRoiShape.type = roiShape.type;
706     outputRoiShape.dimensions = {0, 4};
707     outputRoiShape.scale = 0.125f;
708     outputRoiShape.offset = 0;
709     NN_RET_CHECK(context->setOutputShape(kOutputRoiTensor, outputRoiShape));
710 
711     outputClassShape.type = OperandType::TENSOR_INT32;
712     outputClassShape.dimensions = {0};
713     NN_RET_CHECK(context->setOutputShape(kOutputClassTensor, outputClassShape));
714 
715     outputBatchSplitShape.type = batchesShape.type;
716     outputBatchSplitShape.dimensions = {0};
717     NN_RET_CHECK(context->setOutputShape(kOutputBatchesTensor, outputBatchSplitShape));
718     return true;
719 }
720 
execute(IOperationExecutionContext * context)721 bool execute(IOperationExecutionContext* context) {
722     NNTRACE_TRANS("boxWithNMSLimit");
723     // Bypass execution in the case of zero numRois.
724     if (getSizeOfDimension(context->getInputShape(kScoreTensor), 0) == 0) return true;
725     switch (context->getInputType(kScoreTensor)) {
726         case OperandType::TENSOR_FLOAT16: {
727             return boxWithNmsLimitFloat16(
728                     context->getInputBuffer<_Float16>(kScoreTensor),
729                     context->getInputShape(kScoreTensor),
730                     context->getInputBuffer<_Float16>(kRoiTensor),
731                     context->getInputShape(kRoiTensor),
732                     context->getInputBuffer<int32_t>(kBatchesTensor),
733                     context->getInputShape(kBatchesTensor),
734                     context->getInputValue<_Float16>(kScoreThresholdScalar),
735                     context->getInputValue<int32_t>(kMaxNumDetectionScalar),
736                     context->getInputValue<int32_t>(kNmsKernelScalar),
737                     context->getInputValue<_Float16>(kIoUThresholdScalar),
738                     context->getInputValue<_Float16>(kSigmaScalar),
739                     context->getInputValue<_Float16>(kNmsScoreThresholdScalar),
740                     context->getOutputBuffer<_Float16>(kOutputScoreTensor),
741                     context->getOutputShape(kOutputScoreTensor),
742                     context->getOutputBuffer<_Float16>(kOutputRoiTensor),
743                     context->getOutputShape(kOutputRoiTensor),
744                     context->getOutputBuffer<int32_t>(kOutputClassTensor),
745                     context->getOutputShape(kOutputClassTensor),
746                     context->getOutputBuffer<int32_t>(kOutputBatchesTensor),
747                     context->getOutputShape(kOutputBatchesTensor), context);
748         }
749         case OperandType::TENSOR_FLOAT32: {
750             return boxWithNmsLimitFloat32(context->getInputBuffer<float>(kScoreTensor),
751                                           context->getInputShape(kScoreTensor),
752                                           context->getInputBuffer<float>(kRoiTensor),
753                                           context->getInputShape(kRoiTensor),
754                                           context->getInputBuffer<int32_t>(kBatchesTensor),
755                                           context->getInputShape(kBatchesTensor),
756                                           context->getInputValue<float>(kScoreThresholdScalar),
757                                           context->getInputValue<int32_t>(kMaxNumDetectionScalar),
758                                           context->getInputValue<int32_t>(kNmsKernelScalar),
759                                           context->getInputValue<float>(kIoUThresholdScalar),
760                                           context->getInputValue<float>(kSigmaScalar),
761                                           context->getInputValue<float>(kNmsScoreThresholdScalar),
762                                           context->getOutputBuffer<float>(kOutputScoreTensor),
763                                           context->getOutputShape(kOutputScoreTensor),
764                                           context->getOutputBuffer<float>(kOutputRoiTensor),
765                                           context->getOutputShape(kOutputRoiTensor),
766                                           context->getOutputBuffer<int32_t>(kOutputClassTensor),
767                                           context->getOutputShape(kOutputClassTensor),
768                                           context->getOutputBuffer<int32_t>(kOutputBatchesTensor),
769                                           context->getOutputShape(kOutputBatchesTensor), context);
770         }
771         case OperandType::TENSOR_QUANT8_ASYMM: {
772             return boxWithNmsLimitQuant(context->getInputBuffer<uint8_t>(kScoreTensor),
773                                         context->getInputShape(kScoreTensor),
774                                         context->getInputBuffer<uint16_t>(kRoiTensor),
775                                         context->getInputShape(kRoiTensor),
776                                         context->getInputBuffer<int32_t>(kBatchesTensor),
777                                         context->getInputShape(kBatchesTensor),
778                                         context->getInputValue<float>(kScoreThresholdScalar),
779                                         context->getInputValue<int32_t>(kMaxNumDetectionScalar),
780                                         context->getInputValue<int32_t>(kNmsKernelScalar),
781                                         context->getInputValue<float>(kIoUThresholdScalar),
782                                         context->getInputValue<float>(kSigmaScalar),
783                                         context->getInputValue<float>(kNmsScoreThresholdScalar),
784                                         context->getOutputBuffer<uint8_t>(kOutputScoreTensor),
785                                         context->getOutputShape(kOutputScoreTensor),
786                                         context->getOutputBuffer<uint16_t>(kOutputRoiTensor),
787                                         context->getOutputShape(kOutputRoiTensor),
788                                         context->getOutputBuffer<int32_t>(kOutputClassTensor),
789                                         context->getOutputShape(kOutputClassTensor),
790                                         context->getOutputBuffer<int32_t>(kOutputBatchesTensor),
791                                         context->getOutputShape(kOutputBatchesTensor), context);
792         }
793         default:
794             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
795     }
796 }
797 
798 }  // namespace box_with_nms_limit
799 
800 namespace generate_proposals {
801 
802 constexpr char kOperationName[] = "GENERATE_PROPOSALS";
803 
804 constexpr uint32_t kNumInputs = 11;
805 constexpr uint32_t kScoreTensor = 0;
806 constexpr uint32_t kDeltaTensor = 1;
807 constexpr uint32_t kAnchorTensor = 2;
808 constexpr uint32_t kImageInfoTensor = 3;
809 constexpr uint32_t kHeightStrideSalar = 4;
810 constexpr uint32_t kWidthStrideScalar = 5;
811 constexpr uint32_t kPreNmsMaxScalar = 6;
812 constexpr uint32_t kPostNmsMaxScalar = 7;
813 constexpr uint32_t kIoUThresholdScalar = 8;
814 constexpr uint32_t kMinSizeScalar = 9;
815 constexpr uint32_t kLayoutScalar = 10;
816 
817 constexpr uint32_t kNumOutputs = 3;
818 constexpr uint32_t kOutputScoreTensor = 0;
819 constexpr uint32_t kOutputRoiTensor = 1;
820 constexpr uint32_t kOutputBatchesTensor = 2;
821 
822 namespace {
823 
filterBoxes(const float * roiBase,const float * imageInfoBase,float minSize,std::vector<uint32_t> * select)824 void filterBoxes(const float* roiBase, const float* imageInfoBase, float minSize,
825                  std::vector<uint32_t>* select) {
826     const uint32_t kRoiDim = 4;
827     uint32_t i = 0;
828     for (uint32_t j = 0; j < select->size(); j++) {
829         const float* roiInfo = roiBase + (*select)[j] * kRoiDim;
830         float roiWidth, roiHeight, xRoiCenter, yRoiCenter;
831         roiWidth = roiInfo[2] - roiInfo[0];
832         roiHeight = roiInfo[3] - roiInfo[1];
833         xRoiCenter = roiInfo[0] + roiWidth / 2.0f;
834         yRoiCenter = roiInfo[1] + roiHeight / 2.0f;
835         if (roiWidth > minSize && roiHeight > minSize && xRoiCenter < imageInfoBase[1] &&
836             yRoiCenter < imageInfoBase[0]) {
837             (*select)[i++] = (*select)[j];
838         }
839     }
840     select->resize(i);
841 }
842 
generateProposalsNhwcFloat32Compute(const float * scoresData,const Shape & scoresShape,const float * bboxDeltasData,const Shape & bboxDeltasShape,const float * anchorsData,const Shape & anchorsShape,const float * imageInfoData,const Shape & imageInfoShape,float heightStride,float widthStride,int32_t preNmsTopN,int32_t postNmsTopN,float iouThreshold,float minSize,std::vector<float> * scoresOutData,std::vector<float> * roiOutData,std::vector<int32_t> * batchesOutData)843 bool generateProposalsNhwcFloat32Compute(const float* scoresData, const Shape& scoresShape,
844                                          const float* bboxDeltasData, const Shape& bboxDeltasShape,
845                                          const float* anchorsData, const Shape& anchorsShape,
846                                          const float* imageInfoData, const Shape& imageInfoShape,
847                                          float heightStride, float widthStride, int32_t preNmsTopN,
848                                          int32_t postNmsTopN, float iouThreshold, float minSize,
849                                          std::vector<float>* scoresOutData,
850                                          std::vector<float>* roiOutData,
851                                          std::vector<int32_t>* batchesOutData) {
852     const uint32_t kRoiDim = 4;
853     uint32_t numBatches = getSizeOfDimension(scoresShape, 0);
854     uint32_t height = getSizeOfDimension(scoresShape, 1);
855     uint32_t width = getSizeOfDimension(scoresShape, 2);
856     uint32_t numAnchors = getSizeOfDimension(scoresShape, 3);
857     uint32_t imageInfoLength = getSizeOfDimension(imageInfoShape, 1);
858 
859     uint32_t batchSize = height * width * numAnchors;
860     uint32_t roiBufferSize = batchSize * kRoiDim;
861     std::vector<float> roiBuffer(roiBufferSize);
862     std::vector<float> roiTransformedBuffer(roiBufferSize);
863     scoresOutData->clear();
864     roiOutData->clear();
865     batchesOutData->clear();
866 
867     // Compute the roi region for each anchor.
868     float* roiBase = roiBuffer.data();
869     for (uint32_t h = 0; h < height; h++) {
870         float hShift = h * heightStride;
871         for (uint32_t w = 0; w < width; w++) {
872             const float* anchorsBase = anchorsData;
873             float wShift = w * widthStride;
874             for (uint32_t a = 0; a < numAnchors; a++, roiBase += kRoiDim, anchorsBase += kRoiDim) {
875                 roiBase[0] = anchorsBase[0] + wShift;
876                 roiBase[1] = anchorsBase[1] + hShift;
877                 roiBase[2] = anchorsBase[2] + wShift;
878                 roiBase[3] = anchorsBase[3] + hShift;
879             }
880         }
881     }
882 
883     const float* scoresBase = scoresData;
884     const float* bboxDeltasBase = bboxDeltasData;
885     const float* imageInfoBase = imageInfoData;
886     // Need to fake some data to satisfy bboxTransform.
887     Shape tempRoiShape = anchorsShape;
888     tempRoiShape.dimensions = {batchSize, kRoiDim};
889     Shape tempBBoxDeltasShape = bboxDeltasShape;
890     tempBBoxDeltasShape.dimensions = {batchSize, kRoiDim};
891     std::vector<int32_t> tempBatchSplitData(batchSize, 0);
892     Shape tempbatchSplitShape = {.dimensions = {batchSize}};
893     Shape tempImageInfoShape = imageInfoShape;
894     tempImageInfoShape.dimensions = {1, imageInfoLength};
895 
896     for (uint32_t b = 0; b < numBatches; b++) {
897         // Apply bboxDeltas to anchor locations.
898         float tempImageInfo[] = {imageInfoBase[0], imageInfoBase[1]};
899         if (!bboxTransformFloat32(roiBuffer.data(), tempRoiShape, bboxDeltasBase,
900                                   tempBBoxDeltasShape, tempBatchSplitData.data(),
901                                   tempbatchSplitShape, tempImageInfo, tempImageInfoShape,
902                                   roiTransformedBuffer.data(), tempRoiShape)) {
903             LOG(ERROR) << "BBoxTransform step failed in GENERATE_PROPOSALS op.";
904             return false;
905         }
906 
907         // Find the top preNmsTopN scores.
908         std::vector<uint32_t> select(batchSize);
909         std::iota(select.begin(), select.end(), 0);
910         if (preNmsTopN > 0 && preNmsTopN < select.size()) {
911             std::sort(select.begin(), select.end(),
912                       [&scoresBase](const uint32_t lhs, const uint32_t rhs) {
913                           return scoresBase[lhs] > scoresBase[rhs];
914                       });
915             select.resize(preNmsTopN);
916         }
917 
918         // Filter boxes, disgard regions with height or width < minSize.
919         filterBoxes(roiTransformedBuffer.data(), imageInfoBase, minSize, &select);
920 
921         // Apply hard NMS.
922         uint32_t* selectEnd = box_with_nms_limit::hardNmsSingleClass(
923                 scoresBase, iouThreshold, postNmsTopN,
924                 [&roiTransformedBuffer](uint32_t ind) {
925                     return roiTransformedBuffer.data() + ind * kRoiDim;
926                 },
927                 select.data(), select.size());
928         uint32_t selectSize = selectEnd - select.data();
929         select.resize(selectSize);
930 
931         // Write output.
932         for (auto i : select) {
933             roiOutData->insert(roiOutData->end(), roiTransformedBuffer.begin() + i * kRoiDim,
934                                roiTransformedBuffer.begin() + (i + 1) * kRoiDim);
935             scoresOutData->push_back(scoresBase[i]);
936             batchesOutData->push_back(b);
937         }
938         scoresBase += batchSize;
939         bboxDeltasBase += roiBufferSize;
940         imageInfoBase += imageInfoLength;
941     }
942     return true;
943 }
944 
generateProposalsFloat32Compute(const float * scoresData,const Shape & scoresShape,const float * bboxDeltasData,const Shape & bboxDeltasShape,const float * anchorsData,const Shape & anchorsShape,const float * imageInfoData,const Shape & imageInfoShape,float heightStride,float widthStride,int32_t preNmsTopN,int32_t postNmsTopN,float iouThreshold,float minSize,bool useNchw,std::vector<float> * scoresOutData,std::vector<float> * roiOutData,std::vector<int32_t> * batchesOutData)945 bool generateProposalsFloat32Compute(const float* scoresData, const Shape& scoresShape,
946                                      const float* bboxDeltasData, const Shape& bboxDeltasShape,
947                                      const float* anchorsData, const Shape& anchorsShape,
948                                      const float* imageInfoData, const Shape& imageInfoShape,
949                                      float heightStride, float widthStride, int32_t preNmsTopN,
950                                      int32_t postNmsTopN, float iouThreshold, float minSize,
951                                      bool useNchw, std::vector<float>* scoresOutData,
952                                      std::vector<float>* roiOutData,
953                                      std::vector<int32_t>* batchesOutData) {
954     InputWithLayout<float> score_nhwc(useNchw), delta_nhwc(useNchw);
955     NN_RET_CHECK(score_nhwc.initialize(scoresData, scoresShape));
956     NN_RET_CHECK(delta_nhwc.initialize(bboxDeltasData, bboxDeltasShape));
957     return generateProposalsNhwcFloat32Compute(
958             score_nhwc.getNhwcBuffer(), score_nhwc.getNhwcShape(), delta_nhwc.getNhwcBuffer(),
959             delta_nhwc.getNhwcShape(), anchorsData, anchorsShape, imageInfoData, imageInfoShape,
960             heightStride, widthStride, preNmsTopN, postNmsTopN, iouThreshold, minSize,
961             scoresOutData, roiOutData, batchesOutData);
962 }
963 
generateProposalsFloat32(const float * scoresData,const Shape & scoresShape,const float * bboxDeltasData,const Shape & bboxDeltasShape,const float * anchorsData,const Shape & anchorsShape,const float * imageInfoData,const Shape & imageInfoShape,float heightStride,float widthStride,int32_t preNmsTopN,int32_t postNmsTopN,float iouThreshold,float minSize,bool useNchw,IOperationExecutionContext * context)964 bool generateProposalsFloat32(const float* scoresData, const Shape& scoresShape,
965                               const float* bboxDeltasData, const Shape& bboxDeltasShape,
966                               const float* anchorsData, const Shape& anchorsShape,
967                               const float* imageInfoData, const Shape& imageInfoShape,
968                               float heightStride, float widthStride, int32_t preNmsTopN,
969                               int32_t postNmsTopN, float iouThreshold, float minSize, bool useNchw,
970                               IOperationExecutionContext* context) {
971     std::vector<float> scoresOut_float32, roiOut_float32;
972     std::vector<int32_t> batchesOut;
973     NN_RET_CHECK(generateProposalsFloat32Compute(
974             scoresData, scoresShape, bboxDeltasData, bboxDeltasShape, anchorsData, anchorsShape,
975             imageInfoData, imageInfoShape, heightStride, widthStride, preNmsTopN, postNmsTopN,
976             iouThreshold, minSize, useNchw, &scoresOut_float32, &roiOut_float32, &batchesOut));
977 
978     // Set output dimensions.
979     uint32_t numOutRois = scoresOut_float32.size();
980     if (numOutRois == 0) return true;
981     Shape scoresOutShape = context->getOutputShape(kOutputScoreTensor);
982     scoresOutShape.dimensions = {numOutRois};
983     NN_RET_CHECK(context->setOutputShape(kOutputScoreTensor, scoresOutShape));
984     Shape roiOutShape = context->getOutputShape(kOutputRoiTensor);
985     roiOutShape.dimensions = {numOutRois, 4};
986     NN_RET_CHECK(context->setOutputShape(kOutputRoiTensor, roiOutShape));
987     Shape batchesOutShape = context->getOutputShape(kOutputBatchesTensor);
988     batchesOutShape.dimensions = {numOutRois};
989     NN_RET_CHECK(context->setOutputShape(kOutputBatchesTensor, batchesOutShape));
990 
991     // Write outputs.
992     float* scoresOutData = context->getOutputBuffer<float>(kOutputScoreTensor);
993     for (uint32_t i = 0; i < scoresOut_float32.size(); i++) {
994         scoresOutData[i] = scoresOut_float32[i];
995     }
996     float* roiOutData = context->getOutputBuffer<float>(kOutputRoiTensor);
997     for (uint32_t i = 0; i < roiOut_float32.size(); i++) {
998         roiOutData[i] = roiOut_float32[i];
999     }
1000     int32_t* batchesOutData = context->getOutputBuffer<int32_t>(kOutputBatchesTensor);
1001     for (uint32_t i = 0; i < batchesOut.size(); i++) {
1002         batchesOutData[i] = batchesOut[i];
1003     }
1004     return true;
1005 }
1006 
generateProposalsFloat16(const _Float16 * scoresData,const Shape & scoresShape,const _Float16 * bboxDeltasData,const Shape & bboxDeltasShape,const _Float16 * anchorsData,const Shape & anchorsShape,const _Float16 * imageInfoData,const Shape & imageInfoShape,float heightStride,float widthStride,int32_t preNmsTopN,int32_t postNmsTopN,float iouThreshold,float minSize,bool useNchw,IOperationExecutionContext * context)1007 bool generateProposalsFloat16(const _Float16* scoresData, const Shape& scoresShape,
1008                               const _Float16* bboxDeltasData, const Shape& bboxDeltasShape,
1009                               const _Float16* anchorsData, const Shape& anchorsShape,
1010                               const _Float16* imageInfoData, const Shape& imageInfoShape,
1011                               float heightStride, float widthStride, int32_t preNmsTopN,
1012                               int32_t postNmsTopN, float iouThreshold, float minSize, bool useNchw,
1013                               IOperationExecutionContext* context) {
1014     std::vector<float> score_float32(getNumberOfElements(scoresShape));
1015     convertFloat16ToFloat32(scoresData, &score_float32);
1016     std::vector<float> delta_float32(getNumberOfElements(bboxDeltasShape));
1017     convertFloat16ToFloat32(bboxDeltasData, &delta_float32);
1018     std::vector<float> anchors_float32(getNumberOfElements(anchorsShape));
1019     convertFloat16ToFloat32(anchorsData, &anchors_float32);
1020     std::vector<float> imageInfo_float32(getNumberOfElements(imageInfoShape));
1021     convertFloat16ToFloat32(imageInfoData, &imageInfo_float32);
1022     std::vector<float> scoresOut_float32, roiOut_float32;
1023     std::vector<int32_t> batchesOut;
1024     NN_RET_CHECK(generateProposalsFloat32Compute(
1025             score_float32.data(), scoresShape, delta_float32.data(), bboxDeltasShape,
1026             anchors_float32.data(), anchorsShape, imageInfo_float32.data(), imageInfoShape,
1027             heightStride, widthStride, preNmsTopN, postNmsTopN, iouThreshold, minSize, useNchw,
1028             &scoresOut_float32, &roiOut_float32, &batchesOut));
1029 
1030     // Set output dimensions.
1031     uint32_t numOutRois = scoresOut_float32.size();
1032     if (numOutRois == 0) return true;
1033     Shape scoresOutShape = context->getOutputShape(kOutputScoreTensor);
1034     scoresOutShape.dimensions = {numOutRois};
1035     NN_RET_CHECK(context->setOutputShape(kOutputScoreTensor, scoresOutShape));
1036     Shape roiOutShape = context->getOutputShape(kOutputRoiTensor);
1037     roiOutShape.dimensions = {numOutRois, 4};
1038     NN_RET_CHECK(context->setOutputShape(kOutputRoiTensor, roiOutShape));
1039     Shape batchesOutShape = context->getOutputShape(kOutputBatchesTensor);
1040     batchesOutShape.dimensions = {numOutRois};
1041     NN_RET_CHECK(context->setOutputShape(kOutputBatchesTensor, batchesOutShape));
1042 
1043     // Write outputs.
1044     _Float16* scoresOutData = context->getOutputBuffer<_Float16>(kOutputScoreTensor);
1045     convertFloat32ToFloat16(scoresOut_float32, scoresOutData);
1046     _Float16* roiOutData = context->getOutputBuffer<_Float16>(kOutputRoiTensor);
1047     convertFloat32ToFloat16(roiOut_float32, roiOutData);
1048     int32_t* batchesOutData = context->getOutputBuffer<int32_t>(kOutputBatchesTensor);
1049     for (uint32_t i = 0; i < batchesOut.size(); i++) {
1050         batchesOutData[i] = batchesOut[i];
1051     }
1052     return true;
1053 }
1054 
generateProposalsQuant(const uint8_t * scoresData,const Shape & scoresShape,const uint8_t * bboxDeltasData,const Shape & bboxDeltasShape,const int16_t * anchorsData,const Shape & anchorsShape,const uint16_t * imageInfoData,const Shape & imageInfoShape,float heightStride,float widthStride,int32_t preNmsTopN,int32_t postNmsTopN,float iouThreshold,float minSize,bool useNchw,IOperationExecutionContext * context)1055 bool generateProposalsQuant(const uint8_t* scoresData, const Shape& scoresShape,
1056                             const uint8_t* bboxDeltasData, const Shape& bboxDeltasShape,
1057                             const int16_t* anchorsData, const Shape& anchorsShape,
1058                             const uint16_t* imageInfoData, const Shape& imageInfoShape,
1059                             float heightStride, float widthStride, int32_t preNmsTopN,
1060                             int32_t postNmsTopN, float iouThreshold, float minSize, bool useNchw,
1061                             IOperationExecutionContext* context) {
1062     std::vector<float> score_float32(getNumberOfElements(scoresShape));
1063     convertQuantToFloat32(scoresData, scoresShape.scale, scoresShape.offset, &score_float32);
1064     std::vector<float> delta_float32(getNumberOfElements(bboxDeltasShape));
1065     convertQuantToFloat32(bboxDeltasData, bboxDeltasShape.scale, bboxDeltasShape.offset,
1066                           &delta_float32);
1067     std::vector<float> anchors_float32(getNumberOfElements(anchorsShape));
1068     convertQuantToFloat32(anchorsData, anchorsShape.scale, anchorsShape.offset, &anchors_float32);
1069     std::vector<float> imageInfo_float32(getNumberOfElements(imageInfoShape));
1070     convertQuantToFloat32(imageInfoData, imageInfoShape.scale, imageInfoShape.offset,
1071                           &imageInfo_float32);
1072     std::vector<float> scoresOut_float32, roiOut_float32;
1073     std::vector<int32_t> batchesOut;
1074     NN_RET_CHECK(generateProposalsFloat32Compute(
1075             score_float32.data(), scoresShape, delta_float32.data(), bboxDeltasShape,
1076             anchors_float32.data(), anchorsShape, imageInfo_float32.data(), imageInfoShape,
1077             heightStride, widthStride, preNmsTopN, postNmsTopN, iouThreshold, minSize, useNchw,
1078             &scoresOut_float32, &roiOut_float32, &batchesOut));
1079 
1080     // Set output dimensions.
1081     uint32_t numOutRois = scoresOut_float32.size();
1082     if (numOutRois == 0) return true;
1083     Shape scoresOutShape = context->getOutputShape(kOutputScoreTensor);
1084     scoresOutShape.dimensions = {numOutRois};
1085     NN_RET_CHECK(context->setOutputShape(kOutputScoreTensor, scoresOutShape));
1086     Shape roiOutShape = context->getOutputShape(kOutputRoiTensor);
1087     roiOutShape.dimensions = {numOutRois, 4};
1088     NN_RET_CHECK(context->setOutputShape(kOutputRoiTensor, roiOutShape));
1089     Shape batchesOutShape = context->getOutputShape(kOutputBatchesTensor);
1090     batchesOutShape.dimensions = {numOutRois};
1091     NN_RET_CHECK(context->setOutputShape(kOutputBatchesTensor, batchesOutShape));
1092 
1093     // Write outputs.
1094     uint8_t* scoresOutData = context->getOutputBuffer<uint8_t>(kOutputScoreTensor);
1095     convertFloat32ToQuant(scoresOut_float32, scoresOutShape.scale, scoresOutShape.offset,
1096                           scoresOutData);
1097     uint16_t* roiOutData = context->getOutputBuffer<uint16_t>(kOutputRoiTensor);
1098     convertFloat32ToQuant(roiOut_float32, roiOutShape.scale, roiOutShape.offset, roiOutData);
1099     int32_t* batchesOutData = context->getOutputBuffer<int32_t>(kOutputBatchesTensor);
1100     for (uint32_t i = 0; i < batchesOut.size(); i++) {
1101         batchesOutData[i] = batchesOut[i];
1102     }
1103     return true;
1104 }
1105 
1106 }  // namespace
1107 
validate(const IOperationValidationContext * context)1108 bool validate(const IOperationValidationContext* context) {
1109     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
1110     NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
1111     std::vector<OperandType> inExpectedTypes;
1112     std::vector<OperandType> outExpectedTypes;
1113     auto inputType = context->getInputType(kScoreTensor);
1114     if (inputType == OperandType::TENSOR_FLOAT16) {
1115         inExpectedTypes = {OperandType::TENSOR_FLOAT16,
1116                            OperandType::TENSOR_FLOAT16,
1117                            OperandType::TENSOR_FLOAT16,
1118                            OperandType::TENSOR_FLOAT16,
1119                            OperandType::FLOAT16,
1120                            OperandType::FLOAT16,
1121                            OperandType::INT32,
1122                            OperandType::INT32,
1123                            OperandType::FLOAT16,
1124                            OperandType::FLOAT16,
1125                            OperandType::BOOL};
1126         outExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
1127                             OperandType::TENSOR_INT32};
1128     } else if (inputType == OperandType::TENSOR_FLOAT32) {
1129         inExpectedTypes = {OperandType::TENSOR_FLOAT32,
1130                            OperandType::TENSOR_FLOAT32,
1131                            OperandType::TENSOR_FLOAT32,
1132                            OperandType::TENSOR_FLOAT32,
1133                            OperandType::FLOAT32,
1134                            OperandType::FLOAT32,
1135                            OperandType::INT32,
1136                            OperandType::INT32,
1137                            OperandType::FLOAT32,
1138                            OperandType::FLOAT32,
1139                            OperandType::BOOL};
1140         outExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
1141                             OperandType::TENSOR_INT32};
1142     } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
1143         inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM,
1144                            OperandType::TENSOR_QUANT8_ASYMM,
1145                            OperandType::TENSOR_QUANT16_SYMM,
1146                            OperandType::TENSOR_QUANT16_ASYMM,
1147                            OperandType::FLOAT32,
1148                            OperandType::FLOAT32,
1149                            OperandType::INT32,
1150                            OperandType::INT32,
1151                            OperandType::FLOAT32,
1152                            OperandType::FLOAT32,
1153                            OperandType::BOOL};
1154         outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_QUANT16_ASYMM,
1155                             OperandType::TENSOR_INT32};
1156     } else {
1157         NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
1158     }
1159     NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
1160     NN_RET_CHECK(validateOutputTypes(context, outExpectedTypes));
1161     return validateHalVersion(context, HalVersion::V1_2);
1162 }
1163 
prepare(IOperationExecutionContext * context)1164 bool prepare(IOperationExecutionContext* context) {
1165     bool useNchw = context->getInputValue<bool>(kLayoutScalar);
1166     Shape scoreShape = context->getInputShape(kScoreTensor);
1167     Shape bboxDeltasShape = context->getInputShape(kDeltaTensor);
1168     Shape anchorsShape = context->getInputShape(kAnchorTensor);
1169     Shape imageInfoDataShape = context->getInputShape(kImageInfoTensor);
1170     Shape outputScoreShape = context->getOutputShape(kOutputScoreTensor);
1171     Shape outputRoiShape = context->getOutputShape(kOutputRoiTensor);
1172     Shape outputBatchSplitShape = context->getOutputShape(kOutputBatchesTensor);
1173 
1174     NN_RET_CHECK_EQ(getNumberOfDimensions(scoreShape), 4);
1175     NN_RET_CHECK_EQ(getNumberOfDimensions(bboxDeltasShape), 4);
1176     NN_RET_CHECK_EQ(getNumberOfDimensions(anchorsShape), 2);
1177     NN_RET_CHECK_EQ(getNumberOfDimensions(imageInfoDataShape), 2);
1178 
1179     const uint32_t kRoiDim = 4;
1180     uint32_t numBatches = getSizeOfDimension(scoreShape, 0);
1181     uint32_t height = getSizeOfDimension(scoreShape, useNchw ? 2 : 1);
1182     uint32_t width = getSizeOfDimension(scoreShape, useNchw ? 3 : 2);
1183     uint32_t numAnchors = getSizeOfDimension(scoreShape, useNchw ? 1 : 3);
1184 
1185     NN_RET_CHECK_EQ(getSizeOfDimension(bboxDeltasShape, 0), numBatches);
1186     NN_RET_CHECK_EQ(getSizeOfDimension(bboxDeltasShape, useNchw ? 2 : 1), height);
1187     NN_RET_CHECK_EQ(getSizeOfDimension(bboxDeltasShape, useNchw ? 3 : 2), width);
1188     NN_RET_CHECK_EQ(getSizeOfDimension(bboxDeltasShape, useNchw ? 1 : 3), numAnchors * kRoiDim);
1189     NN_RET_CHECK_EQ(getSizeOfDimension(imageInfoDataShape, 0), numBatches);
1190     NN_RET_CHECK_EQ(getSizeOfDimension(imageInfoDataShape, 1), 2);
1191     NN_RET_CHECK_EQ(getSizeOfDimension(anchorsShape, 0), numAnchors);
1192     NN_RET_CHECK_EQ(getSizeOfDimension(anchorsShape, 1), kRoiDim);
1193 
1194     if (scoreShape.type == OperandType::TENSOR_QUANT8_ASYMM) {
1195         NN_RET_CHECK_EQ(anchorsShape.scale, 0.125f);
1196         NN_RET_CHECK_EQ(imageInfoDataShape.scale, 0.125f);
1197         NN_RET_CHECK_EQ(imageInfoDataShape.offset, 0);
1198     }
1199 
1200     outputScoreShape.type = scoreShape.type;
1201     outputScoreShape.dimensions = {0};
1202     outputScoreShape.scale = scoreShape.scale;
1203     outputScoreShape.offset = scoreShape.offset;
1204     NN_RET_CHECK(context->setOutputShape(kOutputScoreTensor, outputScoreShape));
1205 
1206     outputRoiShape.dimensions = {0, 4};
1207     if (scoreShape.type == OperandType::TENSOR_QUANT8_ASYMM) {
1208         outputRoiShape.scale = 0.125f;
1209         outputRoiShape.offset = 0;
1210     }
1211     NN_RET_CHECK(context->setOutputShape(kOutputRoiTensor, outputRoiShape));
1212 
1213     outputBatchSplitShape.dimensions = {0};
1214     NN_RET_CHECK(context->setOutputShape(kOutputBatchesTensor, outputBatchSplitShape));
1215     return true;
1216 }
1217 
execute(IOperationExecutionContext * context)1218 bool execute(IOperationExecutionContext* context) {
1219     NNTRACE_TRANS("generateProposals");
1220     switch (context->getInputType(kScoreTensor)) {
1221         case OperandType::TENSOR_FLOAT16: {
1222             return generateProposalsFloat16(context->getInputBuffer<_Float16>(kScoreTensor),
1223                                             context->getInputShape(kScoreTensor),
1224                                             context->getInputBuffer<_Float16>(kDeltaTensor),
1225                                             context->getInputShape(kDeltaTensor),
1226                                             context->getInputBuffer<_Float16>(kAnchorTensor),
1227                                             context->getInputShape(kAnchorTensor),
1228                                             context->getInputBuffer<_Float16>(kImageInfoTensor),
1229                                             context->getInputShape(kImageInfoTensor),
1230                                             context->getInputValue<_Float16>(kHeightStrideSalar),
1231                                             context->getInputValue<_Float16>(kWidthStrideScalar),
1232                                             context->getInputValue<int32_t>(kPreNmsMaxScalar),
1233                                             context->getInputValue<int32_t>(kPostNmsMaxScalar),
1234                                             context->getInputValue<_Float16>(kIoUThresholdScalar),
1235                                             context->getInputValue<_Float16>(kMinSizeScalar),
1236                                             context->getInputValue<bool>(kLayoutScalar), context);
1237         }
1238         case OperandType::TENSOR_FLOAT32: {
1239             return generateProposalsFloat32(context->getInputBuffer<float>(kScoreTensor),
1240                                             context->getInputShape(kScoreTensor),
1241                                             context->getInputBuffer<float>(kDeltaTensor),
1242                                             context->getInputShape(kDeltaTensor),
1243                                             context->getInputBuffer<float>(kAnchorTensor),
1244                                             context->getInputShape(kAnchorTensor),
1245                                             context->getInputBuffer<float>(kImageInfoTensor),
1246                                             context->getInputShape(kImageInfoTensor),
1247                                             context->getInputValue<float>(kHeightStrideSalar),
1248                                             context->getInputValue<float>(kWidthStrideScalar),
1249                                             context->getInputValue<int32_t>(kPreNmsMaxScalar),
1250                                             context->getInputValue<int32_t>(kPostNmsMaxScalar),
1251                                             context->getInputValue<float>(kIoUThresholdScalar),
1252                                             context->getInputValue<float>(kMinSizeScalar),
1253                                             context->getInputValue<bool>(kLayoutScalar), context);
1254         }
1255         case OperandType::TENSOR_QUANT8_ASYMM: {
1256             return generateProposalsQuant(context->getInputBuffer<uint8_t>(kScoreTensor),
1257                                           context->getInputShape(kScoreTensor),
1258                                           context->getInputBuffer<uint8_t>(kDeltaTensor),
1259                                           context->getInputShape(kDeltaTensor),
1260                                           context->getInputBuffer<int16_t>(kAnchorTensor),
1261                                           context->getInputShape(kAnchorTensor),
1262                                           context->getInputBuffer<uint16_t>(kImageInfoTensor),
1263                                           context->getInputShape(kImageInfoTensor),
1264                                           context->getInputValue<float>(kHeightStrideSalar),
1265                                           context->getInputValue<float>(kWidthStrideScalar),
1266                                           context->getInputValue<int32_t>(kPreNmsMaxScalar),
1267                                           context->getInputValue<int32_t>(kPostNmsMaxScalar),
1268                                           context->getInputValue<float>(kIoUThresholdScalar),
1269                                           context->getInputValue<float>(kMinSizeScalar),
1270                                           context->getInputValue<bool>(kLayoutScalar), context);
1271         }
1272         default:
1273             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
1274     }
1275 }
1276 
1277 }  // namespace generate_proposals
1278 
1279 namespace detection_postprocess {
1280 
1281 constexpr char kOperationName[] = "DETECTION_POSTPROCESS";
1282 
1283 constexpr uint32_t kNumInputs = 14;
1284 constexpr uint32_t kScoreTensor = 0;
1285 constexpr uint32_t kDeltaTensor = 1;
1286 constexpr uint32_t kAnchorTensor = 2;
1287 constexpr uint32_t kScaleYScalar = 3;
1288 constexpr uint32_t kScaleXScalar = 4;
1289 constexpr uint32_t kScaleHScalar = 5;
1290 constexpr uint32_t kScaleWScalar = 6;
1291 constexpr uint32_t kUseRegularNmsScalar = 7;
1292 constexpr uint32_t kMaxNumDetectionScalar = 8;
1293 constexpr uint32_t kMaxClassesPerDetectionScalar = 9;
1294 constexpr uint32_t kMaxNumDetectionPerClassScalar = 10;
1295 constexpr uint32_t kScoreThresholdScalar = 11;
1296 constexpr uint32_t kIoUThresholdScalar = 12;
1297 constexpr uint32_t kIsBGInLabelScalar = 13;
1298 
1299 constexpr uint32_t kNumOutputs = 4;
1300 constexpr uint32_t kOutputScoreTensor = 0;
1301 constexpr uint32_t kOutputRoiTensor = 1;
1302 constexpr uint32_t kOutputClassTensor = 2;
1303 constexpr uint32_t kOutputDetectionTensor = 3;
1304 
1305 namespace {
1306 
detectionPostprocessFloat32(const float * scoreData,const Shape & scoreShape,const float * deltaData,const Shape & deltaShape,const float * anchorData,const Shape & anchorShape,float scaleY,float scaleX,float scaleH,float scaleW,bool useRegularNms,int32_t maxNumDetections,int32_t maxClassesPerDetection,int32_t maxNumDetectionsPerClass,float iouThreshold,float scoreThreshold,bool isBGInLabel,float * scoreOutData,const Shape & scoreOutShape,float * roiOutData,const Shape & roiOutShape,int32_t * classOutData,const Shape & classOutShape,int32_t * detectionOutData,const Shape & detectionOutShape)1307 bool detectionPostprocessFloat32(
1308         const float* scoreData, const Shape& scoreShape, const float* deltaData,
1309         const Shape& deltaShape, const float* anchorData, const Shape& anchorShape, float scaleY,
1310         float scaleX, float scaleH, float scaleW, bool useRegularNms, int32_t maxNumDetections,
1311         int32_t maxClassesPerDetection, int32_t maxNumDetectionsPerClass, float iouThreshold,
1312         float scoreThreshold, bool isBGInLabel, float* scoreOutData, const Shape& scoreOutShape,
1313         float* roiOutData, const Shape& roiOutShape, int32_t* classOutData,
1314         const Shape& classOutShape, int32_t* detectionOutData, const Shape& detectionOutShape) {
1315     const uint32_t kRoiDim = 4;
1316     uint32_t numBatches = getSizeOfDimension(scoreShape, 0);
1317     uint32_t numAnchors = getSizeOfDimension(scoreShape, 1);
1318     uint32_t numClasses = getSizeOfDimension(scoreShape, 2);
1319     uint32_t lengthBoxEncoding = getSizeOfDimension(deltaShape, 2);
1320     uint32_t numOutDetection = getSizeOfDimension(scoreOutShape, 1);
1321 
1322     memset(scoreOutData, 0, getNumberOfElements(scoreOutShape) * sizeof(float));
1323     memset(roiOutData, 0, getNumberOfElements(roiOutShape) * sizeof(float));
1324     memset(classOutData, 0, getNumberOfElements(classOutShape) * sizeof(int32_t));
1325     memset(detectionOutData, 0, getNumberOfElements(detectionOutShape) * sizeof(int32_t));
1326 
1327     const float* scoreBase = scoreData;
1328     const float* deltaBase = deltaData;
1329     float* scoreOutBase = scoreOutData;
1330     float* roiOutBase = roiOutData;
1331     int32_t* classOutBase = classOutData;
1332     std::vector<float> roiBuffer(numAnchors * kRoiDim);
1333     std::vector<float> scoreBuffer(numAnchors);
1334     for (uint32_t b = 0; b < numBatches; b++) {
1335         const float* anchorBase = anchorData;
1336         for (uint32_t a = 0; a < numAnchors; a++) {
1337             float yCtr = anchorBase[0] + anchorBase[2] * deltaBase[0] / scaleY;
1338             float xCtr = anchorBase[1] + anchorBase[3] * deltaBase[1] / scaleX;
1339             float hHalf = anchorBase[2] * std::exp(deltaBase[2] / scaleH) * 0.5f;
1340             float wHalf = anchorBase[3] * std::exp(deltaBase[3] / scaleW) * 0.5f;
1341             roiBuffer[a * kRoiDim] = yCtr - hHalf;
1342             roiBuffer[a * kRoiDim + 1] = xCtr - wHalf;
1343             roiBuffer[a * kRoiDim + 2] = yCtr + hHalf;
1344             roiBuffer[a * kRoiDim + 3] = xCtr + wHalf;
1345             anchorBase += kRoiDim;
1346             deltaBase += lengthBoxEncoding;
1347         }
1348 
1349         if (useRegularNms) {
1350             std::vector<uint32_t> select;
1351             box_with_nms_limit::hardNmsMultiClass(
1352                     scoreBase, numClasses, numAnchors, scoreThreshold, iouThreshold,
1353                     maxNumDetections, maxNumDetectionsPerClass,
1354                     [&roiBuffer, numClasses](uint32_t ind) {
1355                         return roiBuffer.data() + (ind / numClasses) * kRoiDim;
1356                     },
1357                     &select);
1358             for (uint32_t i = 0; i < select.size(); i++) {
1359                 uint32_t ind = select[i];
1360                 scoreOutBase[i] = scoreBase[ind];
1361                 memcpy(roiOutBase + i * kRoiDim, &roiBuffer[(ind / numClasses) * kRoiDim],
1362                        kRoiDim * sizeof(float));
1363                 classOutBase[i] = (ind % numClasses) - (isBGInLabel ? 0 : 1);
1364             }
1365             *detectionOutData++ = select.size();
1366         } else {
1367             uint32_t numOutClasses = std::min<uint32_t>(numClasses - 1, maxClassesPerDetection);
1368             std::vector<float> maxScores(numAnchors);
1369             for (uint32_t a = 0; a < numAnchors; a++) {
1370                 maxScores[a] = *std::max_element(scoreBase + a * numClasses + 1,
1371                                                  scoreBase + (a + 1) * numClasses);
1372             }
1373             std::vector<uint32_t> select;
1374             for (uint32_t a = 0; a < numAnchors; a++) {
1375                 if (maxScores[a] > scoreThreshold) {
1376                     select.push_back(a);
1377                 }
1378             }
1379             uint32_t* selectEnd = box_with_nms_limit::hardNmsSingleClass(
1380                     maxScores.data(), iouThreshold, maxNumDetections,
1381                     [&roiBuffer](uint32_t ind) { return roiBuffer.data() + ind * kRoiDim; },
1382                     select.data(), select.size());
1383             select.resize(selectEnd - select.data());
1384             float* scoreOutPtr = scoreOutBase;
1385             float* roiOutPtr = roiOutBase;
1386             int32_t* classOutPtr = classOutBase;
1387             for (auto i : select) {
1388                 const float* score = scoreBase + i * numClasses;
1389                 std::vector<uint32_t> scoreInds(numClasses - 1);
1390                 std::iota(scoreInds.begin(), scoreInds.end(), 1);
1391                 std::sort(scoreInds.begin(), scoreInds.end(),
1392                           [&score](const uint32_t lhs, const uint32_t rhs) {
1393                               return score[lhs] > score[rhs];
1394                           });
1395                 for (uint32_t c = 0; c < numOutClasses; c++) {
1396                     *scoreOutPtr++ = score[scoreInds[c]];
1397                     memcpy(roiOutPtr, &roiBuffer[i * kRoiDim], kRoiDim * sizeof(float));
1398                     roiOutPtr += kRoiDim;
1399                     *classOutPtr++ = scoreInds[c] - (isBGInLabel ? 0 : 1);
1400                 }
1401             }
1402             *detectionOutData++ = select.size() * numOutClasses;
1403         }
1404         scoreBase += numAnchors * numClasses;
1405         scoreOutBase += numOutDetection;
1406         roiOutBase += numOutDetection * kRoiDim;
1407         classOutBase += numOutDetection;
1408     }
1409     return true;
1410 }
1411 
detectionPostprocessFloat16(const _Float16 * scoreData,const Shape & scoreShape,const _Float16 * deltaData,const Shape & deltaShape,const _Float16 * anchorData,const Shape & anchorShape,float scaleY,float scaleX,float scaleH,float scaleW,bool useRegularNms,int32_t maxNumDetections,int32_t maxClassesPerDetection,int32_t maxNumDetectionsPerClass,float iouThreshold,float scoreThreshold,bool isBGInLabel,_Float16 * scoreOutData,const Shape & scoreOutShape,_Float16 * roiOutData,const Shape & roiOutShape,int32_t * classOutData,const Shape & classOutShape,int32_t * detectionOutData,const Shape & detectionOutShape)1412 bool detectionPostprocessFloat16(
1413         const _Float16* scoreData, const Shape& scoreShape, const _Float16* deltaData,
1414         const Shape& deltaShape, const _Float16* anchorData, const Shape& anchorShape, float scaleY,
1415         float scaleX, float scaleH, float scaleW, bool useRegularNms, int32_t maxNumDetections,
1416         int32_t maxClassesPerDetection, int32_t maxNumDetectionsPerClass, float iouThreshold,
1417         float scoreThreshold, bool isBGInLabel, _Float16* scoreOutData, const Shape& scoreOutShape,
1418         _Float16* roiOutData, const Shape& roiOutShape, int32_t* classOutData,
1419         const Shape& classOutShape, int32_t* detectionOutData, const Shape& detectionOutShape) {
1420     std::vector<float> scores_float32(getNumberOfElements(scoreShape));
1421     convertFloat16ToFloat32(scoreData, &scores_float32);
1422     std::vector<float> delta_float32(getNumberOfElements(deltaShape));
1423     convertFloat16ToFloat32(deltaData, &delta_float32);
1424     std::vector<float> anchor_float32(getNumberOfElements(anchorShape));
1425     convertFloat16ToFloat32(anchorData, &anchor_float32);
1426     std::vector<float> outputScore_float32(getNumberOfElements(scoreOutShape));
1427     std::vector<float> outputRoi_float32(getNumberOfElements(roiOutShape));
1428     NN_RET_CHECK(detectionPostprocessFloat32(
1429             scores_float32.data(), scoreShape, delta_float32.data(), deltaShape,
1430             anchor_float32.data(), anchorShape, scaleY, scaleX, scaleH, scaleW, useRegularNms,
1431             maxNumDetections, maxClassesPerDetection, maxNumDetectionsPerClass, iouThreshold,
1432             scoreThreshold, isBGInLabel, outputScore_float32.data(), scoreOutShape,
1433             outputRoi_float32.data(), roiOutShape, classOutData, classOutShape, detectionOutData,
1434             detectionOutShape));
1435     convertFloat32ToFloat16(outputScore_float32, scoreOutData);
1436     convertFloat32ToFloat16(outputRoi_float32, roiOutData);
1437     return true;
1438 }
1439 
1440 }  // namespace
1441 
validate(const IOperationValidationContext * context)1442 bool validate(const IOperationValidationContext* context) {
1443     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
1444     NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
1445     std::vector<OperandType> inExpectedTypes;
1446     std::vector<OperandType> outExpectedTypes;
1447     auto inputType = context->getInputType(kScoreTensor);
1448     if (inputType == OperandType::TENSOR_FLOAT16) {
1449         inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
1450                            OperandType::TENSOR_FLOAT16, OperandType::FLOAT16,
1451                            OperandType::FLOAT16,        OperandType::FLOAT16,
1452                            OperandType::FLOAT16,        OperandType::BOOL,
1453                            OperandType::INT32,          OperandType::INT32,
1454                            OperandType::INT32,          OperandType::FLOAT16,
1455                            OperandType::FLOAT16,        OperandType::BOOL};
1456     } else if (inputType == OperandType::TENSOR_FLOAT32) {
1457         inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
1458                            OperandType::TENSOR_FLOAT32, OperandType::FLOAT32,
1459                            OperandType::FLOAT32,        OperandType::FLOAT32,
1460                            OperandType::FLOAT32,        OperandType::BOOL,
1461                            OperandType::INT32,          OperandType::INT32,
1462                            OperandType::INT32,          OperandType::FLOAT32,
1463                            OperandType::FLOAT32,        OperandType::BOOL};
1464     } else {
1465         NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
1466     }
1467     NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
1468     NN_RET_CHECK(validateOutputTypes(
1469             context, {inputType, inputType, OperandType::TENSOR_INT32, OperandType::TENSOR_INT32}));
1470     return validateHalVersion(context, HalVersion::V1_2);
1471 }
1472 
prepare(IOperationExecutionContext * context)1473 bool prepare(IOperationExecutionContext* context) {
1474     Shape scoreShape = context->getInputShape(kScoreTensor);
1475     Shape deltasShape = context->getInputShape(kDeltaTensor);
1476     Shape anchorsShape = context->getInputShape(kAnchorTensor);
1477     Shape outputScoreShape = context->getOutputShape(kOutputScoreTensor);
1478     Shape outputRoiShape = context->getOutputShape(kOutputRoiTensor);
1479     Shape outputClassShape = context->getOutputShape(kOutputClassTensor);
1480     Shape outputDetectionShape = context->getOutputShape(kOutputDetectionTensor);
1481 
1482     NN_RET_CHECK_EQ(getNumberOfDimensions(scoreShape), 3);
1483     NN_RET_CHECK_EQ(getNumberOfDimensions(deltasShape), 3);
1484     NN_RET_CHECK_EQ(getNumberOfDimensions(anchorsShape), 2);
1485 
1486     const uint32_t kRoiDim = 4;
1487     uint32_t numBatches = getSizeOfDimension(scoreShape, 0);
1488     uint32_t numAnchors = getSizeOfDimension(scoreShape, 1);
1489     uint32_t numClasses = getSizeOfDimension(scoreShape, 2);
1490     uint32_t lengthBoxEncoding = getSizeOfDimension(deltasShape, 2);
1491     uint32_t maxNumDetections = context->getInputValue<int32_t>(kMaxNumDetectionScalar);
1492     uint32_t maxClassesPerDetection =
1493             context->getInputValue<int32_t>(kMaxClassesPerDetectionScalar);
1494     uint32_t numOutDetections = maxNumDetections;
1495 
1496     NN_RET_CHECK_EQ(getSizeOfDimension(deltasShape, 0), numBatches);
1497     NN_RET_CHECK_EQ(getSizeOfDimension(deltasShape, 1), numAnchors);
1498     NN_RET_CHECK_EQ(getSizeOfDimension(anchorsShape, 0), numAnchors);
1499     NN_RET_CHECK_EQ(getSizeOfDimension(anchorsShape, 1), kRoiDim);
1500 
1501     if (scoreShape.type == OperandType::TENSOR_FLOAT32) {
1502         NN_RET_CHECK_GT(context->getInputValue<float>(kScaleYScalar), 0);
1503         NN_RET_CHECK_GT(context->getInputValue<float>(kScaleXScalar), 0);
1504         NN_RET_CHECK_GT(context->getInputValue<float>(kScaleHScalar), 0);
1505         NN_RET_CHECK_GT(context->getInputValue<float>(kScaleWScalar), 0);
1506         NN_RET_CHECK_GE(context->getInputValue<float>(kScoreThresholdScalar), 0);
1507         NN_RET_CHECK_GE(context->getInputValue<float>(kIoUThresholdScalar), 0);
1508     } else if (scoreShape.type == OperandType::TENSOR_FLOAT16) {
1509         NN_RET_CHECK(context->getInputValue<_Float16>(kScaleYScalar) > 0);
1510         NN_RET_CHECK(context->getInputValue<_Float16>(kScaleXScalar) > 0);
1511         NN_RET_CHECK(context->getInputValue<_Float16>(kScaleHScalar) > 0);
1512         NN_RET_CHECK(context->getInputValue<_Float16>(kScaleWScalar) > 0);
1513         NN_RET_CHECK(context->getInputValue<_Float16>(kScoreThresholdScalar) >= 0);
1514         NN_RET_CHECK(context->getInputValue<_Float16>(kIoUThresholdScalar) >= 0);
1515     } else {
1516         NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
1517     }
1518     NN_RET_CHECK_GT(numClasses, 1);
1519     NN_RET_CHECK_GE(lengthBoxEncoding, 4);
1520     NN_RET_CHECK_GT(maxNumDetections, 0);
1521     if (context->getInputValue<bool>(kUseRegularNmsScalar)) {
1522         NN_RET_CHECK_GT(context->getInputValue<int32_t>(kMaxNumDetectionPerClassScalar), 0);
1523     } else {
1524         NN_RET_CHECK_GT(maxClassesPerDetection, 0);
1525         numOutDetections *= maxClassesPerDetection;
1526     }
1527 
1528     outputScoreShape.type = scoreShape.type;
1529     outputScoreShape.dimensions = {numBatches, numOutDetections};
1530     NN_RET_CHECK(context->setOutputShape(kOutputScoreTensor, outputScoreShape));
1531 
1532     outputRoiShape.type = anchorsShape.type;
1533     outputRoiShape.dimensions = {numBatches, numOutDetections, 4};
1534     NN_RET_CHECK(context->setOutputShape(kOutputRoiTensor, outputRoiShape));
1535 
1536     outputClassShape.type = OperandType::TENSOR_INT32;
1537     outputClassShape.dimensions = {numBatches, numOutDetections};
1538     NN_RET_CHECK(context->setOutputShape(kOutputClassTensor, outputClassShape));
1539 
1540     outputDetectionShape.type = OperandType::TENSOR_INT32;
1541     outputDetectionShape.dimensions = {numBatches};
1542     NN_RET_CHECK(context->setOutputShape(kOutputDetectionTensor, outputDetectionShape));
1543     return true;
1544 }
1545 
execute(IOperationExecutionContext * context)1546 bool execute(IOperationExecutionContext* context) {
1547     NNTRACE_TRANS("detectionPostProcess");
1548     switch (context->getInputType(kScoreTensor)) {
1549         case OperandType::TENSOR_FLOAT16: {
1550             return detectionPostprocessFloat16(
1551                     context->getInputBuffer<_Float16>(kScoreTensor),
1552                     context->getInputShape(kScoreTensor),
1553                     context->getInputBuffer<_Float16>(kDeltaTensor),
1554                     context->getInputShape(kDeltaTensor),
1555                     context->getInputBuffer<_Float16>(kAnchorTensor),
1556                     context->getInputShape(kAnchorTensor),
1557                     context->getInputValue<_Float16>(kScaleYScalar),
1558                     context->getInputValue<_Float16>(kScaleXScalar),
1559                     context->getInputValue<_Float16>(kScaleHScalar),
1560                     context->getInputValue<_Float16>(kScaleWScalar),
1561                     context->getInputValue<bool>(kUseRegularNmsScalar),
1562                     context->getInputValue<int32_t>(kMaxNumDetectionScalar),
1563                     context->getInputValue<int32_t>(kMaxClassesPerDetectionScalar),
1564                     context->getInputValue<int32_t>(kMaxNumDetectionPerClassScalar),
1565                     context->getInputValue<_Float16>(kIoUThresholdScalar),
1566                     context->getInputValue<_Float16>(kScoreThresholdScalar),
1567                     context->getInputValue<bool>(kIsBGInLabelScalar),
1568                     context->getOutputBuffer<_Float16>(kOutputScoreTensor),
1569                     context->getOutputShape(kOutputScoreTensor),
1570                     context->getOutputBuffer<_Float16>(kOutputRoiTensor),
1571                     context->getOutputShape(kOutputRoiTensor),
1572                     context->getOutputBuffer<int32_t>(kOutputClassTensor),
1573                     context->getOutputShape(kOutputClassTensor),
1574                     context->getOutputBuffer<int32_t>(kOutputDetectionTensor),
1575                     context->getOutputShape(kOutputDetectionTensor));
1576         }
1577         case OperandType::TENSOR_FLOAT32: {
1578             return detectionPostprocessFloat32(
1579                     context->getInputBuffer<float>(kScoreTensor),
1580                     context->getInputShape(kScoreTensor),
1581                     context->getInputBuffer<float>(kDeltaTensor),
1582                     context->getInputShape(kDeltaTensor),
1583                     context->getInputBuffer<float>(kAnchorTensor),
1584                     context->getInputShape(kAnchorTensor),
1585                     context->getInputValue<float>(kScaleYScalar),
1586                     context->getInputValue<float>(kScaleXScalar),
1587                     context->getInputValue<float>(kScaleHScalar),
1588                     context->getInputValue<float>(kScaleWScalar),
1589                     context->getInputValue<bool>(kUseRegularNmsScalar),
1590                     context->getInputValue<int32_t>(kMaxNumDetectionScalar),
1591                     context->getInputValue<int32_t>(kMaxClassesPerDetectionScalar),
1592                     context->getInputValue<int32_t>(kMaxNumDetectionPerClassScalar),
1593                     context->getInputValue<float>(kIoUThresholdScalar),
1594                     context->getInputValue<float>(kScoreThresholdScalar),
1595                     context->getInputValue<bool>(kIsBGInLabelScalar),
1596                     context->getOutputBuffer<float>(kOutputScoreTensor),
1597                     context->getOutputShape(kOutputScoreTensor),
1598                     context->getOutputBuffer<float>(kOutputRoiTensor),
1599                     context->getOutputShape(kOutputRoiTensor),
1600                     context->getOutputBuffer<int32_t>(kOutputClassTensor),
1601                     context->getOutputShape(kOutputClassTensor),
1602                     context->getOutputBuffer<int32_t>(kOutputDetectionTensor),
1603                     context->getOutputShape(kOutputDetectionTensor));
1604         }
1605         default:
1606             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
1607     }
1608 }
1609 
1610 }  // namespace detection_postprocess
1611 
1612 }  // namespace bbox_ops
1613 
1614 NN_REGISTER_OPERATION(AXIS_ALIGNED_BBOX_TRANSFORM,
1615                       bbox_ops::axis_aligned_bbox_transform::kOperationName,
1616                       bbox_ops::axis_aligned_bbox_transform::validate,
1617                       bbox_ops::axis_aligned_bbox_transform::prepare,
1618                       bbox_ops::axis_aligned_bbox_transform::execute, .allowZeroSizedInput = true);
1619 
1620 NN_REGISTER_OPERATION(BOX_WITH_NMS_LIMIT, bbox_ops::box_with_nms_limit::kOperationName,
1621                       bbox_ops::box_with_nms_limit::validate, bbox_ops::box_with_nms_limit::prepare,
1622                       bbox_ops::box_with_nms_limit::execute, .allowZeroSizedInput = true);
1623 
1624 NN_REGISTER_OPERATION(GENERATE_PROPOSALS, bbox_ops::generate_proposals::kOperationName,
1625                       bbox_ops::generate_proposals::validate, bbox_ops::generate_proposals::prepare,
1626                       bbox_ops::generate_proposals::execute);
1627 
1628 NN_REGISTER_OPERATION(DETECTION_POSTPROCESSING, bbox_ops::detection_postprocess::kOperationName,
1629                       bbox_ops::detection_postprocess::validate,
1630                       bbox_ops::detection_postprocess::prepare,
1631                       bbox_ops::detection_postprocess::execute);
1632 }  // namespace nn
1633 }  // namespace android
1634