1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "Operations"
18 
19 #include "PRelu.h"
20 
21 #include <algorithm>
22 #include <functional>
23 #include <vector>
24 
25 #include "IndexedShapeWrapper.h"
26 #include "OperationResolver.h"
27 #include "OperationsExecutionUtils.h"
28 #include "Tracing.h"
29 
30 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
31 #pragma clang diagnostic push
32 #pragma clang diagnostic ignored "-Wunused-parameter"
33 #pragma clang diagnostic ignored "-Wsign-compare"
34 #pragma clang diagnostic ignored "-Winvalid-partial-specialization"
35 #include <tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h>
36 #pragma clang diagnostic pop
37 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
38 
39 namespace android {
40 namespace nn {
41 namespace prelu {
42 
43 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
44 template <typename T>
eval(const std::function<T (const T &,const T &)> & func,const T * aData,const Shape & aShape,const T * bData,const Shape & bShape,T * outputData,const Shape & outputShape)45 inline bool eval(const std::function<T(const T&, const T&)>& func, const T* aData,
46                  const Shape& aShape, const T* bData, const Shape& bShape, T* outputData,
47                  const Shape& outputShape) {
48     IndexedShapeWrapper aShapeIndexed(aShape);
49     IndexedShapeWrapper bShapeIndexed(bShape);
50     IndexedShapeWrapper outputShapeIndexed(outputShape);
51     std::vector<uint32_t> curIndex(outputShape.dimensions.size(), 0);
52     bool lastIndex = false;
53     do {
54         uint32_t outputFlatIndex;
55         NN_RET_CHECK(outputShapeIndexed.indexToFlatIndex(curIndex, &outputFlatIndex));
56         uint32_t aFlatIndex;
57         NN_RET_CHECK(aShapeIndexed.broadcastedIndexToFlatIndex(curIndex, &aFlatIndex));
58         uint32_t bFlatIndex;
59         NN_RET_CHECK(bShapeIndexed.broadcastedIndexToFlatIndex(curIndex, &bFlatIndex));
60 
61         outputData[outputFlatIndex] = func(aData[aFlatIndex], bData[bFlatIndex]);
62 
63         NN_RET_CHECK(outputShapeIndexed.nextIndexInplace(&curIndex, &lastIndex));
64     } while (!lastIndex);
65     return true;
66 }
67 
68 template <typename T>
evalQuant8(const T * aData,const Shape & aShape,const T * bData,const Shape & bShape,T * outputData,const Shape & outputShape)69 bool evalQuant8(const T* aData, const Shape& aShape, const T* bData, const Shape& bShape,
70                 T* outputData, const Shape& outputShape) {
71     const int32_t input_offset = -aShape.offset;
72     const int32_t alpha_offset = -bShape.offset;
73     const int32_t output_offset = outputShape.offset;
74     const double input_product_scale = aShape.scale * bShape.scale;
75     const double real_multiplier_pos = aShape.scale / outputShape.scale;
76     const double real_multiplier_neg = input_product_scale / outputShape.scale;
77     int32_t output_multiplier_pos, output_shift_pos;
78     int32_t output_multiplier_neg, output_shift_neg;
79     tflite::QuantizeMultiplier(real_multiplier_pos, &output_multiplier_pos, &output_shift_pos);
80     tflite::QuantizeMultiplier(real_multiplier_neg, &output_multiplier_neg, &output_shift_neg);
81     return eval<T>(
82             [&](const T& val1, const T& val2) -> uint8_t {
83                 const int32_t input = input_offset + static_cast<int32_t>(val1);
84                 int32_t output_val;
85                 if (input >= 0) {
86                     output_val =
87                             output_offset + tflite::MultiplyByQuantizedMultiplier(
88                                                     input, output_multiplier_pos, output_shift_pos);
89                 } else {
90                     const int32_t alpha = alpha_offset + static_cast<int32_t>(val2);
91                     output_val = output_offset +
92                                  tflite::MultiplyByQuantizedMultiplier(
93                                          input * alpha, output_multiplier_neg, output_shift_neg);
94                 }
95                 return saturateCast<T>(output_val);
96             },
97             aData, aShape, bData, bShape, outputData, outputShape);
98 }
99 
prepare(IOperationExecutionContext * context)100 bool prepare(IOperationExecutionContext* context) {
101     Shape input = context->getInputShape(kInputTensor);
102     Shape alpha = context->getInputShape(kAlphaTensor);
103     NN_RET_CHECK(input.type == alpha.type);
104     Shape output = context->getOutputShape(kOutputTensor);
105     NN_RET_CHECK(calculateBroadcastedShape(input, alpha, &output));
106     return context->setOutputShape(kOutputTensor, output);
107 }
108 
execute(IOperationExecutionContext * context)109 bool execute(IOperationExecutionContext* context) {
110     switch (context->getInputType(kInputTensor)) {
111         case OperandType::TENSOR_FLOAT16:
112             return eval<_Float16>(
113                     [](const _Float16& val1, const _Float16& val2) -> _Float16 {
114                         return val1 >= 0.0f ? val1 : val1 * val2;
115                     },
116                     context->getInputBuffer<_Float16>(kInputTensor),
117                     context->getInputShape(kInputTensor),
118                     context->getInputBuffer<_Float16>(kAlphaTensor),
119                     context->getInputShape(kAlphaTensor),
120                     context->getOutputBuffer<_Float16>(kOutputTensor),
121                     context->getOutputShape(kOutputTensor));
122         case OperandType::TENSOR_FLOAT32:
123             return eval<float>(
124                     [](const float& val1, const float& val2) -> float {
125                         return val1 >= 0.0f ? val1 : val1 * val2;
126                     },
127                     context->getInputBuffer<float>(kInputTensor),
128                     context->getInputShape(kInputTensor),
129                     context->getInputBuffer<float>(kAlphaTensor),
130                     context->getInputShape(kAlphaTensor),
131                     context->getOutputBuffer<float>(kOutputTensor),
132                     context->getOutputShape(kOutputTensor));
133         case OperandType::TENSOR_QUANT8_ASYMM: {
134             return evalQuant8(context->getInputBuffer<uint8_t>(kInputTensor),
135                               context->getInputShape(kInputTensor),
136                               context->getInputBuffer<uint8_t>(kAlphaTensor),
137                               context->getInputShape(kAlphaTensor),
138                               context->getOutputBuffer<uint8_t>(kOutputTensor),
139                               context->getOutputShape(kOutputTensor));
140         }
141         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED: {
142             return evalQuant8(context->getInputBuffer<int8_t>(kInputTensor),
143                               context->getInputShape(kInputTensor),
144                               context->getInputBuffer<int8_t>(kAlphaTensor),
145                               context->getInputShape(kAlphaTensor),
146                               context->getOutputBuffer<int8_t>(kOutputTensor),
147                               context->getOutputShape(kOutputTensor));
148         }
149         default:
150             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
151     }
152 }
153 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
154 
155 }  // namespace prelu
156 
157 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(PRELU, prelu::prepare, prelu::execute);
158 
159 }  // namespace nn
160 }  // namespace android
161