1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "Operations"
18 
19 #include <algorithm>
20 #include <cfloat>
21 #include <cmath>
22 #include <vector>
23 
24 #include "CpuOperationUtils.h"
25 #include "HalInterfaces.h"
26 #include "OperationResolver.h"
27 #include "OperationsUtils.h"
28 #include "Tracing.h"
29 
30 namespace android {
31 namespace nn {
32 namespace roi_pooling {
33 
34 constexpr char kOperationName[] = "ROI_POOLING";
35 
36 constexpr uint32_t kNumInputs = 8;
37 constexpr uint32_t kInputTensor = 0;
38 constexpr uint32_t kRoiTensor = 1;
39 constexpr uint32_t kBatchSplitTensor = 2;
40 constexpr uint32_t kOutputHeightScalar = 3;
41 constexpr uint32_t kOutputWidthScalar = 4;
42 constexpr uint32_t kHeightStrideSalar = 5;
43 constexpr uint32_t kWidthStrideScalar = 6;
44 constexpr uint32_t kLayoutScalar = 7;
45 
46 constexpr uint32_t kNumOutputs = 1;
47 constexpr uint32_t kOutputTensor = 0;
48 
49 namespace {
50 
51 using namespace hal;
52 
53 template <typename T_Input, typename T_Roi>
roiPoolingNhwc(const T_Input * inputData,const Shape & inputShape,const T_Roi * roiData,const Shape & roiShape,const int32_t * batchSplitData,const Shape & batchSplitShape,float heightStride,float widthStride,T_Input * outputData,const Shape & outputShape)54 inline bool roiPoolingNhwc(const T_Input* inputData, const Shape& inputShape, const T_Roi* roiData,
55                            const Shape& roiShape, const int32_t* batchSplitData,
56                            const Shape& batchSplitShape, float heightStride, float widthStride,
57                            T_Input* outputData, const Shape& outputShape) {
58     NNTRACE_TRANS("RoiPooling");
59 
60     const uint32_t kRoiDim = 4;
61     const T_Roi heightScale = 1.0f / heightStride;
62     const T_Roi widthScale = 1.0f / widthStride;
63 
64     uint32_t numBatches = getSizeOfDimension(inputShape, 0);
65     uint32_t inHeight = getSizeOfDimension(inputShape, 1);
66     uint32_t inWidth = getSizeOfDimension(inputShape, 2);
67     uint32_t inDepth = getSizeOfDimension(inputShape, 3);
68     uint32_t outHeight = getSizeOfDimension(outputShape, 1);
69     uint32_t outWidth = getSizeOfDimension(outputShape, 2);
70     uint32_t numRois = getSizeOfDimension(roiShape, 0);
71     uint32_t roiInfoLength = getSizeOfDimension(roiShape, 1);
72 
73     T_Input* outPtr = outputData;
74     const T_Roi* roiDataEnd = roiData + numRois * roiInfoLength;
75     uint32_t roiIndex = 0;
76     for (const T_Roi* roiInfo = roiData; roiInfo < roiDataEnd; roiInfo += kRoiDim, roiIndex++) {
77         uint32_t batchId = batchSplitData[roiIndex];
78         // Check for malformed data
79         // 1. invalid batch id
80         // 2. Region out of bound: x1|x2|y1|y2 < 0 || x1|x2 > inWidth || y1|y2 > inHeight
81         // 3. Invalid region: x2 < x1 || y2 < y1
82         NN_RET_CHECK_GE(batchId, 0);
83         NN_RET_CHECK_LT(batchId, numBatches);
84         NN_RET_CHECK(roiInfo[0] >= 0);
85         NN_RET_CHECK(roiInfo[1] >= 0);
86         NN_RET_CHECK(roiInfo[2] >= 0);
87         NN_RET_CHECK(roiInfo[3] >= 0);
88         NN_RET_CHECK(roiInfo[0] * widthScale <= inWidth);
89         NN_RET_CHECK(roiInfo[1] * heightScale <= inHeight);
90         NN_RET_CHECK(roiInfo[2] * widthScale <= inWidth);
91         NN_RET_CHECK(roiInfo[3] * heightScale <= inHeight);
92         NN_RET_CHECK(roiInfo[0] <= roiInfo[2]);
93         NN_RET_CHECK(roiInfo[1] <= roiInfo[3]);
94 
95         int32_t wRoiStart = std::round(static_cast<float>(roiInfo[0] * widthScale));
96         int32_t hRoiStart = std::round(static_cast<float>(roiInfo[1] * heightScale));
97         int32_t wRoiEnd = std::round(static_cast<float>(roiInfo[2] * widthScale));
98         int32_t hRoiEnd = std::round(static_cast<float>(roiInfo[3] * heightScale));
99 
100         // Rois with width/height < 1 are considered malformed and are forced to be 1
101         T_Roi roiWidth = static_cast<T_Roi>(std::max(wRoiEnd - wRoiStart + 1, 1));
102         T_Roi roiHeight = static_cast<T_Roi>(std::max(hRoiEnd - hRoiStart + 1, 1));
103         T_Roi wStepSize = roiWidth / static_cast<T_Roi>(outWidth);
104         T_Roi hStepSize = roiHeight / static_cast<T_Roi>(outHeight);
105 
106         const T_Input* batchBase = inputData + batchId * inHeight * inWidth * inDepth;
107         for (uint32_t i = 0; i < outHeight; i++) {
108             for (uint32_t j = 0; j < outWidth; j++) {
109                 // Take floor on start, ceil on end, start included, end excluded, i.e. [start, end)
110                 // end is guaranteed to larger than start by at least 1
111                 uint32_t wStart = std::floor(static_cast<float>(wStepSize * j + wRoiStart));
112                 uint32_t wEnd = std::ceil(static_cast<float>(wStepSize * (j + 1) + wRoiStart));
113                 uint32_t hStart = std::floor(static_cast<float>(hStepSize * i + hRoiStart));
114                 uint32_t hEnd = std::ceil(static_cast<float>(hStepSize * (i + 1) + hRoiStart));
115 
116                 wStart = std::min(wStart, inWidth);
117                 wEnd = std::min(wEnd, inWidth);
118                 hStart = std::min(hStart, inHeight);
119                 hEnd = std::min(hEnd, inHeight);
120 
121                 for (uint32_t k = 0; k < inDepth; k++) {
122                     T_Input maxValue = static_cast<T_Input>(inputShape.offset);
123                     bool first = true;
124                     for (uint32_t h = hStart; h < hEnd; h++) {
125                         for (uint32_t w = wStart; w < wEnd; w++) {
126                             T_Input inputValue = batchBase[h * inWidth * inDepth + w * inDepth + k];
127                             if (first || inputValue > maxValue) {
128                                 maxValue = inputValue;
129                                 first = false;
130                             }
131                         }
132                     }
133                     outPtr[k] = maxValue;
134                 }
135                 outPtr += inDepth;
136             }
137         }
138     }
139     return true;
140 }
141 
142 template <typename T_Input, typename T_Roi>
roiPooling(const T_Input * inputData,const Shape & inputShape,const T_Roi * roiData,const Shape & roiShape,const int32_t * batchSplitData,const Shape & batchSplitShape,float heightStride,float widthStride,bool useNchw,T_Input * outputData,const Shape & outputShape)143 inline bool roiPooling(const T_Input* inputData, const Shape& inputShape, const T_Roi* roiData,
144                        const Shape& roiShape, const int32_t* batchSplitData,
145                        const Shape& batchSplitShape, float heightStride, float widthStride,
146                        bool useNchw, T_Input* outputData, const Shape& outputShape) {
147     InputWithLayout<T_Input> input(useNchw);
148     OutputWithLayout<T_Input> output(useNchw);
149     NN_RET_CHECK(input.initialize(inputData, inputShape));
150     NN_RET_CHECK(output.initialize(outputData, outputShape));
151     NN_RET_CHECK(roiPoolingNhwc(input.getNhwcBuffer(), input.getNhwcShape(), roiData, roiShape,
152                                 batchSplitData, batchSplitShape, heightStride, widthStride,
153                                 output.getNhwcBuffer(), output.getNhwcShape()));
154     NN_RET_CHECK(output.commit());
155     return true;
156 }
157 
158 template <>
roiPooling(const uint8_t * inputData,const Shape & inputShape,const uint16_t * roiData,const Shape & roiShape,const int32_t * batchSplitData,const Shape & batchSplitShape,float heightStride,float widthStride,bool useNchw,uint8_t * outputData,const Shape & outputShape)159 inline bool roiPooling<uint8_t, uint16_t>(const uint8_t* inputData, const Shape& inputShape,
160                                           const uint16_t* roiData, const Shape& roiShape,
161                                           const int32_t* batchSplitData,
162                                           const Shape& batchSplitShape, float heightStride,
163                                           float widthStride, bool useNchw, uint8_t* outputData,
164                                           const Shape& outputShape) {
165     std::vector<float> roi_float32(getNumberOfElements(roiShape));
166     convertQuantToFloat32(roiData, roiShape.scale, roiShape.offset, &roi_float32);
167     NN_RET_CHECK(roiPooling(inputData, inputShape, roi_float32.data(), roiShape, batchSplitData,
168                             batchSplitShape, heightStride, widthStride, useNchw, outputData,
169                             outputShape));
170     return true;
171 }
172 
173 template <>
roiPooling(const int8_t * inputData,const Shape & inputShape,const uint16_t * roiData,const Shape & roiShape,const int32_t * batchSplitData,const Shape & batchSplitShape,float heightStride,float widthStride,bool useNchw,int8_t * outputData,const Shape & outputShape)174 inline bool roiPooling<int8_t, uint16_t>(const int8_t* inputData, const Shape& inputShape,
175                                          const uint16_t* roiData, const Shape& roiShape,
176                                          const int32_t* batchSplitData,
177                                          const Shape& batchSplitShape, float heightStride,
178                                          float widthStride, bool useNchw, int8_t* outputData,
179                                          const Shape& outputShape) {
180     std::vector<float> roi_float32(getNumberOfElements(roiShape));
181     convertQuantToFloat32(roiData, roiShape.scale, roiShape.offset, &roi_float32);
182     NN_RET_CHECK(roiPooling(inputData, inputShape, roi_float32.data(), roiShape, batchSplitData,
183                             batchSplitShape, heightStride, widthStride, useNchw, outputData,
184                             outputShape));
185     return true;
186 }
187 
188 }  // namespace
189 
validate(const IOperationValidationContext * context)190 bool validate(const IOperationValidationContext* context) {
191     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
192     NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
193     std::vector<OperandType> inExpectedTypes;
194     auto inputType = context->getInputType(kInputTensor);
195     if (inputType == OperandType::TENSOR_FLOAT32) {
196         inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
197                            OperandType::TENSOR_INT32,   OperandType::INT32,
198                            OperandType::INT32,          OperandType::FLOAT32,
199                            OperandType::FLOAT32,        OperandType::BOOL};
200     } else if (inputType == OperandType::TENSOR_FLOAT16) {
201         inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
202                            OperandType::TENSOR_INT32,   OperandType::INT32,
203                            OperandType::INT32,          OperandType::FLOAT16,
204                            OperandType::FLOAT16,        OperandType::BOOL};
205     } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM ||
206                inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
207         inExpectedTypes = {inputType,
208                            OperandType::TENSOR_QUANT16_ASYMM,
209                            OperandType::TENSOR_INT32,
210                            OperandType::INT32,
211                            OperandType::INT32,
212                            OperandType::FLOAT32,
213                            OperandType::FLOAT32,
214                            OperandType::BOOL};
215     } else {
216         LOG(ERROR) << "Unsupported input tensor type for operation " << kOperationName;
217         return false;
218     }
219     NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
220     NN_RET_CHECK(validateOutputTypes(context, {inputType}));
221     if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
222         return validateHalVersion(context, HalVersion::V1_3);
223         ;
224     } else {
225         return validateHalVersion(context, HalVersion::V1_2);
226     }
227 }
228 
prepare(IOperationExecutionContext * context)229 bool prepare(IOperationExecutionContext* context) {
230     bool useNchw = context->getInputValue<bool>(kLayoutScalar);
231     Shape input = context->getInputShape(kInputTensor);
232     Shape roiShape = context->getInputShape(kRoiTensor);
233     Shape batchSplitShape = context->getInputShape(kBatchSplitTensor);
234     NN_RET_CHECK_EQ(getNumberOfDimensions(input), 4);
235     NN_RET_CHECK_EQ(getNumberOfDimensions(roiShape), 2);
236 
237     uint32_t numBatches = getSizeOfDimension(input, 0);
238     uint32_t inHeight = getSizeOfDimension(input, useNchw ? 2 : 1);
239     uint32_t inWidth = getSizeOfDimension(input, useNchw ? 3 : 2);
240     uint32_t inDepth = getSizeOfDimension(input, useNchw ? 1 : 3);
241     uint32_t numRois = getSizeOfDimension(roiShape, 0);
242     NN_RET_CHECK_EQ(getSizeOfDimension(roiShape, 1), 4);
243     NN_RET_CHECK_EQ(getSizeOfDimension(batchSplitShape, 0), numRois);
244 
245     auto outputHeight = context->getInputValue<int32_t>(kOutputHeightScalar);
246     auto outputWidth = context->getInputValue<int32_t>(kOutputWidthScalar);
247     float heightStride, widthStride;
248     if (context->getInputType(kInputTensor) == OperandType::TENSOR_FLOAT16) {
249         heightStride = context->getInputValue<_Float16>(kHeightStrideSalar);
250         widthStride = context->getInputValue<_Float16>(kWidthStrideScalar);
251     } else {
252         heightStride = context->getInputValue<float>(kHeightStrideSalar);
253         widthStride = context->getInputValue<float>(kWidthStrideScalar);
254     }
255     NN_RET_CHECK_GT(outputHeight, 0);
256     NN_RET_CHECK_GT(outputWidth, 0);
257     NN_RET_CHECK_GT(heightStride, 0);
258     NN_RET_CHECK_GT(widthStride, 0);
259 
260     if (roiShape.type == OperandType::TENSOR_QUANT16_ASYMM) {
261         NN_RET_CHECK_EQ(roiShape.scale, 0.125f);
262         NN_RET_CHECK_EQ(roiShape.offset, 0);
263     }
264 
265     Shape output = input;
266     if (useNchw) {
267         output.dimensions = {numRois, inDepth, static_cast<uint32_t>(outputHeight),
268                              static_cast<uint32_t>(outputWidth)};
269     } else {
270         output.dimensions = {numRois, static_cast<uint32_t>(outputHeight),
271                              static_cast<uint32_t>(outputWidth), inDepth};
272     }
273     return context->setOutputShape(kOutputTensor, output);
274 }
275 
execute(IOperationExecutionContext * context)276 bool execute(IOperationExecutionContext* context) {
277     switch (context->getInputType(kInputTensor)) {
278         case OperandType::TENSOR_FLOAT16:
279             return roiPooling(context->getInputBuffer<_Float16>(kInputTensor),
280                               context->getInputShape(kInputTensor),
281                               context->getInputBuffer<_Float16>(kRoiTensor),
282                               context->getInputShape(kRoiTensor),
283                               context->getInputBuffer<int32_t>(kBatchSplitTensor),
284                               context->getInputShape(kBatchSplitTensor),
285                               context->getInputValue<_Float16>(kHeightStrideSalar),
286                               context->getInputValue<_Float16>(kWidthStrideScalar),
287                               context->getInputValue<bool>(kLayoutScalar),
288                               context->getOutputBuffer<_Float16>(kOutputTensor),
289                               context->getOutputShape(kOutputTensor));
290         case OperandType::TENSOR_FLOAT32:
291             return roiPooling(context->getInputBuffer<float>(kInputTensor),
292                               context->getInputShape(kInputTensor),
293                               context->getInputBuffer<float>(kRoiTensor),
294                               context->getInputShape(kRoiTensor),
295                               context->getInputBuffer<int32_t>(kBatchSplitTensor),
296                               context->getInputShape(kBatchSplitTensor),
297                               context->getInputValue<float>(kHeightStrideSalar),
298                               context->getInputValue<float>(kWidthStrideScalar),
299                               context->getInputValue<bool>(kLayoutScalar),
300                               context->getOutputBuffer<float>(kOutputTensor),
301                               context->getOutputShape(kOutputTensor));
302         case OperandType::TENSOR_QUANT8_ASYMM:
303             return roiPooling(context->getInputBuffer<uint8_t>(kInputTensor),
304                               context->getInputShape(kInputTensor),
305                               context->getInputBuffer<uint16_t>(kRoiTensor),
306                               context->getInputShape(kRoiTensor),
307                               context->getInputBuffer<int32_t>(kBatchSplitTensor),
308                               context->getInputShape(kBatchSplitTensor),
309                               context->getInputValue<float>(kHeightStrideSalar),
310                               context->getInputValue<float>(kWidthStrideScalar),
311                               context->getInputValue<bool>(kLayoutScalar),
312                               context->getOutputBuffer<uint8_t>(kOutputTensor),
313                               context->getOutputShape(kOutputTensor));
314         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
315             return roiPooling(context->getInputBuffer<int8_t>(kInputTensor),
316                               context->getInputShape(kInputTensor),
317                               context->getInputBuffer<uint16_t>(kRoiTensor),
318                               context->getInputShape(kRoiTensor),
319                               context->getInputBuffer<int32_t>(kBatchSplitTensor),
320                               context->getInputShape(kBatchSplitTensor),
321                               context->getInputValue<float>(kHeightStrideSalar),
322                               context->getInputValue<float>(kWidthStrideScalar),
323                               context->getInputValue<bool>(kLayoutScalar),
324                               context->getOutputBuffer<int8_t>(kOutputTensor),
325                               context->getOutputShape(kOutputTensor));
326         default:
327             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
328     }
329 }
330 
331 }  // namespace roi_pooling
332 
333 NN_REGISTER_OPERATION(ROI_POOLING, roi_pooling::kOperationName, roi_pooling::validate,
334                       roi_pooling::prepare, roi_pooling::execute);
335 
336 }  // namespace nn
337 }  // namespace android
338