1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "Operations"
18 
19 #include "RoiAlign.h"
20 
21 #include <algorithm>
22 #include <cfloat>
23 #include <cmath>
24 #include <vector>
25 
26 #include "OperationResolver.h"
27 #include "OperationsExecutionUtils.h"
28 #include "Tracing.h"
29 
30 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
31 #pragma clang diagnostic push
32 #pragma clang diagnostic ignored "-Wunused-parameter"
33 #include <tensorflow/lite/kernels/internal/common.h>
34 #pragma clang diagnostic pop
35 
36 #include "CpuOperationUtils.h"
37 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
38 
39 namespace android {
40 namespace nn {
41 namespace roi_align {
42 
43 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
44 namespace {
45 
46 template <typename T_Input, typename T_Roi>
roiAlignNhwc(const T_Input * inputData,const Shape & inputShape,const T_Roi * roiData,const Shape & roiShape,const int32_t * batchSplitData,const Shape &,float heightStride,float widthStride,int32_t heightSamplingRatio,int32_t widthSamplingRatio,T_Input * outputData,const Shape & outputShape)47 inline bool roiAlignNhwc(const T_Input* inputData, const Shape& inputShape, const T_Roi* roiData,
48                          const Shape& roiShape, const int32_t* batchSplitData,
49                          const Shape& /*batchSplitShape*/, float heightStride, float widthStride,
50                          int32_t heightSamplingRatio, int32_t widthSamplingRatio,
51                          T_Input* outputData, const Shape& outputShape) {
52     NNTRACE_TRANS("RoiAlign");
53 
54     const uint32_t kRoiDim = 4;
55     const T_Roi heightScale = 1.0f / heightStride;
56     const T_Roi widthScale = 1.0f / widthStride;
57 
58     uint32_t numBatches = getSizeOfDimension(inputShape, 0);
59     uint32_t inHeight = getSizeOfDimension(inputShape, 1);
60     uint32_t inWidth = getSizeOfDimension(inputShape, 2);
61     uint32_t inDepth = getSizeOfDimension(inputShape, 3);
62     uint32_t outHeight = getSizeOfDimension(outputShape, 1);
63     uint32_t outWidth = getSizeOfDimension(outputShape, 2);
64     uint32_t numRois = getSizeOfDimension(roiShape, 0);
65     uint32_t roiInfoLength = getSizeOfDimension(roiShape, 1);
66 
67     T_Input* outPtr = outputData;
68     const T_Roi* roiDataEnd = roiData + numRois * roiInfoLength;
69     uint32_t roiIndex = 0;
70     for (const T_Roi* roiInfo = roiData; roiInfo < roiDataEnd; roiInfo += kRoiDim, roiIndex++) {
71         uint32_t batchId = static_cast<uint32_t>(batchSplitData[roiIndex]);
72         // Check for malformed data
73         // 1. invalid batch id
74         // 2. Region out of bound: x1|x2|y1|y2 < 0 || x1|x2 > inWidth || y1|y2 > inHeight
75         // 3. Invalid region: x2 < x1 || y2 < y1
76         NN_RET_CHECK_GE(batchId, 0u);
77         NN_RET_CHECK_LT(batchId, numBatches);
78         NN_RET_CHECK(roiInfo[0] >= 0);
79         NN_RET_CHECK(roiInfo[1] >= 0);
80         NN_RET_CHECK(roiInfo[2] >= 0);
81         NN_RET_CHECK(roiInfo[3] >= 0);
82         NN_RET_CHECK(roiInfo[0] * widthScale <= inWidth);
83         NN_RET_CHECK(roiInfo[1] * heightScale <= inHeight);
84         NN_RET_CHECK(roiInfo[2] * widthScale <= inWidth);
85         NN_RET_CHECK(roiInfo[3] * heightScale <= inHeight);
86         NN_RET_CHECK(roiInfo[0] <= roiInfo[2]);
87         NN_RET_CHECK(roiInfo[1] <= roiInfo[3]);
88 
89         T_Roi wRoiStart = roiInfo[0] * widthScale;
90         T_Roi hRoiStart = roiInfo[1] * heightScale;
91         T_Roi wRoiEnd = roiInfo[2] * widthScale;
92         T_Roi hRoiEnd = roiInfo[3] * heightScale;
93 
94         T_Roi roiWidth = std::max(static_cast<float>(wRoiEnd - wRoiStart), 1.0f);
95         T_Roi roiHeight = std::max(static_cast<float>(hRoiEnd - hRoiStart), 1.0f);
96         T_Roi wStepSize = roiWidth / static_cast<T_Roi>(outWidth);
97         T_Roi hStepSize = roiHeight / static_cast<T_Roi>(outHeight);
98 
99         // if samplingRatio = 0, use adaptive value of ceil(roiWidth/outWidth), same for height
100         uint32_t wSamplingRatio = widthSamplingRatio > 0 ? widthSamplingRatio
101                                                          : std::ceil(static_cast<float>(wStepSize));
102         uint32_t hSamplingRatio = heightSamplingRatio > 0
103                                           ? heightSamplingRatio
104                                           : std::ceil(static_cast<float>(hStepSize));
105         int32_t numSamplingPoints = wSamplingRatio * hSamplingRatio;
106         T_Roi wBinSize = wStepSize / static_cast<T_Roi>(wSamplingRatio);
107         T_Roi hBinSize = hStepSize / static_cast<T_Roi>(hSamplingRatio);
108 
109         const T_Input* batchBase = inputData + batchId * inHeight * inWidth * inDepth;
110         for (uint32_t i = 0; i < outHeight; i++) {
111             for (uint32_t j = 0; j < outWidth; j++) {
112                 T_Roi wStart = wStepSize * j + wRoiStart;
113                 [[maybe_unused]] T_Roi wEnd = wStepSize * (j + 1) + wRoiStart;
114                 T_Roi hStart = hStepSize * i + hRoiStart;
115                 [[maybe_unused]] T_Roi hEnd = hStepSize * (i + 1) + hRoiStart;
116 
117                 // initialize output to zero
118                 for (uint32_t k = 0; k < inDepth; k++) outPtr[k] = 0;
119 
120                 // calculate the sum of the sampling points
121                 for (uint32_t yInd = 0; yInd < hSamplingRatio; yInd++) {
122                     for (uint32_t xInd = 0; xInd < wSamplingRatio; xInd++) {
123                         T_Roi y = hStart + hBinSize / 2 + hBinSize * yInd;
124                         T_Roi x = wStart + wBinSize / 2 + wBinSize * xInd;
125 
126                         // bilinear interpolation of point (x,y)
127                         // w.r.t box [(x1,y1), (x1,y2), (x2,y1), (x2,y2)]
128                         uint32_t x1 = std::floor(static_cast<float>(x));
129                         uint32_t y1 = std::floor(static_cast<float>(y));
130                         uint32_t x2 = x1 + 1, y2 = y1 + 1;
131                         T_Roi dx1 = x - static_cast<T_Roi>(x1);
132                         T_Roi dy1 = y - static_cast<T_Roi>(y1);
133 
134                         // dealing with out of bound samples
135                         if (x1 >= inWidth - 1) {
136                             x1 = x2 = inWidth - 1;
137                             dx1 = 0;
138                         }
139                         if (y1 >= inHeight - 1) {
140                             y1 = y2 = inHeight - 1;
141                             dy1 = 0;
142                         }
143 
144                         T_Roi dx2 = 1.0f - dx1, dy2 = 1.0f - dy1;
145                         T_Roi ws[] = {dx2 * dy2, dx1 * dy2, dx2 * dy1, dx1 * dy1};
146                         uint32_t offsets[] = {y1 * inWidth * inDepth + x1 * inDepth,
147                                               y1 * inWidth * inDepth + x2 * inDepth,
148                                               y2 * inWidth * inDepth + x1 * inDepth,
149                                               y2 * inWidth * inDepth + x2 * inDepth};
150 
151                         for (uint32_t k = 0; k < inDepth; k++) {
152                             T_Input interpolation = 0;
153                             for (uint32_t c = 0; c < 4; c++) {
154                                 interpolation += ws[c] * batchBase[offsets[c] + k];
155                             }
156                             outPtr[k] += interpolation;
157                         }
158                     }
159                 }
160 
161                 // take average
162                 for (uint32_t k = 0; k < inDepth; k++)
163                     outPtr[k] /= static_cast<T_Input>(numSamplingPoints);
164                 outPtr += inDepth;
165             }
166         }
167     }
168     return true;
169 }
170 
171 template <typename T_Input>
roiAlignQuantNhwc(const T_Input * inputData,const Shape & inputShape,const uint16_t * roiData,const Shape & roiShape,const int32_t * batchSplitData,const Shape &,float heightStride,float widthStride,int32_t heightSamplingRatio,int32_t widthSamplingRatio,T_Input * outputData,const Shape & outputShape)172 inline bool roiAlignQuantNhwc(const T_Input* inputData, const Shape& inputShape,
173                               const uint16_t* roiData, const Shape& roiShape,
174                               const int32_t* batchSplitData, const Shape& /*batchSplitShape*/,
175                               float heightStride, float widthStride, int32_t heightSamplingRatio,
176                               int32_t widthSamplingRatio, T_Input* outputData,
177                               const Shape& outputShape) {
178     NNTRACE_TRANS("RoiAlignQuant8");
179 
180     constexpr float wScale = 1.0f / 255.0f;
181     constexpr uint32_t kRoiDim = 4;
182     const float heightScale = 1.0f / heightStride;
183     const float widthScale = 1.0f / widthStride;
184 
185     uint32_t numBatches = getSizeOfDimension(inputShape, 0);
186     uint32_t inHeight = getSizeOfDimension(inputShape, 1);
187     uint32_t inWidth = getSizeOfDimension(inputShape, 2);
188     uint32_t inDepth = getSizeOfDimension(inputShape, 3);
189     uint32_t outHeight = getSizeOfDimension(outputShape, 1);
190     uint32_t outWidth = getSizeOfDimension(outputShape, 2);
191     uint32_t numRois = getSizeOfDimension(roiShape, 0);
192     uint32_t roiInfoLength = getSizeOfDimension(roiShape, 1);
193 
194     T_Input* outPtr = outputData;
195     const uint16_t* roiDataEnd = roiData + numRois * roiInfoLength;
196     uint32_t roiIndex = 0;
197     for (const uint16_t* roiInfo = roiData; roiInfo < roiDataEnd; roiInfo += kRoiDim, roiIndex++) {
198         uint32_t batchId = static_cast<uint32_t>(batchSplitData[roiIndex]);
199         float wRoiStart = static_cast<float>(roiInfo[0]) * widthScale * 0.125f;
200         float hRoiStart = static_cast<float>(roiInfo[1]) * heightScale * 0.125f;
201         float wRoiEnd = static_cast<float>(roiInfo[2]) * widthScale * 0.125f;
202         float hRoiEnd = static_cast<float>(roiInfo[3]) * heightScale * 0.125f;
203 
204         // Check for malformed data
205         // 1. invalid batch id
206         // 2. Region out of bound: x1|x2|y1|y2 < 0 || x1|x2 > inWidth || y1|y2 > inHeight
207         // 3. Invalid region: x2 < x1 || y2 < y1
208         NN_RET_CHECK_GE(batchId, 0u);
209         NN_RET_CHECK_LT(batchId, numBatches);
210         NN_RET_CHECK(wRoiStart <= inWidth);
211         NN_RET_CHECK(hRoiStart <= inHeight);
212         NN_RET_CHECK(wRoiEnd <= inWidth);
213         NN_RET_CHECK(hRoiEnd <= inHeight);
214         NN_RET_CHECK_LE(wRoiStart, wRoiEnd);
215         NN_RET_CHECK_LE(hRoiStart, hRoiEnd);
216 
217         float roiWidth = std::max(wRoiEnd - wRoiStart, 1.0f);
218         float roiHeight = std::max(hRoiEnd - hRoiStart, 1.0f);
219         float wStepSize = roiWidth / static_cast<float>(outWidth);
220         float hStepSize = roiHeight / static_cast<float>(outHeight);
221 
222         // if samplingRatio = 0, use adaptive value of ceil(roiWidth/outWidth), same for height
223         uint32_t wSamplingRatio =
224                 widthSamplingRatio > 0 ? widthSamplingRatio : std::ceil(wStepSize);
225         uint32_t hSamplingRatio =
226                 heightSamplingRatio > 0 ? heightSamplingRatio : std::ceil(hStepSize);
227         int32_t numSamplingPoints = wSamplingRatio * hSamplingRatio;
228         float wBinSize = wStepSize / static_cast<float>(wSamplingRatio);
229         float hBinSize = hStepSize / static_cast<float>(hSamplingRatio);
230 
231         float realMultiplier = inputShape.scale * wScale / outputShape.scale / numSamplingPoints;
232         int32_t outputMultiplier = 0;
233         int32_t outputShift = 0;
234         if (!QuantizeMultiplierSmallerThanOne(realMultiplier, &outputMultiplier, &outputShift)) {
235             return false;
236         }
237 
238         const T_Input* batchBase = inputData + batchId * inHeight * inWidth * inDepth;
239         for (uint32_t i = 0; i < outHeight; i++) {
240             for (uint32_t j = 0; j < outWidth; j++) {
241                 float wStart = wStepSize * j + wRoiStart;
242                 [[maybe_unused]] float wEnd = wStepSize * (j + 1) + wRoiStart;
243                 float hStart = hStepSize * i + hRoiStart;
244                 [[maybe_unused]] float hEnd = hStepSize * (i + 1) + hRoiStart;
245 
246                 std::vector<int32_t> outTemp(inDepth, 0);
247                 // calculate the sum of the sampling points
248                 for (uint32_t yInd = 0; yInd < hSamplingRatio; yInd++) {
249                     for (uint32_t xInd = 0; xInd < wSamplingRatio; xInd++) {
250                         float y = hStart + hBinSize / 2 + hBinSize * yInd;
251                         float x = wStart + wBinSize / 2 + wBinSize * xInd;
252 
253                         // bilinear interpolation of point (x,y)
254                         // w.r.t box [(x1,y1), (x1,y2), (x2,y1), (x2,y2)]
255                         uint32_t x1 = std::floor(x), y1 = std::floor(y);
256                         uint32_t x2 = x1 + 1, y2 = y1 + 1;
257                         float dx1 = x - static_cast<float>(x1);
258                         float dy1 = y - static_cast<float>(y1);
259 
260                         // dealing with out of bound samples
261                         if (x1 >= inWidth - 1) {
262                             x1 = x2 = inWidth - 1;
263                             dx1 = 0;
264                         }
265                         if (y1 >= inHeight - 1) {
266                             y1 = y2 = inHeight - 1;
267                             dy1 = 0;
268                         }
269 
270                         float dx2 = 1.0f - dx1, dy2 = 1.0f - dy1;
271                         float ws[] = {dx2 * dy2, dx1 * dy2, dx2 * dy1, dx1 * dy1};
272                         uint32_t offsets[] = {y1 * inWidth * inDepth + x1 * inDepth,
273                                               y1 * inWidth * inDepth + x2 * inDepth,
274                                               y2 * inWidth * inDepth + x1 * inDepth,
275                                               y2 * inWidth * inDepth + x2 * inDepth};
276 
277                         for (uint32_t k = 0; k < inDepth; k++) {
278                             int32_t interpolation = 0;
279                             for (uint32_t c = 0; c < 4; c++) {
280                                 int32_t wQuant = static_cast<int32_t>(std::round(ws[c] / wScale));
281                                 interpolation +=
282                                         wQuant * (static_cast<int32_t>(batchBase[offsets[c] + k]) -
283                                                   inputShape.offset);
284                             }
285                             outTemp[k] += interpolation;
286                         }
287                     }
288                 }
289 
290                 // take average and cast to output quantization
291                 for (uint32_t k = 0; k < inDepth; k++) {
292                     int32_t raw_out = tflite::MultiplyByQuantizedMultiplier(
293                                               outTemp[k], outputMultiplier, -outputShift) +
294                                       outputShape.offset;
295                     outPtr[k] = saturateCast<T_Input>(raw_out);
296                 }
297                 outPtr += inDepth;
298             }
299         }
300     }
301     return true;
302 }
303 
304 template <typename T_Input, typename T_Roi>
roiAlign(const T_Input * inputData,const Shape & inputShape,const T_Roi * roiData,const Shape & roiShape,const int32_t * batchSplitData,const Shape & batchSplitShape,float heightStride,float widthStride,int32_t heightSamplingRatio,int32_t widthSamplingRatio,bool useNchw,T_Input * outputData,const Shape & outputShape)305 inline bool roiAlign(const T_Input* inputData, const Shape& inputShape, const T_Roi* roiData,
306                      const Shape& roiShape, const int32_t* batchSplitData,
307                      const Shape& batchSplitShape, float heightStride, float widthStride,
308                      int32_t heightSamplingRatio, int32_t widthSamplingRatio, bool useNchw,
309                      T_Input* outputData, const Shape& outputShape) {
310     InputWithLayout<T_Input> input(useNchw);
311     OutputWithLayout<T_Input> output(useNchw);
312     NN_RET_CHECK(input.initialize(inputData, inputShape));
313     NN_RET_CHECK(output.initialize(outputData, outputShape));
314     if constexpr (std::is_same_v<T_Roi, uint16_t> &&
315                   (std::is_same_v<T_Input, uint8_t> || std::is_same_v<T_Input, int8_t>)) {
316         NN_RET_CHECK(roiAlignQuantNhwc<T_Input>(
317                 input.getNhwcBuffer(), input.getNhwcShape(), roiData, roiShape, batchSplitData,
318                 batchSplitShape, heightStride, widthStride, heightSamplingRatio, widthSamplingRatio,
319                 output.getNhwcBuffer(), output.getNhwcShape()));
320     } else {
321         NN_RET_CHECK(roiAlignNhwc(input.getNhwcBuffer(), input.getNhwcShape(), roiData, roiShape,
322                                   batchSplitData, batchSplitShape, heightStride, widthStride,
323                                   heightSamplingRatio, widthSamplingRatio, output.getNhwcBuffer(),
324                                   output.getNhwcShape()));
325     }
326     NN_RET_CHECK(output.commit());
327     return true;
328 }
329 
330 }  // namespace
331 
prepare(IOperationExecutionContext * context)332 bool prepare(IOperationExecutionContext* context) {
333     bool useNchw = context->getInputValue<bool>(kLayoutScalar);
334     Shape input = context->getInputShape(kInputTensor);
335     Shape roiShape = context->getInputShape(kRoiTensor);
336     Shape batchSplitShape = context->getInputShape(kBatchSplitTensor);
337     NN_RET_CHECK_EQ(getNumberOfDimensions(input), 4u);
338     NN_RET_CHECK_EQ(getNumberOfDimensions(roiShape), 2u);
339 
340     uint32_t numBatches = getSizeOfDimension(input, 0);
341     uint32_t inHeight = getSizeOfDimension(input, useNchw ? 2 : 1);
342     uint32_t inWidth = getSizeOfDimension(input, useNchw ? 3 : 2);
343     uint32_t inDepth = getSizeOfDimension(input, useNchw ? 1 : 3);
344     uint32_t numRois = getSizeOfDimension(roiShape, 0);
345     // Every dimension must be positive except for numRois.
346     NN_RET_CHECK_GT(numBatches, 0u);
347     NN_RET_CHECK_GT(inHeight, 0u);
348     NN_RET_CHECK_GT(inWidth, 0u);
349     NN_RET_CHECK_GT(inDepth, 0u);
350     NN_RET_CHECK_EQ(getSizeOfDimension(roiShape, 1), 4u);
351     NN_RET_CHECK_EQ(getSizeOfDimension(batchSplitShape, 0), numRois);
352 
353     int32_t outputHeight = context->getInputValue<int32_t>(kOutputHeightScalar);
354     int32_t outputWidth = context->getInputValue<int32_t>(kOutputWidthScalar);
355     int32_t heightSamplingRatio = context->getInputValue<int32_t>(kHeightSamplingRatioScalar);
356     int32_t widthSamplingRatio = context->getInputValue<int32_t>(kWidthSamplingRatioScalar);
357     float heightScale, widthScale;
358     if (context->getInputType(kInputTensor) == OperandType::TENSOR_FLOAT16) {
359         heightScale = context->getInputValue<_Float16>(kHeightStrideSalar);
360         widthScale = context->getInputValue<_Float16>(kWidthStrideScalar);
361     } else {
362         heightScale = context->getInputValue<float>(kHeightStrideSalar);
363         widthScale = context->getInputValue<float>(kWidthStrideScalar);
364     }
365     NN_RET_CHECK_GT(outputHeight, 0);
366     NN_RET_CHECK_GT(outputWidth, 0);
367     NN_RET_CHECK_GT(heightScale, 0);
368     NN_RET_CHECK_GT(widthScale, 0);
369     // Sampling ratio can set to 0 for adaptive value.
370     NN_RET_CHECK_GE(heightSamplingRatio, 0);
371     NN_RET_CHECK_GE(widthSamplingRatio, 0);
372 
373     if (roiShape.type == OperandType::TENSOR_QUANT16_ASYMM) {
374         NN_RET_CHECK_EQ(roiShape.scale, 0.125f);
375         NN_RET_CHECK_EQ(roiShape.offset, 0);
376     }
377 
378     Shape output = context->getOutputShape(kOutputTensor);
379     output.type = input.type;
380     if (useNchw) {
381         output.dimensions = {numRois, inDepth, static_cast<uint32_t>(outputHeight),
382                              static_cast<uint32_t>(outputWidth)};
383     } else {
384         output.dimensions = {numRois, static_cast<uint32_t>(outputHeight),
385                              static_cast<uint32_t>(outputWidth), inDepth};
386     }
387     return context->setOutputShape(kOutputTensor, output);
388 }
389 
execute(IOperationExecutionContext * context)390 bool execute(IOperationExecutionContext* context) {
391     // Bypass execution in the case of zero-sized input.
392     if (getNumberOfElements(context->getInputShape(kRoiTensor)) == 0) return true;
393     switch (context->getInputType(kInputTensor)) {
394         case OperandType::TENSOR_FLOAT16:
395             return roiAlign(context->getInputBuffer<_Float16>(kInputTensor),
396                             context->getInputShape(kInputTensor),
397                             context->getInputBuffer<_Float16>(kRoiTensor),
398                             context->getInputShape(kRoiTensor),
399                             context->getInputBuffer<int32_t>(kBatchSplitTensor),
400                             context->getInputShape(kBatchSplitTensor),
401                             context->getInputValue<_Float16>(kHeightStrideSalar),
402                             context->getInputValue<_Float16>(kWidthStrideScalar),
403                             context->getInputValue<int32_t>(kHeightSamplingRatioScalar),
404                             context->getInputValue<int32_t>(kWidthSamplingRatioScalar),
405                             context->getInputValue<bool>(kLayoutScalar),
406                             context->getOutputBuffer<_Float16>(kOutputTensor),
407                             context->getOutputShape(kOutputTensor));
408         case OperandType::TENSOR_FLOAT32:
409             return roiAlign(context->getInputBuffer<float>(kInputTensor),
410                             context->getInputShape(kInputTensor),
411                             context->getInputBuffer<float>(kRoiTensor),
412                             context->getInputShape(kRoiTensor),
413                             context->getInputBuffer<int32_t>(kBatchSplitTensor),
414                             context->getInputShape(kBatchSplitTensor),
415                             context->getInputValue<float>(kHeightStrideSalar),
416                             context->getInputValue<float>(kWidthStrideScalar),
417                             context->getInputValue<int32_t>(kHeightSamplingRatioScalar),
418                             context->getInputValue<int32_t>(kWidthSamplingRatioScalar),
419                             context->getInputValue<bool>(kLayoutScalar),
420                             context->getOutputBuffer<float>(kOutputTensor),
421                             context->getOutputShape(kOutputTensor));
422         case OperandType::TENSOR_QUANT8_ASYMM:
423             return roiAlign(context->getInputBuffer<uint8_t>(kInputTensor),
424                             context->getInputShape(kInputTensor),
425                             context->getInputBuffer<uint16_t>(kRoiTensor),
426                             context->getInputShape(kRoiTensor),
427                             context->getInputBuffer<int32_t>(kBatchSplitTensor),
428                             context->getInputShape(kBatchSplitTensor),
429                             context->getInputValue<float>(kHeightStrideSalar),
430                             context->getInputValue<float>(kWidthStrideScalar),
431                             context->getInputValue<int32_t>(kHeightSamplingRatioScalar),
432                             context->getInputValue<int32_t>(kWidthSamplingRatioScalar),
433                             context->getInputValue<bool>(kLayoutScalar),
434                             context->getOutputBuffer<uint8_t>(kOutputTensor),
435                             context->getOutputShape(kOutputTensor));
436         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
437             return roiAlign(context->getInputBuffer<int8_t>(kInputTensor),
438                             context->getInputShape(kInputTensor),
439                             context->getInputBuffer<uint16_t>(kRoiTensor),
440                             context->getInputShape(kRoiTensor),
441                             context->getInputBuffer<int32_t>(kBatchSplitTensor),
442                             context->getInputShape(kBatchSplitTensor),
443                             context->getInputValue<float>(kHeightStrideSalar),
444                             context->getInputValue<float>(kWidthStrideScalar),
445                             context->getInputValue<int32_t>(kHeightSamplingRatioScalar),
446                             context->getInputValue<int32_t>(kWidthSamplingRatioScalar),
447                             context->getInputValue<bool>(kLayoutScalar),
448                             context->getOutputBuffer<int8_t>(kOutputTensor),
449                             context->getOutputShape(kOutputTensor));
450         default:
451             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
452     }
453 }
454 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
455 
456 }  // namespace roi_align
457 
458 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(ROI_ALIGN, roi_align::prepare, roi_align::execute,
459                                          .allowZeroSizedInput = true);
460 
461 }  // namespace nn
462 }  // namespace android
463