1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "Operations"
18
19 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
20
21 #include "HalInterfaces.h"
22 #include "OperationResolver.h"
23 #include "OperationsUtils.h"
24 #include "Tracing.h"
25
26 namespace android {
27 namespace nn {
28 namespace reduce {
29
30 constexpr uint32_t kNumInputs = 3;
31 constexpr uint32_t kInputTensor = 0;
32 constexpr uint32_t kInputAxes = 1;
33 constexpr uint32_t kInputKeepDims = 2;
34
35 constexpr uint32_t kNumOutputs = 1;
36 constexpr uint32_t kOutputTensor = 0;
37
38 // Values from
39 // https://en.wikipedia.org/wiki/Half-precision_floating-point_format#IEEE_754_half-precision_binary_floating-point_format:_binary16
40 constexpr _Float16 kFloat16Max = 65504;
41 constexpr _Float16 kFloat16Lowest = -kFloat16Max;
42
43 namespace {
44
45 template <typename T>
compute(IOperationExecutionContext * context,T init,T func (T,T))46 inline bool compute(IOperationExecutionContext* context, T init, T func(T, T)) {
47 const Shape inputShape = context->getInputShape(kInputTensor);
48 const Shape axesShape = context->getInputShape(kInputAxes);
49 const Shape outputShape = context->getOutputShape(kOutputTensor);
50 const uint32_t inputRank = getNumberOfDimensions(inputShape);
51 const uint32_t numAxes = getNumberOfElements(axesShape);
52 std::vector<int> tempIndex(inputShape.dimensions.size());
53 std::vector<int> tempAxes(numAxes);
54 return tflite::reference_ops::ReduceGeneric<T>(
55 context->getInputBuffer<T>(kInputTensor),
56 reinterpret_cast<const int32_t*>(inputShape.dimensions.data()), inputRank,
57 context->getOutputBuffer<T>(kOutputTensor),
58 reinterpret_cast<const int32_t*>(outputShape.dimensions.data()),
59 outputShape.dimensions.size(), context->getInputBuffer<int32_t>(kInputAxes), numAxes,
60 context->getInputValue<bool8>(kInputKeepDims), tempIndex.data(), tempAxes.data(), init,
61 func);
62 }
63
64 } // namespace
65
validateProdSum(const IOperationValidationContext * context)66 bool validateProdSum(const IOperationValidationContext* context) {
67 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
68 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
69 OperandType inputType = context->getInputType(kInputTensor);
70 NN_RET_CHECK(inputType == OperandType::TENSOR_FLOAT16 ||
71 inputType == OperandType::TENSOR_FLOAT32)
72 << "Unsupported tensor type for REDUCE_PROD or REDUCE_SUM";
73 NN_RET_CHECK(
74 validateInputTypes(context, {inputType, OperandType::TENSOR_INT32, OperandType::BOOL}));
75 NN_RET_CHECK(validateOutputTypes(context, {inputType}));
76 return validateHalVersion(context, HalVersion::V1_2);
77 }
78
validateMaxMin(const IOperationValidationContext * context)79 bool validateMaxMin(const IOperationValidationContext* context) {
80 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
81 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
82 OperandType inputType = context->getInputType(kInputTensor);
83 NN_RET_CHECK(inputType == OperandType::TENSOR_FLOAT16 ||
84 inputType == OperandType::TENSOR_FLOAT32 ||
85 inputType == OperandType::TENSOR_QUANT8_ASYMM)
86 << "Unsupported tensor type for REDUCE_MAX or REDUCE_MIN";
87 NN_RET_CHECK(
88 validateInputTypes(context, {inputType, OperandType::TENSOR_INT32, OperandType::BOOL}));
89 NN_RET_CHECK(validateOutputTypes(context, {inputType}));
90 return validateHalVersion(context, HalVersion::V1_2);
91 }
92
validateLogical(const IOperationValidationContext * context)93 bool validateLogical(const IOperationValidationContext* context) {
94 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
95 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
96 OperandType inputType = context->getInputType(kInputTensor);
97 NN_RET_CHECK(inputType == OperandType::TENSOR_BOOL8)
98 << "Unsupported tensor type for REDUCE_ANY or REDUCE_ALL";
99 NN_RET_CHECK(
100 validateInputTypes(context, {inputType, OperandType::TENSOR_INT32, OperandType::BOOL}));
101 NN_RET_CHECK(validateOutputTypes(context, {inputType}));
102 return validateHalVersion(context, HalVersion::V1_2);
103 }
104
prepare(IOperationExecutionContext * context)105 bool prepare(IOperationExecutionContext* context) {
106 Shape inputShape = context->getInputShape(kInputTensor);
107 const uint32_t inputRank = getNumberOfDimensions(inputShape);
108
109 std::vector<bool> shouldReduce(inputRank);
110 const int32_t* axes = context->getInputBuffer<int32_t>(kInputAxes);
111 Shape axesShape = context->getInputShape(kInputAxes);
112 NN_RET_CHECK_EQ(getNumberOfDimensions(axesShape), 1u);
113 const uint32_t numAxes = getNumberOfElements(axesShape);
114 for (uint32_t i = 0; i < numAxes; ++i) {
115 int32_t axis = axes[i];
116 NN_RET_CHECK(handleNegativeAxis(inputRank, &axis));
117 shouldReduce[axis] = true;
118 }
119
120 // Input and output must have the same quantization parameters, etc.
121 Shape outputShape = inputShape;
122 outputShape.dimensions.clear();
123 bool keepDims = context->getInputValue<bool8>(kInputKeepDims);
124 for (uint32_t axis = 0; axis < inputRank; ++axis) {
125 if (shouldReduce[axis]) {
126 if (keepDims) {
127 outputShape.dimensions.push_back(1);
128 }
129 } else {
130 outputShape.dimensions.push_back(getSizeOfDimension(inputShape, axis));
131 }
132 }
133
134 return context->setOutputShape(kOutputTensor, outputShape);
135 }
136
executeProd(IOperationExecutionContext * context)137 bool executeProd(IOperationExecutionContext* context) {
138 switch (context->getInputType(kInputTensor)) {
139 case OperandType::TENSOR_FLOAT16:
140 return compute<_Float16>(context, 1, [](_Float16 a, _Float16 b) { return a * b; });
141 case OperandType::TENSOR_FLOAT32:
142 return compute<float>(context, 1, [](float a, float b) { return a * b; });
143 default:
144 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation REDUCE_PROD";
145 }
146 }
147
executeSum(IOperationExecutionContext * context)148 bool executeSum(IOperationExecutionContext* context) {
149 switch (context->getInputType(kInputTensor)) {
150 case OperandType::TENSOR_FLOAT16:
151 return compute<_Float16>(context, 0, [](_Float16 a, _Float16 b) { return a + b; });
152 case OperandType::TENSOR_FLOAT32:
153 return compute<float>(context, 0, [](float a, float b) { return a + b; });
154 default:
155 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation REDUCE_SUM";
156 }
157 }
158
executeMax(IOperationExecutionContext * context)159 bool executeMax(IOperationExecutionContext* context) {
160 switch (context->getInputType(kInputTensor)) {
161 case OperandType::TENSOR_FLOAT16:
162 return compute<_Float16>(context, kFloat16Lowest,
163 [](_Float16 a, _Float16 b) { return std::max(a, b); });
164 case OperandType::TENSOR_FLOAT32:
165 return compute<float>(context, std::numeric_limits<float>::lowest(),
166 [](float a, float b) { return std::max(a, b); });
167 case OperandType::TENSOR_QUANT8_ASYMM:
168 return compute<uint8_t>(context, std::numeric_limits<uint8_t>::lowest(),
169 [](uint8_t a, uint8_t b) { return std::max(a, b); });
170 default:
171 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation REDUCE_MAX";
172 }
173 }
174
executeMin(IOperationExecutionContext * context)175 bool executeMin(IOperationExecutionContext* context) {
176 switch (context->getInputType(kInputTensor)) {
177 case OperandType::TENSOR_FLOAT16:
178 return compute<_Float16>(context, kFloat16Max,
179 [](_Float16 a, _Float16 b) { return std::min(a, b); });
180 case OperandType::TENSOR_FLOAT32:
181 return compute<float>(context, std::numeric_limits<float>::max(),
182 [](float a, float b) { return std::min(a, b); });
183 case OperandType::TENSOR_QUANT8_ASYMM:
184 return compute<uint8_t>(context, std::numeric_limits<uint8_t>::max(),
185 [](uint8_t a, uint8_t b) { return std::min(a, b); });
186 default:
187 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation REDUCE_MIN";
188 }
189 }
190
executeAny(IOperationExecutionContext * context)191 bool executeAny(IOperationExecutionContext* context) {
192 switch (context->getInputType(kInputTensor)) {
193 case OperandType::TENSOR_BOOL8:
194 return compute<bool8>(context, false,
195 [](bool8 a, bool8 b) { return static_cast<bool8>(a || b); });
196 default:
197 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation REDUCE_ANY";
198 }
199 }
200
executeAll(IOperationExecutionContext * context)201 bool executeAll(IOperationExecutionContext* context) {
202 switch (context->getInputType(kInputTensor)) {
203 case OperandType::TENSOR_BOOL8:
204 return compute<bool8>(context, true,
205 [](bool8 a, bool8 b) { return static_cast<bool8>(a && b); });
206 default:
207 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation REDUCE_ALL";
208 }
209 }
210
211 } // namespace reduce
212
213 NN_REGISTER_OPERATION(REDUCE_PROD, "REDUCE_PROD", reduce::validateProdSum, reduce::prepare,
214 reduce::executeProd);
215 NN_REGISTER_OPERATION(REDUCE_SUM, "REDUCE_SUM", reduce::validateProdSum, reduce::prepare,
216 reduce::executeSum);
217 NN_REGISTER_OPERATION(REDUCE_MAX, "REDUCE_MAX", reduce::validateMaxMin, reduce::prepare,
218 reduce::executeMax);
219 NN_REGISTER_OPERATION(REDUCE_MIN, "REDUCE_MIN", reduce::validateMaxMin, reduce::prepare,
220 reduce::executeMin);
221 NN_REGISTER_OPERATION(REDUCE_ANY, "REDUCE_ANY", reduce::validateLogical, reduce::prepare,
222 reduce::executeAny);
223 NN_REGISTER_OPERATION(REDUCE_ALL, "REDUCE_ALL", reduce::validateLogical, reduce::prepare,
224 reduce::executeAll);
225
226 } // namespace nn
227 } // namespace android
228