/* * Copyright (C) 2017 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Contains the implementation of the operations. #define LOG_TAG "Operations" #include "Broadcast.h" #include #include #include #include "IndexedShapeWrapper.h" #include "OperationResolver.h" #include "Tracing.h" #include "nnapi/Types.h" #include "nnapi/Validation.h" #ifdef NN_INCLUDE_CPU_IMPLEMENTATION #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wunused-parameter" #pragma clang diagnostic ignored "-Wsign-compare" #pragma clang diagnostic ignored "-Winvalid-partial-specialization" #include #include #include #include #include #include #pragma clang diagnostic pop #include "CpuOperationUtils.h" #endif // NN_INCLUDE_CPU_IMPLEMENTATION namespace android { namespace nn { namespace broadcast { #ifdef NN_INCLUDE_CPU_IMPLEMENTATION namespace { #define ANDROID_NN_MACRO_DISPATCH(macro) \ switch (activation) { \ case static_cast(FusedActivationFunc::NONE): \ macro(kNone); \ break; \ case static_cast(FusedActivationFunc::RELU): \ macro(kRelu); \ break; \ case static_cast(FusedActivationFunc::RELU1): \ macro(kRelu1); \ break; \ case static_cast(FusedActivationFunc::RELU6): \ macro(kRelu6); \ break; \ default: \ LOG(ERROR) << "Unsupported fused activation function type"; \ return false; \ } using binaryFunctionFloat32 = std::function; bool binaryOperationFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2, int32_t activation, _Float16* out, const Shape& shapeOut, binaryFunctionFloat32 operationFloat32) { std::vector in1_float32(getNumberOfElements(shape1)); convertFloat16ToFloat32(in1, &in1_float32); std::vector in2_float32(getNumberOfElements(shape2)); convertFloat16ToFloat32(in2, &in2_float32); std::vector out_float32(getNumberOfElements(shapeOut)); operationFloat32(in1_float32.data(), shape1, in2_float32.data(), shape2, activation, out_float32.data(), shapeOut); convertFloat32ToFloat16(out_float32, out); return true; } bool addFloat32(const float* in1, const Shape& shape1, const float* in2, const Shape& shape2, int32_t activation, float* out, const Shape& shapeOut) { NNTRACE_TRANS("addFloat32"); bool needBroadcast = !SameShape(shape1, shape2); if (needBroadcast) { NNTRACE_COMP_SWITCH("optimized_ops::BroadcastAdd"); #define ANDROID_NN_BROADCAST_ADD(activation) \ tflite::optimized_ops::BroadcastAdd( \ in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2), out, \ convertShapeToDims(shapeOut)) ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_BROADCAST_ADD) #undef ANDROID_NN_BROADCAST_ADD } else { NNTRACE_COMP_SWITCH("optimized_ops::Add"); #define ANDROID_NN_ADD(activation) \ tflite::optimized_ops::Add( \ in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2), out, \ convertShapeToDims(shapeOut)) ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_ADD) #undef ANDROID_NN_ADD } return true; } bool addFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2, int32_t activation, _Float16* out, const Shape& shapeOut) { NNTRACE_TRANS("addFloat16"); return binaryOperationFloat16(in1, shape1, in2, shape2, activation, out, shapeOut, &addFloat32); } template bool addQuant8(const T* in1, const Shape& shape1, const T* in2, const Shape& shape2, int32_t activation, T* out, const Shape& shapeOut) { NNTRACE_TRANS("addQuant8"); const bool needBroadcast = !SameShape(shape1, shape2); const int32_t input1_offset = -shape1.offset; const int32_t input2_offset = -shape2.offset; const int32_t output_offset = shapeOut.offset; const int left_shift = 20; const double twice_max_input_scale = 2 * std::max(shape1.scale, shape2.scale); const double real_input1_multiplier = shape1.scale / twice_max_input_scale; const double real_input2_multiplier = shape2.scale / twice_max_input_scale; const double real_output_multiplier = twice_max_input_scale / ((1 << left_shift) * shapeOut.scale); int32_t input1_multiplier; int32_t input1_shift; NN_RET_CHECK(QuantizeMultiplierSmallerThanOneExp(real_input1_multiplier, &input1_multiplier, &input1_shift)); int32_t input2_multiplier; int32_t input2_shift; NN_RET_CHECK(QuantizeMultiplierSmallerThanOneExp(real_input2_multiplier, &input2_multiplier, &input2_shift)); int32_t output_multiplier; int32_t output_shift; NN_RET_CHECK(QuantizeMultiplierSmallerThanOneExp(real_output_multiplier, &output_multiplier, &output_shift)); int32_t output_activation_min; int32_t output_activation_max; constexpr bool isSignedOp = std::is_same::value; if constexpr (isSignedOp) { CalculateActivationRangeInt8(activation, shapeOut, &output_activation_min, &output_activation_max); } else { CalculateActivationRangeUint8(activation, shapeOut, &output_activation_min, &output_activation_max); } tflite::ArithmeticParams op_params; op_params.left_shift = left_shift; op_params.input1_offset = input1_offset; op_params.input1_multiplier = input1_multiplier; op_params.input1_shift = input1_shift; op_params.input2_offset = input2_offset; op_params.input2_multiplier = input2_multiplier; op_params.input2_shift = input2_shift; op_params.output_offset = output_offset; op_params.output_multiplier = output_multiplier; op_params.output_shift = output_shift; tflite::SetActivationParams(output_activation_min, output_activation_max, &op_params); if (needBroadcast) { if constexpr (isSignedOp) { NNTRACE_COMP_SWITCH("reference_integer_ops::BroadcastAdd4DSlow"); tflite::reference_integer_ops::BroadcastAdd4DSlow( op_params, convertShapeToTflshape(shape1), in1, convertShapeToTflshape(shape2), in2, convertShapeToTflshape(shapeOut), out); } else { NNTRACE_COMP_SWITCH("reference_ops::BroadcastAdd4DSlow"); tflite::reference_ops::BroadcastAdd4DSlow(op_params, convertShapeToTflshape(shape1), in1, convertShapeToTflshape(shape2), in2, convertShapeToTflshape(shapeOut), out); } } else { if constexpr (isSignedOp) { NNTRACE_COMP_SWITCH("optimized_integer_ops::Add"); tflite::optimized_integer_ops::Add(op_params, convertShapeToTflshape(shape1), in1, convertShapeToTflshape(shape2), in2, convertShapeToTflshape(shapeOut), out); } else { NNTRACE_COMP_SWITCH("optimized_ops::Add"); tflite::optimized_ops::Add(op_params, convertShapeToTflshape(shape1), in1, convertShapeToTflshape(shape2), in2, convertShapeToTflshape(shapeOut), out); } } return true; } bool executeInt32(const int32_t* aData, const Shape& aShape, const int32_t* bData, const Shape& bShape, int32_t activation, int32_t* outputData, const Shape& outputShape, int32_t func(int32_t, int32_t)) { NN_RET_CHECK_EQ(static_cast(activation), FusedActivationFunc::NONE); IndexedShapeWrapper aShapeIndexed(aShape); IndexedShapeWrapper bShapeIndexed(bShape); IndexedShapeWrapper outputShapeIndexed(outputShape); std::vector curIndex(outputShape.dimensions.size(), 0); bool lastIndex = false; do { uint32_t outputFlatIndex; NN_RET_CHECK(outputShapeIndexed.indexToFlatIndex(curIndex, &outputFlatIndex)); uint32_t aFlatIndex; NN_RET_CHECK(aShapeIndexed.broadcastedIndexToFlatIndex(curIndex, &aFlatIndex)); uint32_t bFlatIndex; NN_RET_CHECK(bShapeIndexed.broadcastedIndexToFlatIndex(curIndex, &bFlatIndex)); outputData[outputFlatIndex] = func(aData[aFlatIndex], bData[bFlatIndex]); NN_RET_CHECK(outputShapeIndexed.nextIndexInplace(&curIndex, &lastIndex)); } while (!lastIndex); return true; } bool mulFloat32(const float* in1, const Shape& shape1, const float* in2, const Shape& shape2, int32_t activation, float* out, const Shape& shapeOut) { NNTRACE_TRANS("mulFloat32"); bool needBroadcast = !SameShape(shape1, shape2); if (needBroadcast) { NNTRACE_COMP_SWITCH("optimized_ops::BroadcastMul"); #define ANDROID_NN_BROADCAST_MUL(activation) \ tflite::optimized_ops::BroadcastMul( \ in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2), out, \ convertShapeToDims(shapeOut)) ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_BROADCAST_MUL) #undef ANDROID_NN_BROADCAST_MUL } else { float output_activation_min, output_activation_max; CalculateActivationRangeFloat(activation, &output_activation_min, &output_activation_max); NNTRACE_COMP_SWITCH("optimized_ops::Mul"); tflite::optimized_ops::Mul(in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2), output_activation_min, output_activation_max, out, convertShapeToDims(shapeOut)); } return true; } bool mulFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2, int32_t activation, _Float16* out, const Shape& shapeOut) { NNTRACE_TRANS("mulFloat16"); return binaryOperationFloat16(in1, shape1, in2, shape2, activation, out, shapeOut, &mulFloat32); } template bool mulQuant8(const T* in1, const Shape& shape1, const T* in2, const Shape& shape2, int32_t activation, T* out, const Shape& shapeOut) { NNTRACE_TRANS("mulQuant8"); const int32_t input1_offset = -shape1.offset; const int32_t input2_offset = -shape2.offset; const int32_t output_offset = shapeOut.offset; const double input_product_scale = shape1.scale * shape2.scale; const double real_multiplier = input_product_scale / shapeOut.scale; int32 output_multiplier; int output_shift; NN_RET_CHECK(QuantizeMultiplierSmallerThanOneExp(real_multiplier, &output_multiplier, &output_shift)); constexpr bool isSignedOp = std::is_same::value; int32_t output_activation_min; int32_t output_activation_max; if constexpr (isSignedOp) { CalculateActivationRangeInt8(activation, shapeOut, &output_activation_min, &output_activation_max); } else { CalculateActivationRangeUint8(activation, shapeOut, &output_activation_min, &output_activation_max); } tflite::ArithmeticParams op_params; op_params.input1_offset = input1_offset; op_params.input2_offset = input2_offset; op_params.output_offset = output_offset; op_params.output_multiplier = output_multiplier; op_params.output_shift = output_shift; tflite::SetActivationParams(output_activation_min, output_activation_max, &op_params); if constexpr (isSignedOp) { NNTRACE_COMP_SWITCH("reference_integer_ops::BroadcastMul4DSlow"); tflite::reference_integer_ops::BroadcastMul4DSlow(op_params, convertShapeToTflshape(shape1), in1, convertShapeToTflshape(shape2), in2, convertShapeToTflshape(shapeOut), out); } else { NNTRACE_COMP_SWITCH("reference_ops::BroadcastMul4DSlow"); tflite::reference_ops::BroadcastMul4DSlow(op_params, convertShapeToTflshape(shape1), in1, convertShapeToTflshape(shape2), in2, convertShapeToTflshape(shapeOut), out); } return true; } bool subFloat32(const float* in1, const Shape& shape1, const float* in2, const Shape& shape2, int32_t activation, float* out, const Shape& shapeOut) { NNTRACE_TRANS("subFloat32"); NNTRACE_COMP_SWITCH("optimized_ops::Sub"); tflite::optimized_ops::Sub(in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2), out, convertShapeToDims(shapeOut)); // TFLite does not apply activation to broadcast sub. float output_activation_min, output_activation_max; CalculateActivationRangeFloat(activation, &output_activation_min, &output_activation_max); uint32_t numOutputElements = getNumberOfElements(shapeOut); for (uint32_t i = 0; i < numOutputElements; i++) { out[i] = std::min(std::max(out[i], output_activation_min), output_activation_max); } return true; } bool subFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2, int32_t activation, _Float16* out, const Shape& shapeOut) { NNTRACE_TRANS("subFloat16"); return binaryOperationFloat16(in1, shape1, in2, shape2, activation, out, shapeOut, &subFloat32); } template bool subQuant8(const T* in1, const Shape& shape1, const T* in2, const Shape& shape2, int32_t activation, T* out, const Shape& shapeOut) { NNTRACE_TRANS("subQuant8"); const int32_t input1_offset = -shape1.offset; const int32_t input2_offset = -shape2.offset; const int32_t output_offset = shapeOut.offset; const int left_shift = 20; const double twice_max_input_scale = 2 * std::max(shape1.scale, shape2.scale); const double real_input1_multiplier = shape1.scale / twice_max_input_scale; const double real_input2_multiplier = shape2.scale / twice_max_input_scale; const double real_output_multiplier = twice_max_input_scale / ((1 << left_shift) * shapeOut.scale); int32_t input1_multiplier; int32_t input1_shift; NN_RET_CHECK(QuantizeMultiplierSmallerThanOneExp(real_input1_multiplier, &input1_multiplier, &input1_shift)); int32_t input2_multiplier; int32_t input2_shift; NN_RET_CHECK(QuantizeMultiplierSmallerThanOneExp(real_input2_multiplier, &input2_multiplier, &input2_shift)); // Negate multiplier of the second input, so that we can use Add kernels. input2_multiplier *= -1; int32_t output_multiplier; int32_t output_shift; NN_RET_CHECK(QuantizeMultiplierSmallerThanOneExp(real_output_multiplier, &output_multiplier, &output_shift)); constexpr bool isSignedOp = std::is_same::value; int32_t output_activation_min; int32_t output_activation_max; if constexpr (isSignedOp) { CalculateActivationRangeInt8(activation, shapeOut, &output_activation_min, &output_activation_max); } else { CalculateActivationRangeUint8(activation, shapeOut, &output_activation_min, &output_activation_max); } tflite::ArithmeticParams op_params; op_params.left_shift = left_shift; op_params.input1_offset = input1_offset; op_params.input1_multiplier = input1_multiplier; op_params.input1_shift = input1_shift; op_params.input2_offset = input2_offset; op_params.input2_multiplier = input2_multiplier; op_params.input2_shift = input2_shift; op_params.output_offset = output_offset; op_params.output_multiplier = output_multiplier; op_params.output_shift = output_shift; tflite::SetActivationParams(output_activation_min, output_activation_max, &op_params); // We are using tflite::optimized_ops::BroadcastAdd unconditionally here // because tflite::optimized_ops::Add fails to pass some of the // sub_quantized_different_scales tests. if constexpr (isSignedOp) { NNTRACE_COMP_SWITCH("reference_integer_ops::BroadcastAdd4DSlow"); tflite::reference_integer_ops::BroadcastAdd4DSlow(op_params, convertShapeToTflshape(shape1), in1, convertShapeToTflshape(shape2), in2, convertShapeToTflshape(shapeOut), out); } else { NNTRACE_COMP_SWITCH("reference_ops::BroadcastAdd4DSlow"); tflite::reference_ops::BroadcastAdd4DSlow(op_params, convertShapeToTflshape(shape1), in1, convertShapeToTflshape(shape2), in2, convertShapeToTflshape(shapeOut), out); } return true; } bool divFloat32(const float* in1, const Shape& shape1, const float* in2, const Shape& shape2, int32_t activation, float* out, const Shape& shapeOut) { NNTRACE_TRANS("divFloat32"); float output_activation_min, output_activation_max; CalculateActivationRangeFloat(activation, &output_activation_min, &output_activation_max); bool needBroadcast = !SameShape(shape1, shape2); if (needBroadcast) { NNTRACE_COMP_SWITCH("optimized_ops::BroadcastDiv"); tflite::optimized_ops::BroadcastDiv( in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2), output_activation_min, output_activation_max, out, convertShapeToDims(shapeOut)); } else { NNTRACE_COMP_SWITCH("optimized_ops::Div"); tflite::optimized_ops::Div(in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2), output_activation_min, output_activation_max, out, convertShapeToDims(shapeOut)); } return true; } bool divFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2, int32_t activation, _Float16* out, const Shape& shapeOut) { NNTRACE_TRANS("divFloat16"); return binaryOperationFloat16(in1, shape1, in2, shape2, activation, out, shapeOut, &divFloat32); } } // namespace bool prepare(IOperationExecutionContext* context) { Shape input1 = context->getInputShape(kInputTensor1); Shape input2 = context->getInputShape(kInputTensor2); Shape output = context->getOutputShape(kOutputTensor); NN_RET_CHECK_LE(getNumberOfDimensions(input1), 4u); NN_RET_CHECK_LE(getNumberOfDimensions(input2), 4u); NN_RET_CHECK(calculateBroadcastedShape(input1, input2, &output)); return context->setOutputShape(kOutputTensor, output); } bool executeAdd(IOperationExecutionContext* context) { // Bypass execution in the case of zero-sized input. if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true; switch (context->getInputType(kInputTensor1)) { case OperandType::TENSOR_FLOAT16: return addFloat16(context->getInputBuffer<_Float16>(kInputTensor1), context->getInputShape(kInputTensor1), context->getInputBuffer<_Float16>(kInputTensor2), context->getInputShape(kInputTensor2), context->getInputValue(kActivationScalar), context->getOutputBuffer<_Float16>(kOutputTensor), context->getOutputShape(kOutputTensor)); case OperandType::TENSOR_FLOAT32: return addFloat32(context->getInputBuffer(kInputTensor1), context->getInputShape(kInputTensor1), context->getInputBuffer(kInputTensor2), context->getInputShape(kInputTensor2), context->getInputValue(kActivationScalar), context->getOutputBuffer(kOutputTensor), context->getOutputShape(kOutputTensor)); case OperandType::TENSOR_QUANT8_ASYMM: return addQuant8(context->getInputBuffer(kInputTensor1), context->getInputShape(kInputTensor1), context->getInputBuffer(kInputTensor2), context->getInputShape(kInputTensor2), context->getInputValue(kActivationScalar), context->getOutputBuffer(kOutputTensor), context->getOutputShape(kOutputTensor)); case OperandType::TENSOR_QUANT8_ASYMM_SIGNED: return addQuant8(context->getInputBuffer(kInputTensor1), context->getInputShape(kInputTensor1), context->getInputBuffer(kInputTensor2), context->getInputShape(kInputTensor2), context->getInputValue(kActivationScalar), context->getOutputBuffer(kOutputTensor), context->getOutputShape(kOutputTensor)); case OperandType::TENSOR_INT32: return executeInt32(context->getInputBuffer(kInputTensor1), context->getInputShape(kInputTensor1), context->getInputBuffer(kInputTensor2), context->getInputShape(kInputTensor2), context->getInputValue(kActivationScalar), context->getOutputBuffer(kOutputTensor), context->getOutputShape(kOutputTensor), [](int32_t a, int32_t b) { return a + b; }); default: NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation ADD"; } } bool executeMul(IOperationExecutionContext* context) { // Bypass execution in the case of zero-sized input. if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true; switch (context->getInputType(kInputTensor1)) { case OperandType::TENSOR_FLOAT16: return mulFloat16(context->getInputBuffer<_Float16>(kInputTensor1), context->getInputShape(kInputTensor1), context->getInputBuffer<_Float16>(kInputTensor2), context->getInputShape(kInputTensor2), context->getInputValue(kActivationScalar), context->getOutputBuffer<_Float16>(kOutputTensor), context->getOutputShape(kOutputTensor)); case OperandType::TENSOR_FLOAT32: return mulFloat32(context->getInputBuffer(kInputTensor1), context->getInputShape(kInputTensor1), context->getInputBuffer(kInputTensor2), context->getInputShape(kInputTensor2), context->getInputValue(kActivationScalar), context->getOutputBuffer(kOutputTensor), context->getOutputShape(kOutputTensor)); case OperandType::TENSOR_QUANT8_ASYMM: return mulQuant8(context->getInputBuffer(kInputTensor1), context->getInputShape(kInputTensor1), context->getInputBuffer(kInputTensor2), context->getInputShape(kInputTensor2), context->getInputValue(kActivationScalar), context->getOutputBuffer(kOutputTensor), context->getOutputShape(kOutputTensor)); case OperandType::TENSOR_QUANT8_ASYMM_SIGNED: return mulQuant8(context->getInputBuffer(kInputTensor1), context->getInputShape(kInputTensor1), context->getInputBuffer(kInputTensor2), context->getInputShape(kInputTensor2), context->getInputValue(kActivationScalar), context->getOutputBuffer(kOutputTensor), context->getOutputShape(kOutputTensor)); case OperandType::TENSOR_INT32: return executeInt32(context->getInputBuffer(kInputTensor1), context->getInputShape(kInputTensor1), context->getInputBuffer(kInputTensor2), context->getInputShape(kInputTensor2), context->getInputValue(kActivationScalar), context->getOutputBuffer(kOutputTensor), context->getOutputShape(kOutputTensor), [](int32_t a, int32_t b) { return a * b; }); default: NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation MUL"; } } bool executeSub(IOperationExecutionContext* context) { // Bypass execution in the case of zero-sized input. if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true; switch (context->getInputType(kInputTensor1)) { case OperandType::TENSOR_FLOAT16: return subFloat16(context->getInputBuffer<_Float16>(kInputTensor1), context->getInputShape(kInputTensor1), context->getInputBuffer<_Float16>(kInputTensor2), context->getInputShape(kInputTensor2), context->getInputValue(kActivationScalar), context->getOutputBuffer<_Float16>(kOutputTensor), context->getOutputShape(kOutputTensor)); case OperandType::TENSOR_FLOAT32: return subFloat32(context->getInputBuffer(kInputTensor1), context->getInputShape(kInputTensor1), context->getInputBuffer(kInputTensor2), context->getInputShape(kInputTensor2), context->getInputValue(kActivationScalar), context->getOutputBuffer(kOutputTensor), context->getOutputShape(kOutputTensor)); case OperandType::TENSOR_QUANT8_ASYMM: return subQuant8(context->getInputBuffer(kInputTensor1), context->getInputShape(kInputTensor1), context->getInputBuffer(kInputTensor2), context->getInputShape(kInputTensor2), context->getInputValue(kActivationScalar), context->getOutputBuffer(kOutputTensor), context->getOutputShape(kOutputTensor)); case OperandType::TENSOR_QUANT8_ASYMM_SIGNED: return subQuant8(context->getInputBuffer(kInputTensor1), context->getInputShape(kInputTensor1), context->getInputBuffer(kInputTensor2), context->getInputShape(kInputTensor2), context->getInputValue(kActivationScalar), context->getOutputBuffer(kOutputTensor), context->getOutputShape(kOutputTensor)); case OperandType::TENSOR_INT32: return executeInt32(context->getInputBuffer(kInputTensor1), context->getInputShape(kInputTensor1), context->getInputBuffer(kInputTensor2), context->getInputShape(kInputTensor2), context->getInputValue(kActivationScalar), context->getOutputBuffer(kOutputTensor), context->getOutputShape(kOutputTensor), [](int32_t a, int32_t b) { return a - b; }); default: NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation SUB"; } } bool executeDiv(IOperationExecutionContext* context) { // Bypass execution in the case of zero-sized input. if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true; switch (context->getInputType(kInputTensor1)) { case OperandType::TENSOR_FLOAT16: return divFloat16(context->getInputBuffer<_Float16>(kInputTensor1), context->getInputShape(kInputTensor1), context->getInputBuffer<_Float16>(kInputTensor2), context->getInputShape(kInputTensor2), context->getInputValue(kActivationScalar), context->getOutputBuffer<_Float16>(kOutputTensor), context->getOutputShape(kOutputTensor)); case OperandType::TENSOR_FLOAT32: return divFloat32(context->getInputBuffer(kInputTensor1), context->getInputShape(kInputTensor1), context->getInputBuffer(kInputTensor2), context->getInputShape(kInputTensor2), context->getInputValue(kActivationScalar), context->getOutputBuffer(kOutputTensor), context->getOutputShape(kOutputTensor)); case OperandType::TENSOR_INT32: return executeInt32(context->getInputBuffer(kInputTensor1), context->getInputShape(kInputTensor1), context->getInputBuffer(kInputTensor2), context->getInputShape(kInputTensor2), context->getInputValue(kActivationScalar), context->getOutputBuffer(kOutputTensor), context->getOutputShape(kOutputTensor), [](int32_t a, int32_t b) { // In NNAPI, DIV by zero is undefined, but should not crash. if (b == 0) return 0; int32_t result = a / b; if (a % b != 0 && ((a < 0) != (b < 0))) { // Implement "floor division". --result; } return result; }); default: NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation DIV"; } } #endif // NN_INCLUDE_CPU_IMPLEMENTATION } // namespace broadcast NN_REGISTER_OPERATION_DEFAULT_VALIDATION(ADD, broadcast::prepare, broadcast::executeAdd, .allowZeroSizedInput = true); NN_REGISTER_OPERATION_DEFAULT_VALIDATION(MUL, broadcast::prepare, broadcast::executeMul, .allowZeroSizedInput = true); NN_REGISTER_OPERATION_DEFAULT_VALIDATION(SUB, broadcast::prepare, broadcast::executeSub, .allowZeroSizedInput = true); NN_REGISTER_OPERATION_DEFAULT_VALIDATION(DIV, broadcast::prepare, broadcast::executeDiv, .allowZeroSizedInput = true); } // namespace nn } // namespace android