1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "OperationsUtils"
18 
19 #include "OperationsUtils.h"
20 #include "Operations.h"
21 #include "Utils.h"
22 
23 #include <cmath>
24 
25 namespace android {
26 namespace nn {
27 
28 namespace {
29 
validateOperandTypes(const std::vector<OperandType> & expectedTypes,const char * tag,uint32_t operandCount,std::function<OperandType (uint32_t)> getOperandType)30 bool validateOperandTypes(const std::vector<OperandType>& expectedTypes, const char* tag,
31                           uint32_t operandCount,
32                           std::function<OperandType(uint32_t)> getOperandType) {
33     NN_RET_CHECK_EQ(operandCount, expectedTypes.size());
34     for (uint32_t i = 0; i < operandCount; ++i) {
35         OperandType type = getOperandType(i);
36         NN_RET_CHECK(type == expectedTypes[i])
37                 << "Invalid " << tag << " tensor type " << toString(type) << " for " << tag << " "
38                 << i << ", expected " << toString(expectedTypes[i]);
39     }
40     return true;
41 }
42 
43 }  // namespace
44 
validateInputTypes(const IOperationValidationContext * context,const std::vector<OperandType> & expectedTypes)45 bool validateInputTypes(const IOperationValidationContext* context,
46                         const std::vector<OperandType>& expectedTypes) {
47     return validateOperandTypes(expectedTypes, "input", context->getNumInputs(),
48                                 [context](uint32_t index) { return context->getInputType(index); });
49 }
50 
validateOutputTypes(const IOperationValidationContext * context,const std::vector<OperandType> & expectedTypes)51 bool validateOutputTypes(const IOperationValidationContext* context,
52                          const std::vector<OperandType>& expectedTypes) {
53     return validateOperandTypes(
54             expectedTypes, "output", context->getNumOutputs(),
55             [context](uint32_t index) { return context->getOutputType(index); });
56 }
57 
validateHalVersion(const IOperationValidationContext * context,HalVersion minSupportedHalVersion)58 bool validateHalVersion(const IOperationValidationContext* context,
59                         HalVersion minSupportedHalVersion) {
60     if (context->getHalVersion() < minSupportedHalVersion) {
61         NN_RET_CHECK_FAIL() << "The given inputs and outputs are only supported in "
62                             << toString(minSupportedHalVersion) << " and later (validating using "
63                             << toString(context->getHalVersion()) << ")";
64     }
65     return true;
66 }
67 
SameShape(const Shape & in1,const Shape & in2)68 bool SameShape(const Shape& in1, const Shape& in2) {
69     if (in1.type != in2.type || in1.dimensions.size() != in2.dimensions.size()) {
70         return false;
71     }
72     for (size_t i = 0; i < in1.dimensions.size(); i++) {
73         if (in1.dimensions[i] != in2.dimensions[i]) {
74             return false;
75         }
76     }
77     return true;
78 }
79 
SetShape(const Shape & in,Shape * out)80 bool SetShape(const Shape& in, Shape* out) {
81     if (in.type != out->type) {
82         return false;
83     }
84     out->dimensions = in.dimensions;
85     return true;
86 }
87 
combineDimensions(const std::vector<uint32_t> & lhs,const std::vector<uint32_t> & rhs,std::vector<uint32_t> * combined)88 bool combineDimensions(const std::vector<uint32_t>& lhs, const std::vector<uint32_t>& rhs,
89                        std::vector<uint32_t>* combined) {
90     if (rhs.empty()) {
91         *combined = lhs;
92         return true;
93     }
94     if (lhs.empty()) {
95         *combined = rhs;
96         return true;
97     }
98     NN_RET_CHECK_EQ(lhs.size(), rhs.size()) << "incompatible ranks";
99     combined->resize(lhs.size());
100     for (uint32_t i = 0; i < lhs.size(); i++) {
101         if (lhs[i] == 0) {
102             (*combined)[i] = rhs[i];
103             continue;
104         }
105         if (rhs[i] == 0) {
106             (*combined)[i] = lhs[i];
107             continue;
108         }
109         NN_RET_CHECK_EQ(lhs[i], rhs[i]) << "incompatible dimension: " << i;
110         (*combined)[i] = lhs[i];
111     }
112     return true;
113 }
114 
getNumberOfElements(const Shape & shape)115 uint32_t getNumberOfElements(const Shape& shape) {
116     uint32_t count = 1;
117     for (size_t i = 0; i < shape.dimensions.size(); i++) {
118         count *= shape.dimensions[i];
119     }
120     return count;
121 }
122 
getNumberOfElements(const Shape & shape,size_t firstAxisInclusive,size_t lastAxisExclusive)123 uint32_t getNumberOfElements(const Shape& shape,
124                              size_t firstAxisInclusive,
125                              size_t lastAxisExclusive) {
126     nnAssert(0 <= firstAxisInclusive);
127     nnAssert(firstAxisInclusive <= lastAxisExclusive);
128     nnAssert(lastAxisExclusive <= shape.dimensions.size());
129     uint32_t count = 1;
130     for (size_t i = firstAxisInclusive; i < lastAxisExclusive; i++) {
131         count *= shape.dimensions[i];
132     }
133     return count;
134 }
135 
getNumberOfDimensions(const Shape & shape)136 uint32_t getNumberOfDimensions(const Shape& shape) {
137     return shape.dimensions.size();
138 }
139 
getSizeOfDimension(const Shape & shape,uint32_t dimensionIdx)140 uint32_t getSizeOfDimension(const Shape& shape, uint32_t dimensionIdx) {
141     nnAssert(0 <= dimensionIdx && dimensionIdx < shape.dimensions.size());
142     return shape.dimensions[dimensionIdx];
143 }
144 
handleNegativeAxis(int32_t numberOfDimensions,int32_t * axis)145 bool handleNegativeAxis(int32_t numberOfDimensions, int32_t* axis) {
146     NN_CHECK(-numberOfDimensions <= *axis && *axis < numberOfDimensions);
147     if (*axis < 0) {
148         *axis += numberOfDimensions;
149     }
150     return true;
151 }
152 
QuantizeMultiplier(double double_multiplier,int32_t * quantized_multiplier,int * shift)153 bool QuantizeMultiplier(double double_multiplier, int32_t* quantized_multiplier, int* shift) {
154     if (double_multiplier == 0.) {
155         *quantized_multiplier = 0;
156         *shift = 0;
157         return true;
158     }
159     const double q = std::frexp(double_multiplier, shift);
160     auto q_fixed = static_cast<int64_t>(std::round(q * (1ll << 31)));
161     NN_RET_CHECK(q_fixed <= (1ll << 31));
162     if (q_fixed == (1ll << 31)) {
163         q_fixed /= 2;
164         ++*shift;
165     }
166     NN_RET_CHECK_LE(q_fixed, std::numeric_limits<int32_t>::max());
167     *quantized_multiplier = static_cast<int32_t>(q_fixed);
168     return true;
169 }
170 
QuantizeMultiplierSmallerThanOne(double double_multiplier,int32_t * quantized_multiplier,int32_t * right_shift)171 bool QuantizeMultiplierSmallerThanOne(double double_multiplier,
172                                       int32_t* quantized_multiplier,
173                                       int32_t* right_shift) {
174     NN_OPS_CHECK(double_multiplier >= 0.);
175     NN_OPS_CHECK(double_multiplier < 1.);
176     if (double_multiplier == 0.) {
177         *quantized_multiplier = 0;
178         *right_shift = 0;
179         return true;
180     }
181     NN_OPS_CHECK(double_multiplier > 0.);
182     const double q = std::frexp(double_multiplier, right_shift);
183     *right_shift *= -1;
184     int64_t q_fixed = static_cast<int64_t>(std::round(q * (1LL << 31)));
185     NN_OPS_CHECK(q_fixed <= (1LL << 31));
186     if (q_fixed == (1LL << 31)) {
187         q_fixed /= 2;
188         --*right_shift;
189     }
190     NN_OPS_CHECK(*right_shift >= 0);
191     NN_OPS_CHECK(q_fixed <= std::numeric_limits<int32_t>::max());
192     *quantized_multiplier = static_cast<int32_t>(q_fixed);
193     return true;
194 }
195 
QuantizeMultiplierGreaterThanOne(double double_multiplier,int32_t * quantized_multiplier,int * left_shift)196 bool QuantizeMultiplierGreaterThanOne(double double_multiplier,
197                                       int32_t* quantized_multiplier,
198                                       int* left_shift) {
199     NN_OPS_CHECK(double_multiplier > 1.);
200     const double q = std::frexp(double_multiplier, left_shift);
201     int64_t q_fixed = static_cast<int64_t>(std::round(q * (1LL << 31)));
202     NN_OPS_CHECK(q_fixed <= (1LL << 31));
203     if (q_fixed == (1LL << 31)) {
204         q_fixed /= 2;
205         ++*left_shift;
206     }
207     NN_OPS_CHECK(*left_shift >= 0);
208     NN_OPS_CHECK(q_fixed <= std::numeric_limits<int32_t>::max());
209     *quantized_multiplier = static_cast<int32_t>(q_fixed);
210     return true;
211 }
212 
GetQuantizedConvolutionMultipler(const Shape & inputShape,const Shape & filterShape,const Shape & biasShape,const Shape & outputShape,double * multiplier)213 bool GetQuantizedConvolutionMultipler(const Shape& inputShape, const Shape& filterShape,
214                                       const Shape& biasShape, const Shape& outputShape,
215                                       double* multiplier) {
216     // Upcast bias and input_product to double
217     const double input_product_scale = inputShape.scale * filterShape.scale;
218     const double bias_scale = biasShape.scale;
219 
220     // The following conditions must be guaranteed by the training pipeline.
221     NN_OPS_CHECK(std::abs(input_product_scale - bias_scale) <=
222               1e-6 * std::min(input_product_scale, bias_scale));
223     NN_OPS_CHECK(input_product_scale >= 0);
224     *multiplier = input_product_scale / outputShape.scale;
225     return true;
226 }
227 
CalculateActivationRangeUint8(int32_t activation,const Shape & outputShape,int32_t * act_min,int32_t * act_max)228 void CalculateActivationRangeUint8(int32_t activation,
229                                    const Shape& outputShape,
230                                    int32_t* act_min,
231                                    int32_t* act_max) {
232     const int32_t qmin = std::numeric_limits<uint8_t>::min();
233     const int32_t qmax = std::numeric_limits<uint8_t>::max();
234 
235     const auto scale = outputShape.scale;
236     const auto zero_point = outputShape.offset;
237 
238     auto quantize = [scale, zero_point](float f) {
239         return zero_point + static_cast<int32_t>(std::round(f / scale));
240     };
241 
242     if (activation == kActivationRelu) {
243         *act_min = std::max(qmin, quantize(0.0));
244         *act_max = qmax;
245     } else if (activation == kActivationRelu6) {
246         *act_min = std::max(qmin, quantize(0.0));
247         *act_max = std::min(qmax, quantize(6.0));
248     } else if (activation == kActivationRelu1) {
249         *act_min = std::max(qmin, quantize(-1.0));
250         *act_max = std::min(qmax, quantize(1.0));
251     } else if (activation == kActivationNone){
252         *act_min = qmin;
253         *act_max = qmax;
254     } else {
255         LOG(ERROR) << "Unsupported fused activation function.";
256     }
257 }
258 
CalculateActivationRangeFloat(int32_t activation,float * activation_min,float * activation_max)259 void CalculateActivationRangeFloat(int32_t activation,
260                                    float* activation_min,
261                                    float* activation_max) {
262     if (activation == kActivationRelu) {
263         *activation_min = 0.f;
264         *activation_max = std::numeric_limits<float>::max();
265     } else if (activation == kActivationRelu6) {
266         *activation_min = 0.f;
267         *activation_max = 6.f;
268     } else if (activation == kActivationRelu1) {
269         *activation_min = -1.f;
270         *activation_max = 1.f;
271     } else if (activation == kActivationNone){
272         *activation_min = std::numeric_limits<float>::lowest();
273         *activation_max = std::numeric_limits<float>::max();
274     } else {
275         LOG(ERROR) << "Unsupported fused activation function.";
276     }
277 }
278 
CalculateInputRadius(int input_integer_bits,int input_left_shift)279 int32_t CalculateInputRadius(int input_integer_bits, int input_left_shift) {
280     const double max_input_rescaled = 1.0 * ((1 << input_integer_bits) - 1) *
281                                       (1LL << (31 - input_integer_bits)) /
282                                       (1LL << input_left_shift);
283     // Tighten bound using floor.  Suppose that we could use the exact value.
284     // After scaling the difference, the result would be at the maximum.  Thus we
285     // must ensure that our value has lower magnitude.
286     return static_cast<int32_t>(std::floor(max_input_rescaled));
287 }
288 
calculateExplicitPaddingImpl(int32_t in_size,int32_t stride,int32_t dilation_factor,int32_t filter_size,int32_t padding_implicit,bool isTransposeConv,int32_t * padding_head,int32_t * padding_tail)289 void calculateExplicitPaddingImpl(int32_t in_size, int32_t stride, int32_t dilation_factor,
290                                   int32_t filter_size, int32_t padding_implicit,
291                                   bool isTransposeConv, int32_t* padding_head,
292                                   int32_t* padding_tail) {
293     *padding_head = 0;
294     *padding_tail = 0;
295 
296     int32_t effective_filter_size = (filter_size - 1) * dilation_factor + 1;
297 
298     if (padding_implicit == kPaddingSame) {
299         int32_t out_size = (in_size + stride - 1) / stride;
300         int32_t tmp = (out_size - 1) * stride + effective_filter_size;
301         if (tmp > in_size) {
302             *padding_head = (tmp - in_size) / 2;
303             *padding_tail = (tmp - in_size) - *padding_head;
304         }
305         // For transpose conv, make padding tail fit tightly to the end of the last stride.
306         if (isTransposeConv) {
307             *padding_tail = (tmp - in_size) - *padding_head;
308         }
309     }
310 }
311 
calculateBroadcastedShape(const Shape & in1,const Shape & in2,Shape * out)312 bool calculateBroadcastedShape(const Shape& in1, const Shape& in2, Shape* out) {
313     NN_RET_CHECK(in1.type == in2.type);
314     uint32_t numberOfDims1 = getNumberOfDimensions(in1);
315     uint32_t numberOfDims2 = getNumberOfDimensions(in2);
316     uint32_t maxDims = std::max(numberOfDims1, numberOfDims2);
317     out->dimensions = std::vector<uint32_t>(maxDims);
318     for (uint32_t i = 1; i <= maxDims; i++) {
319         uint32_t dim1 = 1;
320         if (i <= numberOfDims1) {
321             dim1 = getSizeOfDimension(in1, numberOfDims1 - i);
322         }
323         uint32_t dim2 = 1;
324         if (i <= numberOfDims2) {
325             dim2 = getSizeOfDimension(in2, numberOfDims2 - i);
326         }
327         if (dim1 != dim2 && dim1 != 1 && dim2 != 1) {
328             LOG(ERROR) << "Dimensions mismatch for broadcast:\n"
329                        << "First tensor: dimension " << numberOfDims1 - i << " of size " << dim1
330                        << "\nSecond tensor: dimension " << numberOfDims2 - i << "of size " << dim2;
331             return false;
332         }
333         out->dimensions[maxDims - i] = (dim1 == 1) ? dim2 : dim1;
334     }
335     return true;
336 }
337 
requantize(uint8_t value,const Shape & oldShape,const Shape & newShape)338 uint8_t requantize(uint8_t value, const Shape& oldShape, const Shape& newShape) {
339     double doubleValue = (value - oldShape.offset) * oldShape.scale;
340     double doubleRet = doubleValue / newShape.scale + newShape.offset;
341     if (doubleRet < 0) return 0;
342     if (doubleRet > 255) return 255;
343     return static_cast<uint8_t>(std::round(doubleRet));
344 }
345 
floorPrepare(const Shape & input,Shape * output)346 bool floorPrepare(const Shape& input, Shape* output) {
347     return SetShape(input, output);
348 }
349 
depthwiseConvPrepare(const Shape & input,const Shape & filter,const Shape & bias,int32_t padding_left,int32_t padding_right,int32_t padding_top,int32_t padding_bottom,int32_t stride_width,int32_t stride_height,int32_t depth_multiplier,int32_t dilation_width_factor,int32_t dilation_height_factor,Shape * output)350 bool depthwiseConvPrepare(const Shape& input, const Shape& filter, const Shape& bias,
351                           int32_t padding_left, int32_t padding_right, int32_t padding_top,
352                           int32_t padding_bottom, int32_t stride_width, int32_t stride_height,
353                           int32_t depth_multiplier, int32_t dilation_width_factor,
354                           int32_t dilation_height_factor, Shape* output) {
355     if (filter.type == OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL) {
356         NN_OPS_CHECK(input.type == OperandType::TENSOR_QUANT8_ASYMM);
357     } else {
358         NN_OPS_CHECK(input.type == filter.type);
359     }
360     if (input.type == OperandType::TENSOR_QUANT8_ASYMM) {
361         NN_OPS_CHECK(bias.type == OperandType::TENSOR_INT32);
362     } else {
363         NN_OPS_CHECK(input.type == bias.type);
364     }
365     NN_OPS_CHECK(getNumberOfDimensions(input) == 4);
366     NN_OPS_CHECK(getNumberOfDimensions(filter) == 4);
367     NN_OPS_CHECK(getNumberOfDimensions(bias) == 1);
368 
369     NN_OPS_CHECK(getSizeOfDimension(filter, 3) == getSizeOfDimension(bias, 0));
370 
371     uint32_t channels_out = getSizeOfDimension(filter, 3);
372     uint32_t channels_in = getSizeOfDimension(input, 3);
373     uint32_t width        = getSizeOfDimension(input, 2);
374     uint32_t height       = getSizeOfDimension(input, 1);
375     uint32_t filterWidth  = getSizeOfDimension(filter, 2);
376     uint32_t filterHeight = getSizeOfDimension(filter, 1);
377     uint32_t batches      = getSizeOfDimension(input, 0);
378 
379     NN_OPS_CHECK(depth_multiplier * channels_in == channels_out);
380     int32_t effectiveFilterWidth = (filterWidth - 1) * dilation_width_factor + 1;
381     int32_t effectiveFilterHeight = (filterHeight - 1) * dilation_height_factor + 1;
382     NN_RET_CHECK_GT(effectiveFilterWidth, padding_left);
383     NN_RET_CHECK_GT(effectiveFilterWidth, padding_right);
384     NN_RET_CHECK_GT(effectiveFilterHeight, padding_top);
385     NN_RET_CHECK_GT(effectiveFilterHeight, padding_bottom);
386 
387     uint32_t outWidth = computeOutSize(width, filterWidth, stride_width, dilation_width_factor,
388                                        padding_left, padding_right);
389     uint32_t outHeight = computeOutSize(height, filterHeight, stride_height, dilation_height_factor,
390                                         padding_top, padding_bottom);
391 
392     output->type = input.type;
393     output->dimensions = {batches, outHeight, outWidth, channels_out};
394     return true;
395 }
396 
genericActivationPrepare(const Shape & input,Shape * output)397 bool genericActivationPrepare(const Shape& input,
398                               Shape* output) {
399     NN_OPS_CHECK(getNumberOfDimensions(input) <= 4);
400     return SetShape(input, output);
401 }
402 
genericNormalizationPrepare(const Shape & input,Shape * output)403 bool genericNormalizationPrepare(const Shape& input, Shape* output) {
404     return SetShape(input, output);
405 }
406 
reshapePrepare(const Shape & input,const int32_t * targetDims,const int32_t targetDimsSize,Shape * output)407 bool reshapePrepare(const Shape& input,
408                     const int32_t* targetDims,
409                     const int32_t targetDimsSize,
410                     Shape* output) {
411     // Reshape allows one of the targetDims components to have the
412     // special -1 value, meaning it will be calculated automatically based on the
413     // input. Here we calculate what that dimension should be so that the number
414     // of output elements in the same as the number of input elements.
415     int32_t numInputElements = (int32_t) getNumberOfElements(input);
416 
417     std::vector<uint32_t> outDims(targetDimsSize);
418     int32_t numOutputElements = 1;
419     int32_t strechDim = -1;
420     for (int32_t i = 0; i < targetDimsSize; ++i) {
421         int32_t value = targetDims[i];
422         if (value == -1) {
423             NN_OPS_CHECK(strechDim == -1);
424             strechDim = i;
425         } else {
426             numOutputElements *= value;
427             outDims[i] = (uint32_t)value;
428         }
429     }
430     if (strechDim != -1) {
431         int32_t strechValue = numInputElements / numOutputElements;
432         outDims[strechDim] = (uint32_t) strechValue;
433         numOutputElements *= strechValue;
434     }
435 
436     NN_OPS_CHECK(numInputElements == numOutputElements);
437 
438     output->type = input.type;
439     output->dimensions = outDims;
440     output->offset = input.offset;
441     output->scale = input.scale;
442 
443     return true;
444 }
445 
depthToSpacePrepare(const Shape & input,int32_t blockSize,Shape * output)446 bool depthToSpacePrepare(const Shape& input,
447                          int32_t blockSize,
448                          Shape* output) {
449     NN_OPS_CHECK(getNumberOfDimensions(input) == 4);
450     NN_OPS_CHECK(blockSize > 0);
451 
452     uint32_t batches  = getSizeOfDimension(input, 0);
453     uint32_t height   = getSizeOfDimension(input, 1);
454     uint32_t width    = getSizeOfDimension(input, 2);
455     uint32_t channels = getSizeOfDimension(input, 3);
456 
457     NN_OPS_CHECK(channels % (blockSize * blockSize) == 0);
458     output->type = input.type;
459     output->dimensions = {batches,
460                           height * blockSize,
461                           width * blockSize,
462                           channels / (blockSize * blockSize)};
463     output->offset = input.offset;
464     output->scale = input.scale;
465 
466     return true;
467 }
468 
spaceToDepthPrepare(const Shape & input,int32_t blockSize,Shape * output)469 bool spaceToDepthPrepare(const Shape& input,
470                          int32_t blockSize,
471                          Shape* output) {
472     NN_OPS_CHECK(getNumberOfDimensions(input) == 4);
473     NN_OPS_CHECK(blockSize > 0);
474 
475     uint32_t batches  = getSizeOfDimension(input, 0);
476     uint32_t height   = getSizeOfDimension(input, 1);
477     uint32_t width    = getSizeOfDimension(input, 2);
478     uint32_t channels = getSizeOfDimension(input, 3);
479 
480     NN_OPS_CHECK(height % blockSize == 0);
481     NN_OPS_CHECK(width % blockSize == 0);
482 
483     output->type = input.type;
484     output->dimensions = {batches,
485                           height / blockSize,
486                           width / blockSize,
487                           channels * (blockSize * blockSize)};
488     output->offset = input.offset;
489     output->scale = input.scale;
490 
491     return true;
492 }
493 
embeddingLookupPrepare(const Shape & valueShape,const Shape & lookupShape,Shape * outputShape)494 bool embeddingLookupPrepare(const Shape &valueShape,
495                             const Shape &lookupShape,
496                             Shape *outputShape) {
497     NN_OPS_CHECK(getNumberOfDimensions(valueShape) >= 2);
498     NN_OPS_CHECK(getNumberOfDimensions(lookupShape) == 1);
499 
500     const uint32_t rows     = getSizeOfDimension(valueShape, 0);
501     const uint32_t columns  = getSizeOfDimension(valueShape, 1);
502 
503     const uint32_t lookups  = getSizeOfDimension(lookupShape, 0);
504 
505     outputShape->type = valueShape.type;
506     outputShape->dimensions = { lookups, columns };
507     for (uint32_t i = 2; i < getNumberOfDimensions(valueShape); i++) {
508         outputShape->dimensions.push_back(getSizeOfDimension(valueShape, i));
509     }
510     outputShape->offset = valueShape.offset;
511     outputShape->scale = valueShape.scale;
512 
513     return true;
514 }
515 
hashtableLookupPrepare(const Shape & lookupShape,const Shape & keyShape,const Shape & valueShape,Shape * outputShape,Shape * hitShape)516 bool hashtableLookupPrepare(const Shape &lookupShape,
517                             const Shape &keyShape,
518                             const Shape &valueShape,
519                             Shape *outputShape,
520                             Shape *hitShape) {
521     NN_OPS_CHECK(getNumberOfDimensions(lookupShape) == 1);
522     NN_OPS_CHECK(getNumberOfDimensions(keyShape) == 1);
523     NN_OPS_CHECK(getNumberOfDimensions(valueShape) >= 1);
524 
525     const uint32_t lookups  = getSizeOfDimension(lookupShape, 0);
526     const uint32_t keys     = getSizeOfDimension(keyShape, 0);
527     const uint32_t rows     = getSizeOfDimension(valueShape, 0);
528     outputShape->type = valueShape.type;
529     outputShape->dimensions = { lookups };
530     for (uint32_t i = 1; i < getNumberOfDimensions(valueShape); i++) {
531         outputShape->dimensions.push_back(getSizeOfDimension(valueShape, i));
532     }
533     outputShape->offset = valueShape.offset;
534     outputShape->scale = valueShape.scale;
535 
536     hitShape->type = OperandType::TENSOR_QUANT8_ASYMM;
537     hitShape->dimensions = { lookups };
538     hitShape->offset = 0;
539     hitShape->scale = 1.f;
540 
541     return true;
542 }
543 
padPrepare(const Shape & input,const int32_t * paddingsData,const Shape & paddingsShape,Shape * output)544 bool padPrepare(const Shape& input,
545                 const int32_t* paddingsData,
546                 const Shape& paddingsShape,
547                 Shape* output) {
548     uint32_t numInputDims = getNumberOfDimensions(input);
549 
550     // paddings need to be provided as a 2-D int32 tensor.
551     NN_OPS_CHECK(paddingsShape.type == OperandType::TENSOR_INT32);
552     NN_OPS_CHECK(getNumberOfDimensions(paddingsShape) == 2);
553     NN_OPS_CHECK(getSizeOfDimension(paddingsShape, 0) == numInputDims);
554     NN_OPS_CHECK(getSizeOfDimension(paddingsShape, 1) == 2);
555 
556     std::vector<uint32_t> outDims(numInputDims);
557     for (uint32_t i = 0; i < numInputDims; ++i) {
558         int32_t beforePadding = *paddingsData++;
559         int32_t afterPadding = *paddingsData++;
560         // Pad value has to be greater than equal to 0.
561         NN_OPS_CHECK(beforePadding >= 0 && afterPadding >= 0);
562         outDims[i] = beforePadding + getSizeOfDimension(input, i) + afterPadding;
563     }
564     output->type = input.type;
565     output->dimensions = outDims;
566     output->offset = input.offset;
567     output->scale = input.scale;
568 
569     return true;
570 }
571 
batchToSpacePrepare(const Shape & input,const int32_t * blockSizeData,const Shape & blockSizeShape,Shape * output)572 bool batchToSpacePrepare(const Shape& input,
573                          const int32_t* blockSizeData,
574                          const Shape& blockSizeShape,
575                          Shape* output) {
576     // Only 4D NHWC tensors are supported.
577     NN_OPS_CHECK(getNumberOfDimensions(input) == 4);
578 
579     // blockSize need to be provided as a 1-D int32 tensor.
580     NN_OPS_CHECK(blockSizeShape.type == OperandType::TENSOR_INT32);
581     NN_OPS_CHECK(getNumberOfDimensions(blockSizeShape) == 1);
582     // Only applies to spatial dimensions.
583     NN_OPS_CHECK(getSizeOfDimension(blockSizeShape, 0) == 2);
584 
585     uint32_t batches  = getSizeOfDimension(input, 0);
586     uint32_t height   = getSizeOfDimension(input, 1);
587     uint32_t width    = getSizeOfDimension(input, 2);
588     uint32_t channels = getSizeOfDimension(input, 3);
589 
590     NN_OPS_CHECK(batches % (blockSizeData[0] * blockSizeData[1]) == 0);
591     output->type = input.type;
592     output->dimensions = {batches / (blockSizeData[0] * blockSizeData[1]),
593                           height * blockSizeData[0],
594                           width * blockSizeData[1],
595                           channels};
596     output->offset = input.offset;
597     output->scale = input.scale;
598 
599     return true;
600 }
601 
spaceToBatchPrepare(const Shape & input,const int32_t * blockSizeData,const Shape & blockSizeShape,const int32_t * paddingsData,const Shape & paddingsShape,Shape * output)602 bool spaceToBatchPrepare(const Shape& input,
603                          const int32_t* blockSizeData,
604                          const Shape& blockSizeShape,
605                          const int32_t* paddingsData,
606                          const Shape& paddingsShape,
607                          Shape* output) {
608     // Only 4D NHWC tensors are supported.
609     NN_OPS_CHECK(getNumberOfDimensions(input) == 4);
610 
611     // blockSize need to be provided as a 1-D int32 tensor.
612     NN_OPS_CHECK(blockSizeShape.type == OperandType::TENSOR_INT32);
613     NN_OPS_CHECK(getNumberOfDimensions(blockSizeShape) == 1);
614     // Only applies to spatial dimensions.
615     NN_OPS_CHECK(getSizeOfDimension(blockSizeShape, 0) == 2);
616 
617     // paddings need to be provided as a 2-D int32 tensor.
618     NN_OPS_CHECK(paddingsShape.type == OperandType::TENSOR_INT32);
619     NN_OPS_CHECK(getNumberOfDimensions(paddingsShape) == 2);
620     NN_OPS_CHECK(getSizeOfDimension(paddingsShape, 0) == 2);
621     NN_OPS_CHECK(getSizeOfDimension(paddingsShape, 1) == 2);
622 
623     uint32_t batches  = getSizeOfDimension(input, 0);
624     uint32_t height   = getSizeOfDimension(input, 1);
625     uint32_t width    = getSizeOfDimension(input, 2);
626     uint32_t channels = getSizeOfDimension(input, 3);
627 
628     uint32_t paddedHeight = paddingsData[0] + height + paddingsData[1];
629     uint32_t paddedWidth = paddingsData[2] + width + paddingsData[3];
630 
631     NN_OPS_CHECK(paddedHeight % blockSizeData[0] == 0);
632     NN_OPS_CHECK(paddedWidth % blockSizeData[1] == 0);
633 
634     output->type = input.type;
635     output->dimensions = {batches * (blockSizeData[0] * blockSizeData[1]),
636                           paddedHeight / blockSizeData[0],
637                           paddedWidth / blockSizeData[1],
638                           channels};
639     output->offset = input.offset;
640     output->scale = input.scale;
641 
642     return true;
643 }
644 
squeezePrepare(const Shape & input,const int32_t * squeezeDims,const Shape & squeezeDimsShape,Shape * output)645 bool squeezePrepare(const Shape& input,
646                     const int32_t* squeezeDims,
647                     const Shape& squeezeDimsShape,
648                     Shape* output) {
649     int32_t numInputDims = static_cast<int32_t>(getNumberOfDimensions(input));
650 
651     // squeezeDims need to be provided as a 1-D int32 tensor.
652     NN_OPS_CHECK(squeezeDimsShape.type == OperandType::TENSOR_INT32);
653     NN_OPS_CHECK(getNumberOfDimensions(squeezeDimsShape) == 1);
654 
655     int32_t squeezeDimsSize = static_cast<int32_t>(getSizeOfDimension(squeezeDimsShape, 0));
656     std::vector<bool> shouldSqueeze(numInputDims, false);
657     int32_t numDimsSqueezed = 0;
658 
659     if (squeezeDimsSize == 0) {
660         // If squeezeDimsSize is 0, all dims with value 1 will be squeezed.
661         for (int32_t idx = 0; idx < numInputDims; ++idx) {
662             if (getSizeOfDimension(input, idx) == 1) {
663                 shouldSqueeze[idx] = true;
664                 ++numDimsSqueezed;
665             }
666         }
667     } else {
668         for (int32_t idx = 0; idx < squeezeDimsSize; ++idx) {
669             int32_t current = squeezeDims[idx] < 0 ? squeezeDims[idx] + numInputDims
670                                                : squeezeDims[idx];
671             NN_OPS_CHECK(current >= 0 && current < numInputDims &&
672                          getSizeOfDimension(input, current) == 1);
673             if (!shouldSqueeze[current]) ++numDimsSqueezed;
674             shouldSqueeze[current] = true;
675       }
676     }
677 
678     // Sets output dimensions.
679     std::vector<uint32_t> outDims(numInputDims - numDimsSqueezed);
680     for (int32_t inIdx = 0, outIdx = 0; inIdx < numInputDims; ++inIdx) {
681         if (!shouldSqueeze[inIdx]) {
682             outDims[outIdx++] = getSizeOfDimension(input, inIdx);
683         }
684     }
685 
686     output->type = input.type;
687     output->dimensions = outDims;
688     output->offset = input.offset;
689     output->scale = input.scale;
690 
691     return true;
692 }
693 
meanPrepare(const Shape & input,const int32_t * axisData,const Shape & axisShape,bool keepDims,Shape * output)694 bool meanPrepare(const Shape& input,
695                  const int32_t* axisData,
696                  const Shape& axisShape,
697                  bool keepDims,
698                  Shape* output) {
699 
700     // perm need to be provided as a 1-D int32 tensor.
701     NN_OPS_CHECK(axisShape.type == OperandType::TENSOR_INT32);
702     NN_OPS_CHECK(getNumberOfDimensions(axisShape) == 1);
703 
704     int32_t numInputDims = static_cast<int32_t>(getNumberOfDimensions(input));
705     int32_t axisSize = static_cast<int32_t>(getSizeOfDimension(axisShape, 0));
706 
707     // Determines size of output tensor.
708     if (keepDims) {
709         std::vector<uint32_t> outDims(numInputDims);
710         for (int32_t idx = 0; idx < numInputDims; ++idx) {
711             bool isAxis = false;
712             for (int32_t axisIdx = 0; axisIdx < axisSize; ++axisIdx) {
713                 if (axisData[axisIdx] == idx || axisData[axisIdx] + numInputDims == idx) {
714                     isAxis = true;
715                     break;
716                 }
717             }
718             if (isAxis) {
719                 outDims[idx] = 1;
720             } else {
721                 outDims[idx] = getSizeOfDimension(input, idx);
722             }
723         }
724         output->dimensions = outDims;
725     } else {
726         // Calculates size of reducing axis.
727         int32_t numReduceAxis = axisSize;
728         for (int32_t i = 0; i < axisSize; ++i) {
729             int32_t current = axisData[i];
730             if (current < 0) {
731                 current += numInputDims;
732             }
733             NN_OPS_CHECK(current >= 0 && current < numInputDims);
734             for (int32_t j = 0; j < i; ++j) {
735                 int32_t previous = axisData[j];
736                 if (previous < 0) {
737                     previous += numInputDims;
738                 }
739                 if (current == previous) {
740                     --numReduceAxis;
741                     break;
742                 }
743             }
744         }
745         // Determines output dimensions.
746         std::vector<uint32_t> outDims(numInputDims - numReduceAxis);
747         int32_t numSkipAxis = 0;
748         for (int32_t idx = 0; idx < numInputDims; ++idx) {
749             bool isAxis = false;
750             for (int32_t axisIdx = 0; axisIdx < axisSize; ++axisIdx) {
751                 if (axisData[axisIdx] == idx || axisData[axisIdx] + numInputDims == idx) {
752                     ++numSkipAxis;
753                     isAxis = true;
754                     break;
755                 }
756             }
757             if (!isAxis) {
758                 outDims[idx - numSkipAxis] = getSizeOfDimension(input, idx);
759             }
760         }
761         output->dimensions = outDims;
762     }
763 
764     output->type = input.type;
765     output->offset = input.offset;
766     output->scale = input.scale;
767 
768     return true;
769 }
770 
stridedSlicePrepare(const Shape & input,const int32_t * beginData,const Shape & beginShape,const int32_t * endData,const Shape & endShape,const int32_t * stridesData,const Shape & stridesShape,int32_t beginMask,int32_t endMask,int32_t shrinkAxisMask,Shape * output)771 bool stridedSlicePrepare(const Shape& input,
772                          const int32_t* beginData, const Shape& beginShape,
773                          const int32_t* endData, const Shape& endShape,
774                          const int32_t* stridesData, const Shape& stridesShape,
775                          int32_t beginMask, int32_t endMask, int32_t shrinkAxisMask,
776                          Shape* output) {
777     uint32_t numInputDims = getNumberOfDimensions(input);
778     // StridedSlice op only supports 1D-4D input arrays.
779     NN_OPS_CHECK(numInputDims <= 4);
780 
781     NN_OPS_CHECK(getNumberOfDimensions(beginShape) == 1);
782     NN_OPS_CHECK(getNumberOfDimensions(endShape) == 1);
783     NN_OPS_CHECK(getNumberOfDimensions(stridesShape) == 1);
784 
785     NN_OPS_CHECK(getSizeOfDimension(beginShape, 0) == numInputDims);
786     NN_OPS_CHECK(getSizeOfDimension(endShape, 0) == numInputDims);
787     NN_OPS_CHECK(getSizeOfDimension(stridesShape, 0) == numInputDims);
788 
789     NN_OPS_CHECK(beginShape.type == OperandType::TENSOR_INT32);
790     NN_OPS_CHECK(endShape.type == OperandType::TENSOR_INT32);
791     NN_OPS_CHECK(stridesShape.type == OperandType::TENSOR_INT32);
792 
793     // Determine size of output tensor and map indices
794     std::vector<uint32_t> outDims;
795     for (int32_t idx = 0; idx < static_cast<int32_t>(numInputDims); idx++) {
796       int32_t dim = static_cast<int32_t>(getSizeOfDimension(input, idx));
797       int32_t stride = stridesData[idx];
798       // stride value has to be non-zero
799       NN_OPS_CHECK(stride != 0);
800       bool positiveStride = stride > 0;
801 
802       int32_t begin = beginMask & (1 << idx)
803               ? positiveStride ? 0 : dim - 1
804               : ClampedIndex(beginData[idx], dim, positiveStride);
805       int32_t end = endMask & (1 << idx)
806               ? positiveStride ? dim : -1
807               : ClampedIndex(endData[idx], dim, positiveStride);
808 
809       // This is valid for both positive and negative strides
810       int32_t outDim = ceil((end - begin) / static_cast<float>(stride));
811       outDim = outDim < 0 ? 0 : static_cast<uint32_t>(outDim);
812       if (!(shrinkAxisMask & (1 << idx))) {
813           outDims.push_back(outDim);
814       } else {
815           if (outDim != 1) {
816               LOG(ERROR) << "Outdim " << idx << " is " << outDim << ", expected 1";
817               NN_OPS_CHECK(outDim == 1);
818           }
819       }
820     }
821 
822     output->type = input.type;
823     output->dimensions = outDims;
824     output->offset = input.offset;
825     output->scale = input.scale;
826 
827     return true;
828 }
829 
argMinMaxPrepare(const Shape & input,int32_t axis,Shape * output)830 bool argMinMaxPrepare(const Shape& input, int32_t axis, Shape* output) {
831     NN_CHECK(handleNegativeAxis(input, &axis));
832 
833     output->type = OperandType::TENSOR_INT32;
834 
835     // Copy the input dimensions, omitting the axis dimension.
836     output->dimensions.clear();
837     output->dimensions.reserve(getNumberOfDimensions(input) - 1);
838     output->dimensions.insert(output->dimensions.end(),
839                               input.dimensions.begin(),
840                               input.dimensions.begin() + axis);
841     output->dimensions.insert(output->dimensions.end(),
842                               input.dimensions.begin() + axis + 1,
843                               input.dimensions.end());
844 
845     return true;
846 }
847 
splitPrepare(const Shape & input,int32_t axis,int32_t numOutputs,std::vector<Shape> * output)848 bool splitPrepare(const Shape& input, int32_t axis, int32_t numOutputs,
849                   std::vector<Shape>* output) {
850     NN_CHECK(handleNegativeAxis(input, &axis));
851 
852     const int32_t sizeOfAxisToSplit = input.dimensions[axis];
853     NN_OPS_CHECK(sizeOfAxisToSplit % numOutputs == 0);
854     const int32_t sliceSize = sizeOfAxisToSplit / numOutputs;
855 
856     for (int i = 0; i < numOutputs; ++i) {
857         output->at(i).type = input.type;
858         output->at(i).dimensions = input.dimensions;
859         output->at(i).dimensions[axis] = sliceSize;
860         output->at(i).offset = input.offset;
861         output->at(i).scale = input.scale;
862     }
863     return true;
864 }
865 
groupedConvPrepare(const Shape & input,const Shape & filter,const Shape & bias,int32_t padding_left,int32_t padding_right,int32_t padding_top,int32_t padding_bottom,int32_t stride_width,int32_t stride_height,int32_t numGroups,Shape * output)866 bool groupedConvPrepare(const Shape& input, const Shape& filter, const Shape& bias,
867                         int32_t padding_left, int32_t padding_right, int32_t padding_top,
868                         int32_t padding_bottom, int32_t stride_width, int32_t stride_height,
869                         int32_t numGroups, Shape* output) {
870     if (filter.type == OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL) {
871         NN_OPS_CHECK(input.type == OperandType::TENSOR_QUANT8_ASYMM);
872     } else {
873         NN_OPS_CHECK(input.type == filter.type);
874     }
875     if (input.type == OperandType::TENSOR_QUANT8_ASYMM) {
876         NN_OPS_CHECK(bias.type == OperandType::TENSOR_INT32);
877     } else {
878         NN_OPS_CHECK(input.type == bias.type);
879     }
880     NN_OPS_CHECK(getNumberOfDimensions(input) == 4);
881     NN_OPS_CHECK(getNumberOfDimensions(filter) == 4);
882     NN_OPS_CHECK(getNumberOfDimensions(bias) == 1);
883 
884     NN_OPS_CHECK(getSizeOfDimension(filter, 0) == getSizeOfDimension(bias, 0));
885 
886     NN_OPS_CHECK(getSizeOfDimension(filter, 3) * numGroups == getSizeOfDimension(input, 3));
887     NN_OPS_CHECK(getSizeOfDimension(filter, 0) % numGroups == 0);
888 
889     uint32_t channels_out = getSizeOfDimension(filter, 0);
890     uint32_t width = getSizeOfDimension(input, 2);
891     uint32_t height = getSizeOfDimension(input, 1);
892     uint32_t filterWidth = getSizeOfDimension(filter, 2);
893     uint32_t filterHeight = getSizeOfDimension(filter, 1);
894     uint32_t batches = getSizeOfDimension(input, 0);
895 
896     NN_RET_CHECK_GT(static_cast<int32_t>(filterWidth), padding_left);
897     NN_RET_CHECK_GT(static_cast<int32_t>(filterWidth), padding_right);
898     NN_RET_CHECK_GT(static_cast<int32_t>(filterHeight), padding_top);
899     NN_RET_CHECK_GT(static_cast<int32_t>(filterHeight), padding_bottom);
900 
901     uint32_t outWidth =
902             computeOutSize(width, filterWidth, stride_width, padding_left, padding_right);
903     uint32_t outHeight =
904             computeOutSize(height, filterHeight, stride_height, padding_top, padding_bottom);
905 
906     output->type = input.type;
907     output->dimensions = {batches, outHeight, outWidth, channels_out};
908     return true;
909 }
910 
911 } // namespace nn
912 } // namespace android
913