1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "Operations"
18
19 #include <cmath>
20
21 #include "HalInterfaces.h"
22 #include "OperationResolver.h"
23 #include "OperationsUtils.h"
24 #include "Tracing.h"
25
26 namespace android {
27 namespace nn {
28 namespace elementwise {
29
30 constexpr uint32_t kNumInputs = 1;
31 constexpr uint32_t kInputTensor = 0;
32
33 constexpr uint32_t kNumOutputs = 1;
34 constexpr uint32_t kOutputTensor = 0;
35
36 namespace {
37
38 using namespace hal;
39
40 template <typename IntermediateType, typename T>
compute(IntermediateType func (IntermediateType),const T * input,const Shape & shape,T * output)41 inline bool compute(IntermediateType func(IntermediateType), const T* input, const Shape& shape,
42 T* output) {
43 const auto size = getNumberOfElements(shape);
44 for (uint32_t i = 0; i < size; ++i) {
45 output[i] = static_cast<T>(func(static_cast<IntermediateType>(input[i])));
46 }
47 return true;
48 }
49
execute(IOperationExecutionContext * context,float func (float))50 bool execute(IOperationExecutionContext* context, float func(float)) {
51 switch (context->getInputType(kInputTensor)) {
52 case OperandType::TENSOR_FLOAT16:
53 return compute<float, _Float16>(func, context->getInputBuffer<_Float16>(kInputTensor),
54 context->getInputShape(kInputTensor),
55 context->getOutputBuffer<_Float16>(kOutputTensor));
56 case OperandType::TENSOR_FLOAT32:
57 return compute<float, float>(func, context->getInputBuffer<float>(kInputTensor),
58 context->getInputShape(kInputTensor),
59 context->getOutputBuffer<float>(kOutputTensor));
60 default:
61 NN_RET_CHECK_FAIL() << "Unsupported tensor type for elementwise operation";
62 }
63 }
64
65 } // namespace
66
executeAbs(IOperationExecutionContext * context)67 bool executeAbs(IOperationExecutionContext* context) {
68 switch (context->getInputType(kInputTensor)) {
69 case OperandType::TENSOR_FLOAT16:
70 return compute<float, _Float16>(std::abs,
71 context->getInputBuffer<_Float16>(kInputTensor),
72 context->getInputShape(kInputTensor),
73 context->getOutputBuffer<_Float16>(kOutputTensor));
74 case OperandType::TENSOR_FLOAT32:
75 return compute<float, float>(std::abs, context->getInputBuffer<float>(kInputTensor),
76 context->getInputShape(kInputTensor),
77 context->getOutputBuffer<float>(kOutputTensor));
78 case OperandType::TENSOR_INT32:
79 return compute<int32_t, int32_t>(std::abs,
80 context->getInputBuffer<int32_t>(kInputTensor),
81 context->getInputShape(kInputTensor),
82 context->getOutputBuffer<int32_t>(kOutputTensor));
83 default:
84 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation ABS";
85 }
86 }
87
validate(const IOperationValidationContext * context)88 bool validate(const IOperationValidationContext* context) {
89 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
90 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
91 OperandType inputType = context->getInputType(kInputTensor);
92 NN_RET_CHECK(inputType == OperandType::TENSOR_FLOAT16 ||
93 inputType == OperandType::TENSOR_FLOAT32)
94 << "Unsupported tensor type for elementwise operation";
95 NN_RET_CHECK(validateInputTypes(context, {inputType}));
96 NN_RET_CHECK(validateOutputTypes(context, {inputType}));
97 return validateHalVersion(context, HalVersion::V1_2);
98 }
99
validateAbs(const IOperationValidationContext * context)100 bool validateAbs(const IOperationValidationContext* context) {
101 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
102 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
103 OperandType inputType = context->getInputType(kInputTensor);
104 NN_RET_CHECK(inputType == OperandType::TENSOR_FLOAT16 ||
105 inputType == OperandType::TENSOR_FLOAT32 || inputType == OperandType::TENSOR_INT32)
106 << "Unsupported tensor type for operation ABS";
107 NN_RET_CHECK(validateInputTypes(context, {inputType}));
108 NN_RET_CHECK(validateOutputTypes(context, {inputType}));
109 return validateHalVersion(context, (inputType == OperandType::TENSOR_INT32 ? HalVersion::V1_3
110 : HalVersion::V1_2));
111 }
112
validateFloor(const IOperationValidationContext * context)113 bool validateFloor(const IOperationValidationContext* context) {
114 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
115 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
116
117 OperandType inputType = context->getInputType(kInputTensor);
118 NN_RET_CHECK(inputType == OperandType::TENSOR_FLOAT16 ||
119 inputType == OperandType::TENSOR_FLOAT32)
120 << "Unsupported tensor type for operation FLOOR";
121 NN_RET_CHECK(validateInputTypes(context, {inputType}));
122 NN_RET_CHECK(validateOutputTypes(context, {inputType}));
123
124 const Shape& input = context->getInputShape(kInputTensor);
125 if (hasKnownRank(input)) {
126 NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
127 }
128
129 return validateHalVersion(
130 context,
131 (inputType == OperandType::TENSOR_FLOAT16 ? HalVersion::V1_2 : HalVersion::V1_0));
132 }
133
prepare(IOperationExecutionContext * context)134 bool prepare(IOperationExecutionContext* context) {
135 Shape input = context->getInputShape(kInputTensor);
136 Shape output = context->getOutputShape(kOutputTensor);
137 NN_RET_CHECK(SetShape(input, &output));
138 return context->setOutputShape(kOutputTensor, output);
139 }
140
prepareFloor(IOperationExecutionContext * context)141 bool prepareFloor(IOperationExecutionContext* context) {
142 Shape input = context->getInputShape(kInputTensor);
143 Shape output = context->getOutputShape(kOutputTensor);
144 NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
145 NN_RET_CHECK(SetShape(input, &output));
146 return context->setOutputShape(kOutputTensor, output);
147 }
148
executeExp(IOperationExecutionContext * context)149 bool executeExp(IOperationExecutionContext* context) {
150 return execute(context, std::exp);
151 }
152
executeFloor(IOperationExecutionContext * context)153 bool executeFloor(IOperationExecutionContext* context) {
154 return execute(context, std::floor);
155 }
156
executeLog(IOperationExecutionContext * context)157 bool executeLog(IOperationExecutionContext* context) {
158 return execute(context, std::log);
159 }
160
executeRsqrt(IOperationExecutionContext * context)161 bool executeRsqrt(IOperationExecutionContext* context) {
162 return execute(context, [](float x) { return 1.f / std::sqrt(x); });
163 }
164
executeSin(IOperationExecutionContext * context)165 bool executeSin(IOperationExecutionContext* context) {
166 return execute(context, std::sin);
167 }
168
executeSqrt(IOperationExecutionContext * context)169 bool executeSqrt(IOperationExecutionContext* context) {
170 return execute(context, std::sqrt);
171 }
172
173 } // namespace elementwise
174
175 NN_REGISTER_OPERATION(ABS, "ABS", elementwise::validateAbs, elementwise::prepare,
176 elementwise::executeAbs);
177 NN_REGISTER_OPERATION(EXP, "EXP", elementwise::validate, elementwise::prepare,
178 elementwise::executeExp);
179 NN_REGISTER_OPERATION(FLOOR, "FLOOR", elementwise::validateFloor, elementwise::prepareFloor,
180 elementwise::executeFloor);
181 NN_REGISTER_OPERATION(LOG, "LOG", elementwise::validate, elementwise::prepare,
182 elementwise::executeLog);
183 NN_REGISTER_OPERATION(RSQRT, "RSQRT", elementwise::validate, elementwise::prepare,
184 elementwise::executeRsqrt);
185 NN_REGISTER_OPERATION(SIN, "SIN", elementwise::validate, elementwise::prepare,
186 elementwise::executeSin);
187 NN_REGISTER_OPERATION(SQRT, "SQRT", elementwise::validate, elementwise::prepare,
188 elementwise::executeSqrt);
189
190 } // namespace nn
191 } // namespace android
192