1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "CpuOperationUtils.h"
18 #include "OperationResolver.h"
19 #include "Operations.h"
20
21 #include "Utils.h"
22 #include "tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h"
23
24 #include "Tracing.h"
25
26 namespace android {
27 namespace nn {
28 namespace conv_2d {
29
30 constexpr char kOperationName[] = "CONV_2D";
31
32 constexpr uint32_t kInputTensor = 0;
33 constexpr uint32_t kFilterTensor = 1;
34 constexpr uint32_t kBiasTensor = 2;
35
36 constexpr uint32_t kNumOutputs = 1;
37 constexpr uint32_t kOutputTensor = 0;
38
39 namespace {
40
41 // If possible we will use this static buffer for the tensor.
42 constexpr size_t kStaticBufferSize = 1605632;
43 char static_scratch_buffer[kStaticBufferSize];
44
45 // executionMutex is used to protect concurrent access of the static_scratch_buffer
46 // and other non-threadsafe resources like gemmlowp::GemmContext.
47 // std::mutex is safe for pthreads on Android.
48 std::mutex executionMutex;
49
50 struct Conv2dParam {
51 int32_t padding_left, padding_right;
52 int32_t padding_top, padding_bottom;
53 int32_t stride_width, stride_height;
54 int32_t dilation_width_factor = 1, dilation_height_factor = 1;
55 int32_t activation;
56 bool useNchw = false;
57
initializeandroid::nn::conv_2d::__anon14bee14a0111::Conv2dParam58 bool initialize(const IOperationExecutionContext* context) {
59 uint32_t inCount = context->getNumInputs();
60 int32_t padding_implicit = 0;
61 bool useImplicitPadding = false;
62 if ((inCount >= 8 && context->getInputType(7) == OperandType::BOOL) || inCount == 7) {
63 padding_implicit = context->getInputValue<int32_t>(3);
64 stride_width = context->getInputValue<int32_t>(4);
65 stride_height = context->getInputValue<int32_t>(5);
66 activation = context->getInputValue<int32_t>(6);
67 if (inCount >= 8) {
68 useNchw = context->getInputValue<bool>(7);
69 }
70 if (inCount == 10) {
71 dilation_width_factor = context->getInputValue<int32_t>(8);
72 dilation_height_factor = context->getInputValue<int32_t>(9);
73 }
74 useImplicitPadding = true;
75 } else if (inCount >= 10 && context->getInputType(7) == OperandType::INT32) {
76 padding_left = context->getInputValue<int32_t>(3);
77 padding_right = context->getInputValue<int32_t>(4);
78 padding_top = context->getInputValue<int32_t>(5);
79 padding_bottom = context->getInputValue<int32_t>(6);
80 stride_width = context->getInputValue<int32_t>(7);
81 stride_height = context->getInputValue<int32_t>(8);
82 activation = context->getInputValue<int32_t>(9);
83 if (inCount >= 11) {
84 useNchw = context->getInputValue<bool>(10);
85 }
86 if (inCount == 13) {
87 dilation_width_factor = context->getInputValue<int32_t>(11);
88 dilation_height_factor = context->getInputValue<int32_t>(12);
89 }
90 } else {
91 NN_RET_CHECK_FAIL() << "Unsupported input spec for operation " << kOperationName;
92 }
93 if (useImplicitPadding) {
94 Shape inputShape = context->getInputShape(kInputTensor);
95 Shape filterShape = context->getInputShape(kFilterTensor);
96 int32_t input_width = getSizeOfDimension(inputShape, useNchw ? 3 : 2);
97 int32_t input_height = getSizeOfDimension(inputShape, useNchw ? 2 : 1);
98 int32_t filter_width = getSizeOfDimension(filterShape, 2);
99 int32_t filter_height = getSizeOfDimension(filterShape, 1);
100 calculateExplicitPadding(input_width, stride_width, dilation_width_factor, filter_width,
101 padding_implicit, &padding_left, &padding_right);
102 calculateExplicitPadding(input_height, stride_height, dilation_height_factor,
103 filter_height, padding_implicit, &padding_top,
104 &padding_bottom);
105 }
106 NN_RET_CHECK_GE(padding_left, 0);
107 NN_RET_CHECK_GE(padding_right, 0);
108 NN_RET_CHECK_GE(padding_top, 0);
109 NN_RET_CHECK_GE(padding_bottom, 0);
110 NN_RET_CHECK_GT(stride_width, 0);
111 NN_RET_CHECK_GT(stride_height, 0);
112 NN_RET_CHECK_GT(dilation_width_factor, 0);
113 NN_RET_CHECK_GT(dilation_height_factor, 0);
114 NN_RET_CHECK_GE(activation, 0);
115 return true;
116 }
117 };
118
119 #define ANDROID_NN_CONV_PARAMETERS(Type) \
120 uint32_t height = getSizeOfDimension(inputShape, 1); \
121 uint32_t width = getSizeOfDimension(inputShape, 2); \
122 uint32_t filterHeight = getSizeOfDimension(filterShape, 1); \
123 uint32_t filterWidth = getSizeOfDimension(filterShape, 2); \
124 uint32_t outHeight = getSizeOfDimension(outputShape, 1); \
125 uint32_t outWidth = getSizeOfDimension(outputShape, 2); \
126 uint32_t inDepth = getSizeOfDimension(inputShape, 3); \
127 \
128 uint32_t paddingHeight = (uint32_t)padding_top; \
129 uint32_t paddingWidth = (uint32_t)padding_left; \
130 \
131 tflite::Dims<4> im2colDim; \
132 im2colDim.sizes[3] = (int)getSizeOfDimension(outputShape, 0); \
133 im2colDim.sizes[2] = (int)getSizeOfDimension(outputShape, 1); \
134 im2colDim.sizes[1] = (int)getSizeOfDimension(outputShape, 2); \
135 im2colDim.sizes[0] = (int)inDepth * filterHeight * filterWidth; \
136 \
137 im2colDim.strides[0] = 1; \
138 for (int i=1; i<4; i++) { \
139 im2colDim.strides[i] = im2colDim.strides[i-1] * im2colDim.sizes[i-1]; \
140 } \
141 \
142 Type* im2colData = nullptr; \
143 uint64_t im2colByteSize = sizeof(Type); \
144 std::unique_ptr<Type[]> im2colGuard; \
145 for (int i=0; i<4; i++) { \
146 im2colByteSize *= im2colDim.sizes[i]; \
147 } \
148 /* http://b/77982879, tflite::optimized_ops::Conv uses int for offsets */ \
149 if (im2colByteSize >= 0x7fffffff) { \
150 LOG(ERROR) << "Conv size is too large, not enough memory"; \
151 return false; \
152 } \
153 if (im2colByteSize <= kStaticBufferSize) { \
154 im2colData = reinterpret_cast<Type *>(static_scratch_buffer); \
155 } else { \
156 im2colData = new (std::nothrow) Type[im2colByteSize / sizeof(Type)]; \
157 if (im2colData == nullptr) { \
158 LOG(ERROR) << "Conv size is too large, not enough memory"; \
159 return false; \
160 } \
161 im2colGuard.reset(im2colData); \
162 }
163
convNhwc(const float * inputData,const Shape & inputShape,const float * filterData,const Shape & filterShape,const float * biasData,const Shape & biasShape,int32_t padding_left,int32_t padding_right,int32_t padding_top,int32_t padding_bottom,int32_t stride_width,int32_t stride_height,int32_t dilation_width_factor,int32_t dilation_height_factor,int32_t activation,float * outputData,const Shape & outputShape)164 bool convNhwc(const float* inputData, const Shape& inputShape, const float* filterData,
165 const Shape& filterShape, const float* biasData, const Shape& biasShape,
166 int32_t padding_left, int32_t padding_right, int32_t padding_top,
167 int32_t padding_bottom, int32_t stride_width, int32_t stride_height,
168 int32_t dilation_width_factor, int32_t dilation_height_factor, int32_t activation,
169 float* outputData, const Shape& outputShape) {
170 NNTRACE_TRANS("convFloat32");
171
172 ANDROID_NN_CONV_PARAMETERS(float)
173
174 float output_activation_min, output_activation_max;
175 CalculateActivationRangeFloat(activation, &output_activation_min, &output_activation_max);
176
177 // Prevent concurrent executions that may access the scratch buffer.
178 std::unique_lock<std::mutex> lock(executionMutex);
179 NNTRACE_COMP_SWITCH("optimized_ops::Conv");
180 tflite::optimized_ops::Conv(inputData, convertShapeToDims(inputShape), filterData,
181 convertShapeToDims(filterShape), biasData,
182 convertShapeToDims(biasShape), stride_width, stride_height,
183 dilation_width_factor, dilation_height_factor, paddingWidth,
184 paddingHeight, output_activation_min, output_activation_max,
185 outputData, convertShapeToDims(outputShape), im2colData, im2colDim);
186 return true;
187 }
188
convNhwc(const uint8_t * inputData,const Shape & inputShape,const uint8_t * filterData,const Shape & filterShape,const int32_t * biasData,const Shape & biasShape,int32_t padding_left,int32_t padding_right,int32_t padding_top,int32_t padding_bottom,int32_t stride_width,int32_t stride_height,int32_t dilation_width_factor,int32_t dilation_height_factor,int32_t activation,uint8_t * outputData,const Shape & outputShape)189 bool convNhwc(const uint8_t* inputData, const Shape& inputShape, const uint8_t* filterData,
190 const Shape& filterShape, const int32_t* biasData, const Shape& biasShape,
191 int32_t padding_left, int32_t padding_right, int32_t padding_top,
192 int32_t padding_bottom, int32_t stride_width, int32_t stride_height,
193 int32_t dilation_width_factor, int32_t dilation_height_factor, int32_t activation,
194 uint8_t* outputData, const Shape& outputShape) {
195 NNTRACE_TRANS("convQuant8");
196
197 ANDROID_NN_CONV_PARAMETERS(uint8_t)
198
199 int32_t inputOffset = -inputShape.offset;
200 int32_t filterOffset = -filterShape.offset;
201 int32_t outputOffset = outputShape.offset;
202
203 double real_multiplier = 0.0;
204 int32_t output_multiplier = 0;
205 int32_t output_shift = 0;
206 int32_t output_activation_min = 0;
207 int32_t output_activation_max = 0;
208
209 NN_RET_CHECK(GetQuantizedConvolutionMultipler(inputShape, filterShape, biasShape, outputShape,
210 &real_multiplier));
211 int exponent;
212 NN_RET_CHECK(QuantizeMultiplier(real_multiplier, &output_multiplier, &exponent));
213 output_shift = -exponent;
214 CalculateActivationRangeUint8(activation, outputShape, &output_activation_min,
215 &output_activation_max);
216
217 static gemmlowp::GemmContext gemm_context;
218
219 // Prevent concurrent executions that may access the scratch buffer and
220 // gemm_context.
221 std::unique_lock<std::mutex> lock(executionMutex);
222 // Alow gemmlowp automatically decide how many threads to use.
223 gemm_context.set_max_num_threads(0);
224
225 NNTRACE_COMP_SWITCH("optimized_ops::Conv");
226 tflite::optimized_ops::Conv(
227 inputData, convertShapeToDims(inputShape), inputOffset, filterData,
228 convertShapeToDims(filterShape), filterOffset, biasData, convertShapeToDims(biasShape),
229 stride_width, stride_height, dilation_width_factor, dilation_height_factor,
230 paddingWidth, paddingHeight, outputOffset, output_multiplier, output_shift,
231 output_activation_min, output_activation_max, outputData,
232 convertShapeToDims(outputShape), im2colData, im2colDim, &gemm_context);
233 return true;
234 }
235
convNhwc(const _Float16 * inputData,const Shape & inputShape,const _Float16 * filterData,const Shape & filterShape,const _Float16 * biasData,const Shape & biasShape,int32_t padding_left,int32_t padding_right,int32_t padding_top,int32_t padding_bottom,int32_t stride_width,int32_t stride_height,int32_t dilation_width_factor,int32_t dilation_height_factor,int32_t activation,_Float16 * outputData,const Shape & outputShape)236 bool convNhwc(const _Float16* inputData, const Shape& inputShape, const _Float16* filterData,
237 const Shape& filterShape, const _Float16* biasData, const Shape& biasShape,
238 int32_t padding_left, int32_t padding_right, int32_t padding_top,
239 int32_t padding_bottom, int32_t stride_width, int32_t stride_height,
240 int32_t dilation_width_factor, int32_t dilation_height_factor, int32_t activation,
241 _Float16* outputData, const Shape& outputShape) {
242 NNTRACE_TRANS("convFloat16");
243
244 std::vector<float> inputData_float32(getNumberOfElements(inputShape));
245 std::vector<float> filterData_float32(getNumberOfElements(filterShape));
246 std::vector<float> biasData_float32(getNumberOfElements(biasShape));
247 std::vector<float> outputData_float32(getNumberOfElements(outputShape));
248
249 convertFloat16ToFloat32(inputData, &inputData_float32);
250 convertFloat16ToFloat32(filterData, &filterData_float32);
251 convertFloat16ToFloat32(biasData, &biasData_float32);
252
253 convNhwc(inputData_float32.data(), inputShape, filterData_float32.data(), filterShape,
254 biasData_float32.data(), biasShape, padding_left, padding_right, padding_top,
255 padding_bottom, stride_width, stride_height, dilation_width_factor,
256 dilation_height_factor, activation, outputData_float32.data(), outputShape);
257 convertFloat32ToFloat16(outputData_float32, outputData);
258
259 return true;
260 }
261
262 template <typename T_Input, typename T_Filter, typename T_Bias>
conv(const T_Input * inputData,const Shape & inputShape,const T_Filter * filterData,const Shape & filterShape,const T_Bias * biasData,const Shape & biasShape,int32_t padding_left,int32_t padding_right,int32_t padding_top,int32_t padding_bottom,int32_t stride_width,int32_t stride_height,int32_t dilation_width_factor,int32_t dilation_height_factor,int32_t activation,bool useNchw,T_Input * outputData,const Shape & outputShape)263 bool conv(const T_Input* inputData, const Shape& inputShape, const T_Filter* filterData,
264 const Shape& filterShape, const T_Bias* biasData, const Shape& biasShape,
265 int32_t padding_left, int32_t padding_right, int32_t padding_top, int32_t padding_bottom,
266 int32_t stride_width, int32_t stride_height, int32_t dilation_width_factor,
267 int32_t dilation_height_factor, int32_t activation, bool useNchw, T_Input* outputData,
268 const Shape& outputShape) {
269 InputWithLayout<T_Input> input(useNchw);
270 OutputWithLayout<T_Input> output(useNchw);
271 NN_RET_CHECK(input.initialize(inputData, inputShape));
272 NN_RET_CHECK(output.initialize(outputData, outputShape));
273 NN_RET_CHECK(convNhwc(input.getNhwcBuffer(), input.getNhwcShape(), filterData, filterShape,
274 biasData, biasShape, padding_left, padding_right, padding_top,
275 padding_bottom, stride_width, stride_height, dilation_width_factor,
276 dilation_height_factor, activation, output.getNhwcBuffer(),
277 output.getNhwcShape()));
278 NN_RET_CHECK(output.commit());
279 return true;
280 }
281
convQuant8PerChannelNhwc(const uint8_t * inputData,const Shape & inputShape,const int8_t * filterData,const Shape & filterShape,const float * filterScales,const int32_t * biasData,const Shape & biasShape,int32_t paddingLeft,int32_t paddingRight,int32_t paddingTop,int32_t paddingBottom,int32_t strideWidth,int32_t strideHeight,int32_t dilationWidthFactor,int32_t dilationHeightFactor,int32_t activation,uint8_t * outputData,const Shape & outputShape)282 bool convQuant8PerChannelNhwc(const uint8_t* inputData, const Shape& inputShape,
283 const int8_t* filterData, const Shape& filterShape,
284 const float* filterScales, const int32_t* biasData,
285 const Shape& biasShape, int32_t paddingLeft, int32_t paddingRight,
286 int32_t paddingTop, int32_t paddingBottom, int32_t strideWidth,
287 int32_t strideHeight, int32_t dilationWidthFactor,
288 int32_t dilationHeightFactor, int32_t activation, uint8_t* outputData,
289 const Shape& outputShape) {
290 NNTRACE_TRANS("convQuant8PerChannel");
291
292 uint32_t numBatches = getSizeOfDimension(inputShape, 0);
293 uint32_t inputHeight = getSizeOfDimension(inputShape, 1);
294 uint32_t inputWidth = getSizeOfDimension(inputShape, 2);
295 uint32_t inputDepth = getSizeOfDimension(inputShape, 3);
296 uint32_t filterHeight = getSizeOfDimension(filterShape, 1);
297 uint32_t filterWidth = getSizeOfDimension(filterShape, 2);
298 uint32_t filterDepth = getSizeOfDimension(filterShape, 3);
299 uint32_t outputHeight = getSizeOfDimension(outputShape, 1);
300 uint32_t outputWidth = getSizeOfDimension(outputShape, 2);
301 uint32_t outputDepth = getSizeOfDimension(outputShape, 3);
302
303 int32_t inputOffset = -inputShape.offset;
304 int32_t outputOffset = outputShape.offset;
305
306 auto realMultiplier = std::vector<double>(outputDepth, .0f);
307 auto outputMultiplier = std::vector<int32_t>(outputDepth, 0);
308 auto outputShift = std::vector<int32_t>(outputDepth, .0f);
309
310 for (int i = 0; i < outputDepth; ++i) {
311 Shape filterChannelShape = filterShape;
312 filterChannelShape.scale = filterScales[i];
313 Shape biasChannelShape = biasShape;
314 biasChannelShape.scale = filterScales[i] * inputShape.scale;
315 NN_RET_CHECK(GetQuantizedConvolutionMultipler(
316 inputShape, filterChannelShape, biasChannelShape, outputShape, &realMultiplier[i]));
317 int exponent;
318 NN_RET_CHECK(QuantizeMultiplier(realMultiplier[i], &outputMultiplier[i], &exponent));
319 outputShift[i] = -exponent;
320 }
321
322 int32_t output_activation_min = 0, output_activation_max = 0;
323 CalculateActivationRangeUint8(activation, outputShape, &output_activation_min,
324 &output_activation_max);
325 const uint8_t* inputBase = inputData;
326 uint8_t* outPtr = outputData;
327 for (uint32_t b = 0; b < numBatches; b++) {
328 for (uint32_t h = 0; h < outputHeight; h++) {
329 for (uint32_t w = 0; w < outputWidth; w++) {
330 const int8_t* filterBase = filterData;
331
332 for (uint32_t d = 0; d < outputDepth; d++) {
333 int32_t wInputOrigin = static_cast<int32_t>(w) * strideWidth - paddingLeft;
334 int32_t hInputOrigin = static_cast<int32_t>(h) * strideHeight - paddingTop;
335 int32_t sum = 0.0f;
336
337 for (uint32_t i = 0; i < filterHeight; i++) {
338 for (uint32_t j = 0; j < filterWidth; j++) {
339 for (uint32_t k = 0; k < filterDepth; k++) {
340 int32_t hInput = hInputOrigin +
341 dilationHeightFactor * static_cast<int32_t>(i);
342 int32_t wInput = wInputOrigin +
343 dilationWidthFactor * static_cast<int32_t>(j);
344 uint32_t dInput = k;
345 if (hInput >= 0 && hInput < static_cast<int32_t>(inputHeight) &&
346 wInput >= 0 && wInput < static_cast<int32_t>(inputWidth)) {
347 uint32_t filterIndex =
348 i * filterWidth * filterDepth + j * filterDepth + k;
349 uint32_t inputIndex = hInput * inputWidth * inputDepth +
350 wInput * inputDepth + dInput;
351 sum += (static_cast<int32_t>(filterBase[filterIndex])) *
352 (static_cast<int32_t>(inputBase[inputIndex]) +
353 inputOffset);
354 }
355 }
356 }
357 }
358 sum += biasData[d];
359 sum = tflite::MultiplyByQuantizedMultiplier(sum, outputMultiplier[d],
360 -outputShift[d]);
361 sum += outputOffset;
362 sum = std::max(std::min(sum, output_activation_max), output_activation_min);
363 outPtr[d] = static_cast<uint8_t>(sum);
364 filterBase += filterHeight * filterWidth * filterDepth;
365 }
366 outPtr += outputDepth;
367 }
368 }
369 inputBase += inputHeight * inputWidth * inputDepth;
370 }
371
372 return true;
373 }
374
convQuant8PerChannel(const uint8_t * inputData,const Shape & inputShape,const int8_t * filterData,const Shape & filterShape,const float * filterScales,const int32_t * biasData,const Shape & biasShape,int32_t paddingLeft,int32_t paddingRight,int32_t paddingTop,int32_t paddingBottom,int32_t strideWidth,int32_t strideHeight,int32_t dilationWidthFactor,int32_t dilationHeightFactor,int32_t activation,bool useNchw,uint8_t * outputData,const Shape & outputShape)375 bool convQuant8PerChannel(const uint8_t* inputData, const Shape& inputShape,
376 const int8_t* filterData, const Shape& filterShape,
377 const float* filterScales, const int32_t* biasData,
378 const Shape& biasShape, int32_t paddingLeft, int32_t paddingRight,
379 int32_t paddingTop, int32_t paddingBottom, int32_t strideWidth,
380 int32_t strideHeight, int32_t dilationWidthFactor,
381 int32_t dilationHeightFactor, int32_t activation, bool useNchw,
382 uint8_t* outputData, const Shape& outputShape) {
383 InputWithLayout<uint8_t> input(useNchw);
384 OutputWithLayout<uint8_t> output(useNchw);
385 NN_RET_CHECK(input.initialize(inputData, inputShape));
386 NN_RET_CHECK(output.initialize(outputData, outputShape));
387 NN_RET_CHECK(convQuant8PerChannelNhwc(
388 input.getNhwcBuffer(), input.getNhwcShape(), filterData, filterShape, filterScales,
389 biasData, biasShape, paddingLeft, paddingRight, paddingTop, paddingBottom, strideWidth,
390 strideHeight, dilationWidthFactor, dilationHeightFactor, activation,
391 output.getNhwcBuffer(), output.getNhwcShape()));
392 NN_RET_CHECK(output.commit());
393 return true;
394 }
395
396 #undef ANDROID_NN_CONV_PARAMETERS
397
398 } // namespace
399
validate(const IOperationValidationContext * context)400 bool validate(const IOperationValidationContext* context) {
401 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
402 auto inputCount = context->getNumInputs();
403 auto inputType = context->getInputType(kInputTensor);
404 auto filterType = context->getInputType(kFilterTensor);
405 std::vector<OperandType> inExpectedTypes;
406 if (inputType == OperandType::TENSOR_FLOAT32) {
407 inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
408 OperandType::TENSOR_FLOAT32, OperandType::INT32,
409 OperandType::INT32, OperandType::INT32,
410 OperandType::INT32};
411 } else if (inputType == OperandType::TENSOR_FLOAT16) {
412 inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
413 OperandType::TENSOR_FLOAT16, OperandType::INT32,
414 OperandType::INT32, OperandType::INT32,
415 OperandType::INT32};
416 } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
417 if (filterType == OperandType::TENSOR_QUANT8_ASYMM ||
418 filterType == OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL) {
419 inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM,
420 filterType,
421 OperandType::TENSOR_INT32,
422 OperandType::INT32,
423 OperandType::INT32,
424 OperandType::INT32,
425 OperandType::INT32};
426
427 if (filterType == OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL) {
428 NN_RET_CHECK_EQ(
429 context->getInputExtraParams(kFilterTensor).channelQuant().channelDim, 0)
430 << "Unsupported filter tensor channel dimension for operation "
431 << kOperationName;
432 }
433 } else {
434 NN_RET_CHECK_FAIL() << "Unsupported filter tensor type for operation "
435 << kOperationName;
436 }
437 } else {
438 NN_RET_CHECK_FAIL() << "Unsupported input tensor type for operation " << kOperationName;
439 }
440
441 // NeuralNetworks.h specifies that ANEURALNETWORKS_CONV_2D's output must
442 // meet "outputScale > inputScale * filterScale" for the operand type
443 // ANEURALNETWORKS_TENSOR_QUANT8_ASYMM before API level 29. For other
444 // operand types (e.g., ANEURALNETWORKS_TENSOR_FLOAT32), this constraint
445 // does not apply, so by default the constraint is met.
446 bool meetsQuantizedScaleConstraintBeforeV1_2 = true;
447 if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
448 const float inputScale = context->getInputShape(kInputTensor).scale;
449 const float filterScale = context->getInputShape(kFilterTensor).scale;
450 const float outputScale = context->getInputShape(kOutputTensor).scale;
451 meetsQuantizedScaleConstraintBeforeV1_2 = (outputScale > inputScale * filterScale);
452 }
453
454 bool withExplicitPadding = false;
455 bool withLayout = false;
456 bool withDilation = false;
457 if (inputCount >= 8) {
458 if (context->getInputType(7) == OperandType::INT32 && inputCount >= 10) {
459 std::vector<OperandType> explicitScalarTypes(3, OperandType::INT32);
460 inExpectedTypes.insert(inExpectedTypes.end(), explicitScalarTypes.begin(),
461 explicitScalarTypes.end());
462 withExplicitPadding = true;
463 }
464 int inputOffset = withExplicitPadding ? 3 : 0;
465 if (inputCount >= 8 + inputOffset) {
466 inExpectedTypes.push_back(OperandType::BOOL);
467 withLayout = true;
468 }
469 NN_RET_CHECK_NE(inputCount, 9 + inputOffset)
470 << "Provided only one dilation factor value, two values are requred for operation "
471 << kOperationName;
472 if (inputCount == 10 + inputOffset) {
473 inExpectedTypes.push_back(OperandType::INT32);
474 inExpectedTypes.push_back(OperandType::INT32);
475 withDilation = true;
476 }
477 }
478
479 if (filterType == OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL || withLayout || withDilation ||
480 !meetsQuantizedScaleConstraintBeforeV1_2) {
481 NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_2));
482 } else {
483 NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_0));
484 }
485 return validateInputTypes(context, inExpectedTypes) &&
486 validateOutputTypes(context, {inputType});
487 }
488
prepare(IOperationExecutionContext * context)489 bool prepare(IOperationExecutionContext* context) {
490 Shape input = context->getInputShape(kInputTensor);
491 Shape filter = context->getInputShape(kFilterTensor);
492 Shape bias = context->getInputShape(kBiasTensor);
493
494 if (filter.type == OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL) {
495 NN_RET_CHECK(input.type == OperandType::TENSOR_QUANT8_ASYMM);
496 } else {
497 NN_RET_CHECK(input.type == filter.type);
498 }
499 if (input.type == OperandType::TENSOR_QUANT8_ASYMM) {
500 NN_RET_CHECK(bias.type == OperandType::TENSOR_INT32);
501 } else {
502 NN_RET_CHECK(input.type == bias.type);
503 }
504 NN_RET_CHECK_EQ(getNumberOfDimensions(input), 4);
505 NN_RET_CHECK_EQ(getNumberOfDimensions(filter), 4);
506 NN_RET_CHECK_EQ(getNumberOfDimensions(bias), 1);
507
508 Conv2dParam param;
509 NN_RET_CHECK(param.initialize(context));
510
511 uint32_t batches = getSizeOfDimension(input, 0);
512 uint32_t height = getSizeOfDimension(input, param.useNchw ? 2 : 1);
513 uint32_t width = getSizeOfDimension(input, param.useNchw ? 3 : 2);
514 uint32_t channels_in = getSizeOfDimension(input, param.useNchw ? 1 : 3);
515 uint32_t channels_out = getSizeOfDimension(filter, 0);
516 uint32_t filterHeight = getSizeOfDimension(filter, 1);
517 uint32_t filterWidth = getSizeOfDimension(filter, 2);
518 // Only batches can be zero.
519 NN_RET_CHECK_EQ(channels_in, getSizeOfDimension(filter, 3));
520 NN_RET_CHECK_EQ(channels_out, getSizeOfDimension(bias, 0));
521 NN_RET_CHECK_GT(height, 0);
522 NN_RET_CHECK_GT(width, 0);
523 NN_RET_CHECK_GT(channels_in, 0);
524 NN_RET_CHECK_GT(channels_out, 0);
525
526 int32_t effectiveFilterWidth = (filterWidth - 1) * param.dilation_width_factor + 1;
527 int32_t effectiveFilterHeight = (filterHeight - 1) * param.dilation_height_factor + 1;
528 NN_RET_CHECK_GT(effectiveFilterWidth, param.padding_left);
529 NN_RET_CHECK_GT(effectiveFilterWidth, param.padding_right);
530 NN_RET_CHECK_GT(effectiveFilterHeight, param.padding_top);
531 NN_RET_CHECK_GT(effectiveFilterHeight, param.padding_bottom);
532
533 uint32_t outWidth =
534 computeOutSize(width, filterWidth, param.stride_width, param.dilation_width_factor,
535 param.padding_left, param.padding_right);
536 uint32_t outHeight =
537 computeOutSize(height, filterHeight, param.stride_height, param.dilation_height_factor,
538 param.padding_top, param.padding_bottom);
539
540 Shape output = context->getOutputShape(kOutputTensor);
541 output.type = input.type;
542 if (param.useNchw) {
543 output.dimensions = {batches, channels_out, outHeight, outWidth};
544 } else {
545 output.dimensions = {batches, outHeight, outWidth, channels_out};
546 }
547 return context->setOutputShape(kOutputTensor, output);
548 }
549
execute(IOperationExecutionContext * context)550 bool execute(IOperationExecutionContext* context) {
551 // Bypass execution in the case of zero-sized input.
552 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
553 Conv2dParam param;
554 NN_RET_CHECK(param.initialize(context));
555 switch (context->getInputType(kInputTensor)) {
556 case OperandType::TENSOR_FLOAT32:
557 return conv(context->getInputBuffer<float>(kInputTensor),
558 context->getInputShape(kInputTensor),
559 context->getInputBuffer<float>(kFilterTensor),
560 context->getInputShape(kFilterTensor),
561 context->getInputBuffer<float>(kBiasTensor),
562 context->getInputShape(kBiasTensor), param.padding_left,
563 param.padding_right, param.padding_top, param.padding_bottom,
564 param.stride_width, param.stride_height, param.dilation_width_factor,
565 param.dilation_height_factor, param.activation, param.useNchw,
566 context->getOutputBuffer<float>(kOutputTensor),
567 context->getOutputShape(kOutputTensor));
568 case OperandType::TENSOR_FLOAT16:
569 return conv(context->getInputBuffer<_Float16>(kInputTensor),
570 context->getInputShape(kInputTensor),
571 context->getInputBuffer<_Float16>(kFilterTensor),
572 context->getInputShape(kFilterTensor),
573 context->getInputBuffer<_Float16>(kBiasTensor),
574 context->getInputShape(kBiasTensor), param.padding_left,
575 param.padding_right, param.padding_top, param.padding_bottom,
576 param.stride_width, param.stride_height, param.dilation_width_factor,
577 param.dilation_height_factor, param.activation, param.useNchw,
578 context->getOutputBuffer<_Float16>(kOutputTensor),
579 context->getOutputShape(kOutputTensor));
580 case OperandType::TENSOR_QUANT8_ASYMM:
581 if (context->getInputType(kFilterTensor) ==
582 OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL) {
583 return convQuant8PerChannel(
584 context->getInputBuffer<uint8_t>(kInputTensor),
585 context->getInputShape(kInputTensor),
586 context->getInputBuffer<int8_t>(kFilterTensor),
587 context->getInputShape(kFilterTensor),
588 context->getInputExtraParams(kFilterTensor).channelQuant().scales.data(),
589 context->getInputBuffer<int32_t>(kBiasTensor),
590 context->getInputShape(kBiasTensor), param.padding_left,
591 param.padding_right, param.padding_top, param.padding_bottom,
592 param.stride_width, param.stride_height, param.dilation_width_factor,
593 param.dilation_height_factor, param.activation, param.useNchw,
594 context->getOutputBuffer<uint8_t>(kOutputTensor),
595 context->getOutputShape(kOutputTensor));
596 } else if (context->getInputType(kFilterTensor) == OperandType::TENSOR_QUANT8_ASYMM) {
597 return conv(context->getInputBuffer<uint8_t>(kInputTensor),
598 context->getInputShape(kInputTensor),
599 context->getInputBuffer<uint8_t>(kFilterTensor),
600 context->getInputShape(kFilterTensor),
601 context->getInputBuffer<int32_t>(kBiasTensor),
602 context->getInputShape(kBiasTensor), param.padding_left,
603 param.padding_right, param.padding_top, param.padding_bottom,
604 param.stride_width, param.stride_height, param.dilation_width_factor,
605 param.dilation_height_factor, param.activation, param.useNchw,
606 context->getOutputBuffer<uint8_t>(kOutputTensor),
607 context->getOutputShape(kOutputTensor));
608 } else {
609 NN_RET_CHECK_FAIL() << "Unsupported filter type for operation " << kOperationName;
610 }
611 default:
612 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
613 }
614 }
615
616 } // namespace conv_2d
617
618 NN_REGISTER_OPERATION(CONV_2D, conv_2d::kOperationName, conv_2d::validate, conv_2d::prepare,
619 conv_2d::execute, .allowZeroSizedInput = true);
620
621 } // namespace nn
622 } // namespace android
623