1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "Operations"
18 
19 #include <cmath>
20 
21 #include "HalInterfaces.h"
22 #include "OperationResolver.h"
23 #include "OperationsUtils.h"
24 #include "Tracing.h"
25 
26 namespace android {
27 namespace nn {
28 namespace elementwise {
29 
30 constexpr uint32_t kNumInputs = 1;
31 constexpr uint32_t kInputTensor = 0;
32 
33 constexpr uint32_t kNumOutputs = 1;
34 constexpr uint32_t kOutputTensor = 0;
35 
36 namespace {
37 
38 using namespace hal;
39 
40 template <typename IntermediateType, typename T>
compute(IntermediateType func (IntermediateType),const T * input,const Shape & shape,T * output)41 inline bool compute(IntermediateType func(IntermediateType), const T* input, const Shape& shape,
42                     T* output) {
43     const auto size = getNumberOfElements(shape);
44     for (uint32_t i = 0; i < size; ++i) {
45         output[i] = static_cast<T>(func(static_cast<IntermediateType>(input[i])));
46     }
47     return true;
48 }
49 
execute(IOperationExecutionContext * context,float func (float))50 bool execute(IOperationExecutionContext* context, float func(float)) {
51     switch (context->getInputType(kInputTensor)) {
52         case OperandType::TENSOR_FLOAT16:
53             return compute<float, _Float16>(func, context->getInputBuffer<_Float16>(kInputTensor),
54                                             context->getInputShape(kInputTensor),
55                                             context->getOutputBuffer<_Float16>(kOutputTensor));
56         case OperandType::TENSOR_FLOAT32:
57             return compute<float, float>(func, context->getInputBuffer<float>(kInputTensor),
58                                          context->getInputShape(kInputTensor),
59                                          context->getOutputBuffer<float>(kOutputTensor));
60         default:
61             NN_RET_CHECK_FAIL() << "Unsupported tensor type for elementwise operation";
62     }
63 }
64 
65 }  // namespace
66 
executeAbs(IOperationExecutionContext * context)67 bool executeAbs(IOperationExecutionContext* context) {
68     switch (context->getInputType(kInputTensor)) {
69         case OperandType::TENSOR_FLOAT16:
70             return compute<float, _Float16>(std::abs,
71                                             context->getInputBuffer<_Float16>(kInputTensor),
72                                             context->getInputShape(kInputTensor),
73                                             context->getOutputBuffer<_Float16>(kOutputTensor));
74         case OperandType::TENSOR_FLOAT32:
75             return compute<float, float>(std::abs, context->getInputBuffer<float>(kInputTensor),
76                                          context->getInputShape(kInputTensor),
77                                          context->getOutputBuffer<float>(kOutputTensor));
78         case OperandType::TENSOR_INT32:
79             return compute<int32_t, int32_t>(std::abs,
80                                              context->getInputBuffer<int32_t>(kInputTensor),
81                                              context->getInputShape(kInputTensor),
82                                              context->getOutputBuffer<int32_t>(kOutputTensor));
83         default:
84             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation ABS";
85     }
86 }
87 
validate(const IOperationValidationContext * context)88 bool validate(const IOperationValidationContext* context) {
89     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
90     NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
91     OperandType inputType = context->getInputType(kInputTensor);
92     NN_RET_CHECK(inputType == OperandType::TENSOR_FLOAT16 ||
93                  inputType == OperandType::TENSOR_FLOAT32)
94             << "Unsupported tensor type for elementwise operation";
95     NN_RET_CHECK(validateInputTypes(context, {inputType}));
96     NN_RET_CHECK(validateOutputTypes(context, {inputType}));
97     return validateHalVersion(context, HalVersion::V1_2);
98 }
99 
validateAbs(const IOperationValidationContext * context)100 bool validateAbs(const IOperationValidationContext* context) {
101     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
102     NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
103     OperandType inputType = context->getInputType(kInputTensor);
104     NN_RET_CHECK(inputType == OperandType::TENSOR_FLOAT16 ||
105                  inputType == OperandType::TENSOR_FLOAT32 || inputType == OperandType::TENSOR_INT32)
106             << "Unsupported tensor type for operation ABS";
107     NN_RET_CHECK(validateInputTypes(context, {inputType}));
108     NN_RET_CHECK(validateOutputTypes(context, {inputType}));
109     return validateHalVersion(context, (inputType == OperandType::TENSOR_INT32 ? HalVersion::V1_3
110                                                                                : HalVersion::V1_2));
111 }
112 
validateFloor(const IOperationValidationContext * context)113 bool validateFloor(const IOperationValidationContext* context) {
114     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
115     NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
116 
117     OperandType inputType = context->getInputType(kInputTensor);
118     NN_RET_CHECK(inputType == OperandType::TENSOR_FLOAT16 ||
119                  inputType == OperandType::TENSOR_FLOAT32)
120             << "Unsupported tensor type for operation FLOOR";
121     NN_RET_CHECK(validateInputTypes(context, {inputType}));
122     NN_RET_CHECK(validateOutputTypes(context, {inputType}));
123 
124     const Shape& input = context->getInputShape(kInputTensor);
125     if (hasKnownRank(input)) {
126         NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
127     }
128 
129     return validateHalVersion(
130             context,
131             (inputType == OperandType::TENSOR_FLOAT16 ? HalVersion::V1_2 : HalVersion::V1_0));
132 }
133 
prepare(IOperationExecutionContext * context)134 bool prepare(IOperationExecutionContext* context) {
135     Shape input = context->getInputShape(kInputTensor);
136     Shape output = context->getOutputShape(kOutputTensor);
137     NN_RET_CHECK(SetShape(input, &output));
138     return context->setOutputShape(kOutputTensor, output);
139 }
140 
prepareFloor(IOperationExecutionContext * context)141 bool prepareFloor(IOperationExecutionContext* context) {
142     Shape input = context->getInputShape(kInputTensor);
143     Shape output = context->getOutputShape(kOutputTensor);
144     NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
145     NN_RET_CHECK(SetShape(input, &output));
146     return context->setOutputShape(kOutputTensor, output);
147 }
148 
executeExp(IOperationExecutionContext * context)149 bool executeExp(IOperationExecutionContext* context) {
150     return execute(context, std::exp);
151 }
152 
executeFloor(IOperationExecutionContext * context)153 bool executeFloor(IOperationExecutionContext* context) {
154     return execute(context, std::floor);
155 }
156 
executeLog(IOperationExecutionContext * context)157 bool executeLog(IOperationExecutionContext* context) {
158     return execute(context, std::log);
159 }
160 
executeRsqrt(IOperationExecutionContext * context)161 bool executeRsqrt(IOperationExecutionContext* context) {
162     return execute(context, [](float x) { return 1.f / std::sqrt(x); });
163 }
164 
executeSin(IOperationExecutionContext * context)165 bool executeSin(IOperationExecutionContext* context) {
166     return execute(context, std::sin);
167 }
168 
executeSqrt(IOperationExecutionContext * context)169 bool executeSqrt(IOperationExecutionContext* context) {
170     return execute(context, std::sqrt);
171 }
172 
173 }  // namespace elementwise
174 
175 NN_REGISTER_OPERATION(ABS, "ABS", elementwise::validateAbs, elementwise::prepare,
176                       elementwise::executeAbs);
177 NN_REGISTER_OPERATION(EXP, "EXP", elementwise::validate, elementwise::prepare,
178                       elementwise::executeExp);
179 NN_REGISTER_OPERATION(FLOOR, "FLOOR", elementwise::validateFloor, elementwise::prepareFloor,
180                       elementwise::executeFloor);
181 NN_REGISTER_OPERATION(LOG, "LOG", elementwise::validate, elementwise::prepare,
182                       elementwise::executeLog);
183 NN_REGISTER_OPERATION(RSQRT, "RSQRT", elementwise::validate, elementwise::prepare,
184                       elementwise::executeRsqrt);
185 NN_REGISTER_OPERATION(SIN, "SIN", elementwise::validate, elementwise::prepare,
186                       elementwise::executeSin);
187 NN_REGISTER_OPERATION(SQRT, "SQRT", elementwise::validate, elementwise::prepare,
188                       elementwise::executeSqrt);
189 
190 }  // namespace nn
191 }  // namespace android
192