1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 // Contains the implementation of the operations.
18 
19 #define LOG_TAG "Operations"
20 
21 #include "Operations.h"
22 #include "CpuOperationUtils.h"
23 
24 #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
25 #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
26 
27 namespace android {
28 namespace nn {
29 
30 #define ANDROID_NN_MACRO_DISPATCH(macro)                                    \
31     switch (activation) {                                                   \
32         case (int32_t) FusedActivationFunc::NONE:                           \
33             macro(kNone);                                                   \
34             break;                                                          \
35         case (int32_t) FusedActivationFunc::RELU:                           \
36             macro(kRelu);                                                   \
37             break;                                                          \
38         case (int32_t) FusedActivationFunc::RELU1:                          \
39             macro(kRelu1);                                                  \
40             break;                                                          \
41         case (int32_t) FusedActivationFunc::RELU6:                          \
42             macro(kRelu6);                                                  \
43             break;                                                          \
44         default:                                                            \
45             LOG(ERROR) << "Unsupported fused activation function type";     \
46             return false;                                                   \
47     }
48 
addFloat32(const float * in1,const Shape & shape1,const float * in2,const Shape & shape2,int32_t activation,float * out,const Shape & shapeOut)49 bool addFloat32(const float* in1, const Shape& shape1,
50                 const float* in2, const Shape& shape2,
51                 int32_t activation,
52                 float* out, const Shape& shapeOut) {
53     bool needBroadcast = !SameShape(shape1, shape2);
54 
55     if (needBroadcast) {
56         #define ANDROID_NN_BROADCAST_ADD(activation)                                              \
57             tflite::optimized_ops::BroadcastAdd<tflite::FusedActivationFunctionType::activation>( \
58                     in1, convertShapeToDims(shape1),                                              \
59                     in2, convertShapeToDims(shape2),                                              \
60                     out, convertShapeToDims(shapeOut))
61 
62         ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_BROADCAST_ADD)
63         #undef ANDROID_NN_BROADCAST_ADD
64     } else {
65         float output_activation_min, output_activation_max;
66         CalculateActivationRangeFloat(activation, &output_activation_min,
67                                       &output_activation_max);
68 
69         tflite::optimized_ops::Add(
70                 in1, convertShapeToDims(shape1),
71                 in2, convertShapeToDims(shape2),
72                 output_activation_min, output_activation_max,
73                 out, convertShapeToDims(shapeOut));
74     }
75 
76     return true;
77 }
78 
addQuant8(const uint8_t * in1,const Shape & shape1,const uint8_t * in2,const Shape & shape2,int32_t activation,uint8_t * out,const Shape & shapeOut)79 bool addQuant8(const uint8_t* in1, const Shape& shape1,
80                const uint8_t* in2, const Shape& shape2,
81                int32_t activation,
82                uint8_t* out, const Shape& shapeOut) {
83     bool needBroadcast = !SameShape(shape1, shape2);
84 
85     const int32_t input1_offset = -shape1.offset;
86     const int32_t input2_offset = -shape2.offset;
87     const int32_t output_offset = shapeOut.offset;
88     const int left_shift = 20;
89     const double twice_max_input_scale = 2 * std::max(shape1.scale, shape2.scale);
90     const double real_input1_multiplier = shape1.scale / twice_max_input_scale;
91     const double real_input2_multiplier = shape2.scale / twice_max_input_scale;
92     const double real_output_multiplier =
93             twice_max_input_scale /
94             ((1 << left_shift) * shapeOut.scale);
95 
96     int32_t input1_multiplier;
97     int32_t input1_shift;
98     if (!QuantizeMultiplierSmallerThanOne(real_input1_multiplier,
99                                           &input1_multiplier, &input1_shift)) {
100         return false;
101     }
102     int32_t input2_multiplier;
103     int32_t input2_shift;
104     if (!QuantizeMultiplierSmallerThanOne(real_input2_multiplier,
105                                           &input2_multiplier, &input2_shift)) {
106         return false;
107     }
108     int32_t output_multiplier;
109     int32_t output_shift;
110     if (!QuantizeMultiplierSmallerThanOne(real_output_multiplier,
111                                           &output_multiplier, &output_shift)) {
112         return false;
113     }
114     int32_t output_activation_min;
115     int32_t output_activation_max;
116     CalculateActivationRangeUint8(activation, shapeOut,
117                                   &output_activation_min,
118                                   &output_activation_max);
119 
120     if (needBroadcast) {
121         tflite::optimized_ops::BroadcastAdd(
122                 left_shift,
123                 in1, convertShapeToDims(shape1),
124                 input1_offset, input1_multiplier, input1_shift,
125                 in2, convertShapeToDims(shape2),
126                 input2_offset, input2_multiplier, input2_shift,
127                 output_offset, output_multiplier, output_shift,
128                 output_activation_min, output_activation_max,
129                 out, convertShapeToDims(shapeOut));
130     } else {
131         #define ANDROID_NN_NORMAL_ADD(activation)                                        \
132             tflite::optimized_ops::Add<tflite::FusedActivationFunctionType::activation>( \
133                     left_shift,                                                          \
134                     in1, convertShapeToDims(shape1),                                     \
135                     input1_offset, input1_multiplier, input1_shift,                      \
136                     in2, convertShapeToDims(shape2),                                     \
137                     input2_offset, input2_multiplier, input2_shift,                      \
138                     output_offset, output_multiplier, output_shift,                      \
139                     output_activation_min, output_activation_max,                        \
140                     out, convertShapeToDims(shapeOut))
141 
142         ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_NORMAL_ADD)
143         #undef ANDROID_NN_NORMAL_ADD
144     }
145 
146     return true;
147 }
148 
mulFloat32(const float * in1,const Shape & shape1,const float * in2,const Shape & shape2,int32_t activation,float * out,const Shape & shapeOut)149 bool mulFloat32(const float* in1, const Shape& shape1,
150                 const float* in2, const Shape& shape2,
151                 int32_t activation,
152                 float* out, const Shape& shapeOut) {
153     bool needBroadcast = !SameShape(shape1, shape2);
154 
155     if (needBroadcast) {
156     #define ANDROID_NN_BROADCAST_MUL(activation)                                              \
157         tflite::optimized_ops::BroadcastMul<tflite::FusedActivationFunctionType::activation>( \
158                 in1, convertShapeToDims(shape1),                                              \
159                 in2, convertShapeToDims(shape2),                                              \
160                 out, convertShapeToDims(shapeOut))
161 
162         ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_BROADCAST_MUL)
163         #undef ANDROID_NN_BROADCAST_MUL
164     } else {
165         float output_activation_min, output_activation_max;
166         CalculateActivationRangeFloat(activation, &output_activation_min,
167                                       &output_activation_max);
168 
169         tflite::optimized_ops::Mul(
170                 in1, convertShapeToDims(shape1),
171                 in2, convertShapeToDims(shape2),
172                 output_activation_min, output_activation_max,
173                 out, convertShapeToDims(shapeOut));
174     }
175 
176     return true;
177 }
178 
mulQuant8(const uint8_t * in1,const Shape & shape1,const uint8_t * in2,const Shape & shape2,int32_t activation,uint8_t * out,const Shape & shapeOut)179 bool mulQuant8(const uint8_t* in1, const Shape& shape1,
180                const uint8_t* in2, const Shape& shape2,
181                int32_t activation,
182                uint8_t* out, const Shape& shapeOut) {
183     const int32_t input1_offset = -shape1.offset;
184     const int32_t input2_offset = -shape2.offset;
185     const int32_t output_offset = shapeOut.offset;
186     const double input_product_scale = shape1.scale * shape2.scale;
187     const double real_multiplier = input_product_scale / shapeOut.scale;
188     int32 output_multiplier;
189     int output_shift;
190     if (!QuantizeMultiplierSmallerThanOne(real_multiplier, &output_multiplier,
191                                           &output_shift)) {
192         return false;
193     }
194     int32_t output_activation_min;
195     int32_t output_activation_max;
196     CalculateActivationRangeUint8(activation, shapeOut,
197                                   &output_activation_min,
198                                   &output_activation_max);
199 
200     // Use BROADCAST version to handle the normal case.
201     tflite::optimized_ops::BroadcastMul(
202                 in1, convertShapeToDims(shape1), input1_offset,
203                 in2, convertShapeToDims(shape2), input2_offset,
204                 output_offset, output_multiplier, output_shift,
205                 output_activation_min, output_activation_max,
206                 out, convertShapeToDims(shapeOut));
207 
208     return true;
209 }
210 
floorFloat32(const float * inputData,float * outputData,const Shape & shape)211 bool floorFloat32(const float* inputData,
212                   float* outputData,
213                   const Shape& shape) {
214     tflite::Dims<4> dim = convertShapeToDims(shape);
215     tflite::optimized_ops::Floor(inputData, dim, outputData, dim);
216     return true;
217 }
218 
dequantizeQuant8ToFloat32(const uint8_t * inputData,float * outputData,const Shape & shape)219 bool dequantizeQuant8ToFloat32(const uint8_t* inputData,
220                                float* outputData,
221                                const Shape& shape) {
222     tflite::Dims<4> dim = convertShapeToDims(shape);
223     tflite::optimized_ops::Dequantize(inputData, dim,
224                                       shape.offset, shape.scale,
225                                       outputData, dim);
226     return true;
227 }
228 
subFloat32(const float * in1,const Shape & shape1,const float * in2,const Shape & shape2,int32_t activation,float * out,const Shape & shapeOut)229 bool subFloat32(const float* in1, const Shape& shape1,
230                 const float* in2, const Shape& shape2,
231                 int32_t activation,
232                 float* out, const Shape& shapeOut) {
233     float output_activation_min, output_activation_max;
234     CalculateActivationRangeFloat(activation, &output_activation_min,
235                                   &output_activation_max);
236 
237     bool needBroadcast = !SameShape(shape1, shape2);
238     if (needBroadcast) {
239         tflite::optimized_ops::BroadcastSub(
240                 in1, convertShapeToDims(shape1),
241                 in2, convertShapeToDims(shape2),
242                 output_activation_min, output_activation_max,
243                 out, convertShapeToDims(shapeOut));
244     } else {
245         tflite::optimized_ops::Sub(
246                 in1, convertShapeToDims(shape1),
247                 in2, convertShapeToDims(shape2),
248                 output_activation_min, output_activation_max,
249                 out, convertShapeToDims(shapeOut));
250     }
251     return true;
252 }
253 
divFloat32(const float * in1,const Shape & shape1,const float * in2,const Shape & shape2,int32_t activation,float * out,const Shape & shapeOut)254 bool divFloat32(const float* in1, const Shape& shape1,
255                 const float* in2, const Shape& shape2,
256                 int32_t activation,
257                 float* out, const Shape& shapeOut) {
258     float output_activation_min, output_activation_max;
259     CalculateActivationRangeFloat(activation, &output_activation_min,
260                                   &output_activation_max);
261 
262     bool needBroadcast = !SameShape(shape1, shape2);
263     if (needBroadcast) {
264         tflite::optimized_ops::BroadcastDiv(
265                 in1, convertShapeToDims(shape1),
266                 in2, convertShapeToDims(shape2),
267                 output_activation_min, output_activation_max,
268                 out, convertShapeToDims(shapeOut));
269     } else {
270         tflite::optimized_ops::Div(
271                 in1, convertShapeToDims(shape1),
272                 in2, convertShapeToDims(shape2),
273                 output_activation_min, output_activation_max,
274                 out, convertShapeToDims(shapeOut));
275     }
276     return true;
277 }
278 
meanGeneric(const uint8_t * inputData,const Shape & inputShape,const int32_t * axis,const Shape & axisShape,bool keepDims,uint8_t * outputData,const Shape & outputShape)279 bool meanGeneric(const uint8_t* inputData, const Shape& inputShape,
280                  const int32_t* axis, const Shape& axisShape, bool keepDims,
281                  uint8_t* outputData, const Shape& outputShape) {
282     // Creates a temp index to iterate through input data.
283     int32_t* scratchBuffer = new int32_t[getNumberOfDimensions(inputShape)];
284 
285     // Creates a temp tensor to store resolved axis given input data.
286     int32_t axisSize = static_cast<int32_t>(getSizeOfDimension(axisShape, 0));
287     int32_t* resolvedAxis = new int32_t[axisSize];
288 
289     bool result = true;
290     if (inputShape.type == OperandType::TENSOR_FLOAT32) {
291         tflite::reference_ops::Mean<float>(
292                 const_cast<float*>(reinterpret_cast<const float*>(inputData)),
293                 reinterpret_cast<const int*>(inputShape.dimensions.data()),
294                 getNumberOfDimensions(inputShape),
295                 reinterpret_cast<float*>(outputData),
296                 reinterpret_cast<const int*>(outputShape.dimensions.data()),
297                 getNumberOfDimensions(outputShape),
298                 axis, axisSize, keepDims, scratchBuffer, resolvedAxis);
299     } else if (inputShape.type == OperandType::TENSOR_QUANT8_ASYMM) {
300         tflite::reference_ops::Mean<uint8_t>(
301                 const_cast<uint8_t*>(inputData),
302                 reinterpret_cast<const int*>(inputShape.dimensions.data()),
303                 getNumberOfDimensions(inputShape),
304                 outputData,
305                 reinterpret_cast<const int*>(outputShape.dimensions.data()),
306                 getNumberOfDimensions(outputShape),
307                 axis, axisSize, keepDims, scratchBuffer, resolvedAxis);
308     } else {
309         LOG(ERROR) << "Unsupported data type";
310         result = false;
311     }
312     delete[] scratchBuffer;
313     delete[] resolvedAxis;
314     return result;
315 }
316 } // namespace nn
317 } // namespace android
318