1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "Operations"
18 
19 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
20 
21 #include "HalInterfaces.h"
22 #include "OperationResolver.h"
23 #include "OperationsUtils.h"
24 #include "Tracing.h"
25 
26 namespace android {
27 namespace nn {
28 namespace reduce {
29 
30 constexpr uint32_t kNumInputs = 3;
31 constexpr uint32_t kInputTensor = 0;
32 constexpr uint32_t kInputAxes = 1;
33 constexpr uint32_t kInputKeepDims = 2;
34 
35 constexpr uint32_t kNumOutputs = 1;
36 constexpr uint32_t kOutputTensor = 0;
37 
38 // Values from
39 // https://en.wikipedia.org/wiki/Half-precision_floating-point_format#IEEE_754_half-precision_binary_floating-point_format:_binary16
40 constexpr _Float16 kFloat16Max = 65504;
41 constexpr _Float16 kFloat16Lowest = -kFloat16Max;
42 
43 namespace {
44 
45 template <typename T>
compute(IOperationExecutionContext * context,T init,T func (T,T))46 inline bool compute(IOperationExecutionContext* context, T init, T func(T, T)) {
47     const Shape inputShape = context->getInputShape(kInputTensor);
48     const Shape axesShape = context->getInputShape(kInputAxes);
49     const Shape outputShape = context->getOutputShape(kOutputTensor);
50     const uint32_t inputRank = getNumberOfDimensions(inputShape);
51     const uint32_t numAxes = getNumberOfElements(axesShape);
52     std::vector<int> tempIndex(inputShape.dimensions.size());
53     std::vector<int> tempAxes(numAxes);
54     return tflite::reference_ops::ReduceGeneric<T>(
55             context->getInputBuffer<T>(kInputTensor),
56             reinterpret_cast<const int32_t*>(inputShape.dimensions.data()), inputRank,
57             context->getOutputBuffer<T>(kOutputTensor),
58             reinterpret_cast<const int32_t*>(outputShape.dimensions.data()),
59             outputShape.dimensions.size(), context->getInputBuffer<int32_t>(kInputAxes), numAxes,
60             context->getInputValue<bool8>(kInputKeepDims), tempIndex.data(), tempAxes.data(), init,
61             func);
62 }
63 
64 }  // namespace
65 
validateProdSum(const IOperationValidationContext * context)66 bool validateProdSum(const IOperationValidationContext* context) {
67     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
68     NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
69     OperandType inputType = context->getInputType(kInputTensor);
70     NN_RET_CHECK(inputType == OperandType::TENSOR_FLOAT16 ||
71                  inputType == OperandType::TENSOR_FLOAT32)
72             << "Unsupported tensor type for REDUCE_PROD or REDUCE_SUM";
73     NN_RET_CHECK(
74             validateInputTypes(context, {inputType, OperandType::TENSOR_INT32, OperandType::BOOL}));
75     NN_RET_CHECK(validateOutputTypes(context, {inputType}));
76     return validateHalVersion(context, HalVersion::V1_2);
77 }
78 
validateMaxMin(const IOperationValidationContext * context)79 bool validateMaxMin(const IOperationValidationContext* context) {
80     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
81     NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
82     OperandType inputType = context->getInputType(kInputTensor);
83     NN_RET_CHECK(inputType == OperandType::TENSOR_FLOAT16 ||
84                  inputType == OperandType::TENSOR_FLOAT32 ||
85                  inputType == OperandType::TENSOR_QUANT8_ASYMM)
86             << "Unsupported tensor type for REDUCE_MAX or REDUCE_MIN";
87     NN_RET_CHECK(
88             validateInputTypes(context, {inputType, OperandType::TENSOR_INT32, OperandType::BOOL}));
89     NN_RET_CHECK(validateOutputTypes(context, {inputType}));
90     return validateHalVersion(context, HalVersion::V1_2);
91 }
92 
validateLogical(const IOperationValidationContext * context)93 bool validateLogical(const IOperationValidationContext* context) {
94     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
95     NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
96     OperandType inputType = context->getInputType(kInputTensor);
97     NN_RET_CHECK(inputType == OperandType::TENSOR_BOOL8)
98             << "Unsupported tensor type for REDUCE_ANY or REDUCE_ALL";
99     NN_RET_CHECK(
100             validateInputTypes(context, {inputType, OperandType::TENSOR_INT32, OperandType::BOOL}));
101     NN_RET_CHECK(validateOutputTypes(context, {inputType}));
102     return validateHalVersion(context, HalVersion::V1_2);
103 }
104 
prepare(IOperationExecutionContext * context)105 bool prepare(IOperationExecutionContext* context) {
106     Shape inputShape = context->getInputShape(kInputTensor);
107     const uint32_t inputRank = getNumberOfDimensions(inputShape);
108 
109     std::vector<bool> shouldReduce(inputRank);
110     const int32_t* axes = context->getInputBuffer<int32_t>(kInputAxes);
111     Shape axesShape = context->getInputShape(kInputAxes);
112     NN_RET_CHECK_EQ(getNumberOfDimensions(axesShape), 1u);
113     const uint32_t numAxes = getNumberOfElements(axesShape);
114     for (uint32_t i = 0; i < numAxes; ++i) {
115         int32_t axis = axes[i];
116         NN_RET_CHECK(handleNegativeAxis(inputRank, &axis));
117         shouldReduce[axis] = true;
118     }
119 
120     // Input and output must have the same quantization parameters, etc.
121     Shape outputShape = inputShape;
122     outputShape.dimensions.clear();
123     bool keepDims = context->getInputValue<bool8>(kInputKeepDims);
124     for (uint32_t axis = 0; axis < inputRank; ++axis) {
125         if (shouldReduce[axis]) {
126             if (keepDims) {
127                 outputShape.dimensions.push_back(1);
128             }
129         } else {
130             outputShape.dimensions.push_back(getSizeOfDimension(inputShape, axis));
131         }
132     }
133 
134     return context->setOutputShape(kOutputTensor, outputShape);
135 }
136 
executeProd(IOperationExecutionContext * context)137 bool executeProd(IOperationExecutionContext* context) {
138     switch (context->getInputType(kInputTensor)) {
139         case OperandType::TENSOR_FLOAT16:
140             return compute<_Float16>(context, 1, [](_Float16 a, _Float16 b) { return a * b; });
141         case OperandType::TENSOR_FLOAT32:
142             return compute<float>(context, 1, [](float a, float b) { return a * b; });
143         default:
144             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation REDUCE_PROD";
145     }
146 }
147 
executeSum(IOperationExecutionContext * context)148 bool executeSum(IOperationExecutionContext* context) {
149     switch (context->getInputType(kInputTensor)) {
150         case OperandType::TENSOR_FLOAT16:
151             return compute<_Float16>(context, 0, [](_Float16 a, _Float16 b) { return a + b; });
152         case OperandType::TENSOR_FLOAT32:
153             return compute<float>(context, 0, [](float a, float b) { return a + b; });
154         default:
155             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation REDUCE_SUM";
156     }
157 }
158 
executeMax(IOperationExecutionContext * context)159 bool executeMax(IOperationExecutionContext* context) {
160     switch (context->getInputType(kInputTensor)) {
161         case OperandType::TENSOR_FLOAT16:
162             return compute<_Float16>(context, kFloat16Lowest,
163                                      [](_Float16 a, _Float16 b) { return std::max(a, b); });
164         case OperandType::TENSOR_FLOAT32:
165             return compute<float>(context, std::numeric_limits<float>::lowest(),
166                                   [](float a, float b) { return std::max(a, b); });
167         case OperandType::TENSOR_QUANT8_ASYMM:
168             return compute<uint8_t>(context, std::numeric_limits<uint8_t>::lowest(),
169                                     [](uint8_t a, uint8_t b) { return std::max(a, b); });
170         default:
171             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation REDUCE_MAX";
172     }
173 }
174 
executeMin(IOperationExecutionContext * context)175 bool executeMin(IOperationExecutionContext* context) {
176     switch (context->getInputType(kInputTensor)) {
177         case OperandType::TENSOR_FLOAT16:
178             return compute<_Float16>(context, kFloat16Max,
179                                      [](_Float16 a, _Float16 b) { return std::min(a, b); });
180         case OperandType::TENSOR_FLOAT32:
181             return compute<float>(context, std::numeric_limits<float>::max(),
182                                   [](float a, float b) { return std::min(a, b); });
183         case OperandType::TENSOR_QUANT8_ASYMM:
184             return compute<uint8_t>(context, std::numeric_limits<uint8_t>::max(),
185                                     [](uint8_t a, uint8_t b) { return std::min(a, b); });
186         default:
187             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation REDUCE_MIN";
188     }
189 }
190 
executeAny(IOperationExecutionContext * context)191 bool executeAny(IOperationExecutionContext* context) {
192     switch (context->getInputType(kInputTensor)) {
193         case OperandType::TENSOR_BOOL8:
194             return compute<bool8>(context, false,
195                                   [](bool8 a, bool8 b) { return static_cast<bool8>(a || b); });
196         default:
197             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation REDUCE_ANY";
198     }
199 }
200 
executeAll(IOperationExecutionContext * context)201 bool executeAll(IOperationExecutionContext* context) {
202     switch (context->getInputType(kInputTensor)) {
203         case OperandType::TENSOR_BOOL8:
204             return compute<bool8>(context, true,
205                                   [](bool8 a, bool8 b) { return static_cast<bool8>(a && b); });
206         default:
207             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation REDUCE_ALL";
208     }
209 }
210 
211 }  // namespace reduce
212 
213 NN_REGISTER_OPERATION(REDUCE_PROD, "REDUCE_PROD", reduce::validateProdSum, reduce::prepare,
214                       reduce::executeProd);
215 NN_REGISTER_OPERATION(REDUCE_SUM, "REDUCE_SUM", reduce::validateProdSum, reduce::prepare,
216                       reduce::executeSum);
217 NN_REGISTER_OPERATION(REDUCE_MAX, "REDUCE_MAX", reduce::validateMaxMin, reduce::prepare,
218                       reduce::executeMax);
219 NN_REGISTER_OPERATION(REDUCE_MIN, "REDUCE_MIN", reduce::validateMaxMin, reduce::prepare,
220                       reduce::executeMin);
221 NN_REGISTER_OPERATION(REDUCE_ANY, "REDUCE_ANY", reduce::validateLogical, reduce::prepare,
222                       reduce::executeAny);
223 NN_REGISTER_OPERATION(REDUCE_ALL, "REDUCE_ALL", reduce::validateLogical, reduce::prepare,
224                       reduce::executeAll);
225 
226 }  // namespace nn
227 }  // namespace android
228