1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 // Contains the implementation of the operations.
18
19 #define LOG_TAG "Operations"
20
21 #include "Operations.h"
22 #include "CpuOperationUtils.h"
23
24 #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
25 #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
26
27 namespace android {
28 namespace nn {
29
30 #define ANDROID_NN_MACRO_DISPATCH(macro) \
31 switch (activation) { \
32 case (int32_t) FusedActivationFunc::NONE: \
33 macro(kNone); \
34 break; \
35 case (int32_t) FusedActivationFunc::RELU: \
36 macro(kRelu); \
37 break; \
38 case (int32_t) FusedActivationFunc::RELU1: \
39 macro(kRelu1); \
40 break; \
41 case (int32_t) FusedActivationFunc::RELU6: \
42 macro(kRelu6); \
43 break; \
44 default: \
45 LOG(ERROR) << "Unsupported fused activation function type"; \
46 return false; \
47 }
48
addFloat32(const float * in1,const Shape & shape1,const float * in2,const Shape & shape2,int32_t activation,float * out,const Shape & shapeOut)49 bool addFloat32(const float* in1, const Shape& shape1,
50 const float* in2, const Shape& shape2,
51 int32_t activation,
52 float* out, const Shape& shapeOut) {
53 bool needBroadcast = !SameShape(shape1, shape2);
54
55 if (needBroadcast) {
56 #define ANDROID_NN_BROADCAST_ADD(activation) \
57 tflite::optimized_ops::BroadcastAdd<tflite::FusedActivationFunctionType::activation>( \
58 in1, convertShapeToDims(shape1), \
59 in2, convertShapeToDims(shape2), \
60 out, convertShapeToDims(shapeOut))
61
62 ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_BROADCAST_ADD)
63 #undef ANDROID_NN_BROADCAST_ADD
64 } else {
65 float output_activation_min, output_activation_max;
66 CalculateActivationRangeFloat(activation, &output_activation_min,
67 &output_activation_max);
68
69 tflite::optimized_ops::Add(
70 in1, convertShapeToDims(shape1),
71 in2, convertShapeToDims(shape2),
72 output_activation_min, output_activation_max,
73 out, convertShapeToDims(shapeOut));
74 }
75
76 return true;
77 }
78
addQuant8(const uint8_t * in1,const Shape & shape1,const uint8_t * in2,const Shape & shape2,int32_t activation,uint8_t * out,const Shape & shapeOut)79 bool addQuant8(const uint8_t* in1, const Shape& shape1,
80 const uint8_t* in2, const Shape& shape2,
81 int32_t activation,
82 uint8_t* out, const Shape& shapeOut) {
83 bool needBroadcast = !SameShape(shape1, shape2);
84
85 const int32_t input1_offset = -shape1.offset;
86 const int32_t input2_offset = -shape2.offset;
87 const int32_t output_offset = shapeOut.offset;
88 const int left_shift = 20;
89 const double twice_max_input_scale = 2 * std::max(shape1.scale, shape2.scale);
90 const double real_input1_multiplier = shape1.scale / twice_max_input_scale;
91 const double real_input2_multiplier = shape2.scale / twice_max_input_scale;
92 const double real_output_multiplier =
93 twice_max_input_scale /
94 ((1 << left_shift) * shapeOut.scale);
95
96 int32_t input1_multiplier;
97 int32_t input1_shift;
98 if (!QuantizeMultiplierSmallerThanOne(real_input1_multiplier,
99 &input1_multiplier, &input1_shift)) {
100 return false;
101 }
102 int32_t input2_multiplier;
103 int32_t input2_shift;
104 if (!QuantizeMultiplierSmallerThanOne(real_input2_multiplier,
105 &input2_multiplier, &input2_shift)) {
106 return false;
107 }
108 int32_t output_multiplier;
109 int32_t output_shift;
110 if (!QuantizeMultiplierSmallerThanOne(real_output_multiplier,
111 &output_multiplier, &output_shift)) {
112 return false;
113 }
114 int32_t output_activation_min;
115 int32_t output_activation_max;
116 CalculateActivationRangeUint8(activation, shapeOut,
117 &output_activation_min,
118 &output_activation_max);
119
120 if (needBroadcast) {
121 tflite::optimized_ops::BroadcastAdd(
122 left_shift,
123 in1, convertShapeToDims(shape1),
124 input1_offset, input1_multiplier, input1_shift,
125 in2, convertShapeToDims(shape2),
126 input2_offset, input2_multiplier, input2_shift,
127 output_offset, output_multiplier, output_shift,
128 output_activation_min, output_activation_max,
129 out, convertShapeToDims(shapeOut));
130 } else {
131 #define ANDROID_NN_NORMAL_ADD(activation) \
132 tflite::optimized_ops::Add<tflite::FusedActivationFunctionType::activation>( \
133 left_shift, \
134 in1, convertShapeToDims(shape1), \
135 input1_offset, input1_multiplier, input1_shift, \
136 in2, convertShapeToDims(shape2), \
137 input2_offset, input2_multiplier, input2_shift, \
138 output_offset, output_multiplier, output_shift, \
139 output_activation_min, output_activation_max, \
140 out, convertShapeToDims(shapeOut))
141
142 ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_NORMAL_ADD)
143 #undef ANDROID_NN_NORMAL_ADD
144 }
145
146 return true;
147 }
148
mulFloat32(const float * in1,const Shape & shape1,const float * in2,const Shape & shape2,int32_t activation,float * out,const Shape & shapeOut)149 bool mulFloat32(const float* in1, const Shape& shape1,
150 const float* in2, const Shape& shape2,
151 int32_t activation,
152 float* out, const Shape& shapeOut) {
153 bool needBroadcast = !SameShape(shape1, shape2);
154
155 if (needBroadcast) {
156 #define ANDROID_NN_BROADCAST_MUL(activation) \
157 tflite::optimized_ops::BroadcastMul<tflite::FusedActivationFunctionType::activation>( \
158 in1, convertShapeToDims(shape1), \
159 in2, convertShapeToDims(shape2), \
160 out, convertShapeToDims(shapeOut))
161
162 ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_BROADCAST_MUL)
163 #undef ANDROID_NN_BROADCAST_MUL
164 } else {
165 float output_activation_min, output_activation_max;
166 CalculateActivationRangeFloat(activation, &output_activation_min,
167 &output_activation_max);
168
169 tflite::optimized_ops::Mul(
170 in1, convertShapeToDims(shape1),
171 in2, convertShapeToDims(shape2),
172 output_activation_min, output_activation_max,
173 out, convertShapeToDims(shapeOut));
174 }
175
176 return true;
177 }
178
mulQuant8(const uint8_t * in1,const Shape & shape1,const uint8_t * in2,const Shape & shape2,int32_t activation,uint8_t * out,const Shape & shapeOut)179 bool mulQuant8(const uint8_t* in1, const Shape& shape1,
180 const uint8_t* in2, const Shape& shape2,
181 int32_t activation,
182 uint8_t* out, const Shape& shapeOut) {
183 const int32_t input1_offset = -shape1.offset;
184 const int32_t input2_offset = -shape2.offset;
185 const int32_t output_offset = shapeOut.offset;
186 const double input_product_scale = shape1.scale * shape2.scale;
187 const double real_multiplier = input_product_scale / shapeOut.scale;
188 int32 output_multiplier;
189 int output_shift;
190 if (!QuantizeMultiplierSmallerThanOne(real_multiplier, &output_multiplier,
191 &output_shift)) {
192 return false;
193 }
194 int32_t output_activation_min;
195 int32_t output_activation_max;
196 CalculateActivationRangeUint8(activation, shapeOut,
197 &output_activation_min,
198 &output_activation_max);
199
200 // Use BROADCAST version to handle the normal case.
201 tflite::optimized_ops::BroadcastMul(
202 in1, convertShapeToDims(shape1), input1_offset,
203 in2, convertShapeToDims(shape2), input2_offset,
204 output_offset, output_multiplier, output_shift,
205 output_activation_min, output_activation_max,
206 out, convertShapeToDims(shapeOut));
207
208 return true;
209 }
210
floorFloat32(const float * inputData,float * outputData,const Shape & shape)211 bool floorFloat32(const float* inputData,
212 float* outputData,
213 const Shape& shape) {
214 tflite::Dims<4> dim = convertShapeToDims(shape);
215 tflite::optimized_ops::Floor(inputData, dim, outputData, dim);
216 return true;
217 }
218
dequantizeQuant8ToFloat32(const uint8_t * inputData,float * outputData,const Shape & shape)219 bool dequantizeQuant8ToFloat32(const uint8_t* inputData,
220 float* outputData,
221 const Shape& shape) {
222 tflite::Dims<4> dim = convertShapeToDims(shape);
223 tflite::optimized_ops::Dequantize(inputData, dim,
224 shape.offset, shape.scale,
225 outputData, dim);
226 return true;
227 }
228
subFloat32(const float * in1,const Shape & shape1,const float * in2,const Shape & shape2,int32_t activation,float * out,const Shape & shapeOut)229 bool subFloat32(const float* in1, const Shape& shape1,
230 const float* in2, const Shape& shape2,
231 int32_t activation,
232 float* out, const Shape& shapeOut) {
233 float output_activation_min, output_activation_max;
234 CalculateActivationRangeFloat(activation, &output_activation_min,
235 &output_activation_max);
236
237 bool needBroadcast = !SameShape(shape1, shape2);
238 if (needBroadcast) {
239 tflite::optimized_ops::BroadcastSub(
240 in1, convertShapeToDims(shape1),
241 in2, convertShapeToDims(shape2),
242 output_activation_min, output_activation_max,
243 out, convertShapeToDims(shapeOut));
244 } else {
245 tflite::optimized_ops::Sub(
246 in1, convertShapeToDims(shape1),
247 in2, convertShapeToDims(shape2),
248 output_activation_min, output_activation_max,
249 out, convertShapeToDims(shapeOut));
250 }
251 return true;
252 }
253
divFloat32(const float * in1,const Shape & shape1,const float * in2,const Shape & shape2,int32_t activation,float * out,const Shape & shapeOut)254 bool divFloat32(const float* in1, const Shape& shape1,
255 const float* in2, const Shape& shape2,
256 int32_t activation,
257 float* out, const Shape& shapeOut) {
258 float output_activation_min, output_activation_max;
259 CalculateActivationRangeFloat(activation, &output_activation_min,
260 &output_activation_max);
261
262 bool needBroadcast = !SameShape(shape1, shape2);
263 if (needBroadcast) {
264 tflite::optimized_ops::BroadcastDiv(
265 in1, convertShapeToDims(shape1),
266 in2, convertShapeToDims(shape2),
267 output_activation_min, output_activation_max,
268 out, convertShapeToDims(shapeOut));
269 } else {
270 tflite::optimized_ops::Div(
271 in1, convertShapeToDims(shape1),
272 in2, convertShapeToDims(shape2),
273 output_activation_min, output_activation_max,
274 out, convertShapeToDims(shapeOut));
275 }
276 return true;
277 }
278
meanGeneric(const uint8_t * inputData,const Shape & inputShape,const int32_t * axis,const Shape & axisShape,bool keepDims,uint8_t * outputData,const Shape & outputShape)279 bool meanGeneric(const uint8_t* inputData, const Shape& inputShape,
280 const int32_t* axis, const Shape& axisShape, bool keepDims,
281 uint8_t* outputData, const Shape& outputShape) {
282 // Creates a temp index to iterate through input data.
283 int32_t* scratchBuffer = new int32_t[getNumberOfDimensions(inputShape)];
284
285 // Creates a temp tensor to store resolved axis given input data.
286 int32_t axisSize = static_cast<int32_t>(getSizeOfDimension(axisShape, 0));
287 int32_t* resolvedAxis = new int32_t[axisSize];
288
289 bool result = true;
290 if (inputShape.type == OperandType::TENSOR_FLOAT32) {
291 tflite::reference_ops::Mean<float>(
292 const_cast<float*>(reinterpret_cast<const float*>(inputData)),
293 reinterpret_cast<const int*>(inputShape.dimensions.data()),
294 getNumberOfDimensions(inputShape),
295 reinterpret_cast<float*>(outputData),
296 reinterpret_cast<const int*>(outputShape.dimensions.data()),
297 getNumberOfDimensions(outputShape),
298 axis, axisSize, keepDims, scratchBuffer, resolvedAxis);
299 } else if (inputShape.type == OperandType::TENSOR_QUANT8_ASYMM) {
300 tflite::reference_ops::Mean<uint8_t>(
301 const_cast<uint8_t*>(inputData),
302 reinterpret_cast<const int*>(inputShape.dimensions.data()),
303 getNumberOfDimensions(inputShape),
304 outputData,
305 reinterpret_cast<const int*>(outputShape.dimensions.data()),
306 getNumberOfDimensions(outputShape),
307 axis, axisSize, keepDims, scratchBuffer, resolvedAxis);
308 } else {
309 LOG(ERROR) << "Unsupported data type";
310 result = false;
311 }
312 delete[] scratchBuffer;
313 delete[] resolvedAxis;
314 return result;
315 }
316 } // namespace nn
317 } // namespace android
318