1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "FibonacciDriver"
18 
19 #include "FibonacciDriver.h"
20 
21 #include <vector>
22 
23 #include "FibonacciExtension.h"
24 #include "HalInterfaces.h"
25 #include "NeuralNetworksExtensions.h"
26 #include "OperationResolver.h"
27 #include "OperationsUtils.h"
28 #include "Utils.h"
29 #include "ValidateHal.h"
30 
31 namespace android {
32 namespace nn {
33 namespace sample_driver {
34 namespace {
35 
36 using namespace hal;
37 
38 const uint8_t kLowBitsType = static_cast<uint8_t>(ExtensionTypeEncoding::LOW_BITS_TYPE);
39 const uint32_t kTypeWithinExtensionMask = (1 << kLowBitsType) - 1;
40 
41 namespace fibonacci_op {
42 
43 constexpr char kOperationName[] = "EXAMPLE_FIBONACCI";
44 
45 constexpr uint32_t kNumInputs = 1;
46 constexpr uint32_t kInputN = 0;
47 
48 constexpr uint32_t kNumOutputs = 1;
49 constexpr uint32_t kOutputTensor = 0;
50 
getFibonacciExtensionPrefix(const Model & model,uint16_t * prefix)51 bool getFibonacciExtensionPrefix(const Model& model, uint16_t* prefix) {
52     NN_RET_CHECK_EQ(model.extensionNameToPrefix.size(), 1u);  // Assumes no other extensions in use.
53     NN_RET_CHECK_EQ(model.extensionNameToPrefix[0].name, EXAMPLE_FIBONACCI_EXTENSION_NAME);
54     *prefix = model.extensionNameToPrefix[0].prefix;
55     return true;
56 }
57 
isFibonacciOperation(const Operation & operation,const Model & model)58 bool isFibonacciOperation(const Operation& operation, const Model& model) {
59     int32_t operationType = static_cast<int32_t>(operation.type);
60     uint16_t prefix;
61     NN_RET_CHECK(getFibonacciExtensionPrefix(model, &prefix));
62     NN_RET_CHECK_EQ(operationType, (prefix << kLowBitsType) | EXAMPLE_FIBONACCI);
63     return true;
64 }
65 
validate(const Operation & operation,const Model & model)66 bool validate(const Operation& operation, const Model& model) {
67     NN_RET_CHECK(isFibonacciOperation(operation, model));
68     NN_RET_CHECK_EQ(operation.inputs.size(), kNumInputs);
69     NN_RET_CHECK_EQ(operation.outputs.size(), kNumOutputs);
70     int32_t inputType = static_cast<int32_t>(model.main.operands[operation.inputs[0]].type);
71     int32_t outputType = static_cast<int32_t>(model.main.operands[operation.outputs[0]].type);
72     uint16_t prefix;
73     NN_RET_CHECK(getFibonacciExtensionPrefix(model, &prefix));
74     NN_RET_CHECK(inputType == ((prefix << kLowBitsType) | EXAMPLE_INT64) ||
75                  inputType == ANEURALNETWORKS_TENSOR_FLOAT32);
76     NN_RET_CHECK(outputType == ((prefix << kLowBitsType) | EXAMPLE_TENSOR_QUANT64_ASYMM) ||
77                  outputType == ANEURALNETWORKS_TENSOR_FLOAT32);
78     return true;
79 }
80 
prepare(IOperationExecutionContext * context)81 bool prepare(IOperationExecutionContext* context) {
82     int64_t n;
83     if (context->getInputType(kInputN) == OperandType::TENSOR_FLOAT32) {
84         n = static_cast<int64_t>(context->getInputValue<float>(kInputN));
85     } else {
86         n = context->getInputValue<int64_t>(kInputN);
87     }
88     NN_RET_CHECK_GE(n, 1);
89     Shape output = context->getOutputShape(kOutputTensor);
90     output.dimensions = {static_cast<uint32_t>(n)};
91     return context->setOutputShape(kOutputTensor, output);
92 }
93 
94 template <typename ScaleT, typename ZeroPointT, typename OutputT>
compute(int32_t n,ScaleT outputScale,ZeroPointT outputZeroPoint,OutputT * output)95 bool compute(int32_t n, ScaleT outputScale, ZeroPointT outputZeroPoint, OutputT* output) {
96     // Compute the Fibonacci numbers.
97     if (n >= 1) {
98         output[0] = 1;
99     }
100     if (n >= 2) {
101         output[1] = 1;
102     }
103     if (n >= 3) {
104         for (int32_t i = 2; i < n; ++i) {
105             output[i] = output[i - 1] + output[i - 2];
106         }
107     }
108 
109     // Quantize output.
110     for (int32_t i = 0; i < n; ++i) {
111         output[i] = output[i] / outputScale + outputZeroPoint;
112     }
113 
114     return true;
115 }
116 
execute(IOperationExecutionContext * context)117 bool execute(IOperationExecutionContext* context) {
118     int64_t n;
119     if (context->getInputType(kInputN) == OperandType::TENSOR_FLOAT32) {
120         n = static_cast<int64_t>(context->getInputValue<float>(kInputN));
121     } else {
122         n = context->getInputValue<int64_t>(kInputN);
123     }
124     if (context->getOutputType(kOutputTensor) == OperandType::TENSOR_FLOAT32) {
125         float* output = context->getOutputBuffer<float>(kOutputTensor);
126         return compute(n, /*scale=*/1.0, /*zeroPoint=*/0, output);
127     } else {
128         uint64_t* output = context->getOutputBuffer<uint64_t>(kOutputTensor);
129         Shape outputShape = context->getOutputShape(kOutputTensor);
130         auto outputQuant = reinterpret_cast<const ExampleQuant64AsymmParams*>(
131                 outputShape.extraParams.extension().data());
132         return compute(n, outputQuant->scale, outputQuant->zeroPoint, output);
133     }
134 }
135 
136 }  // namespace fibonacci_op
137 }  // namespace
138 
findOperation(OperationType operationType) const139 const OperationRegistration* FibonacciOperationResolver::findOperation(
140         OperationType operationType) const {
141     // .validate is omitted because it's not used by the extension driver.
142     static OperationRegistration operationRegistration(operationType, fibonacci_op::kOperationName,
143                                                        nullptr, fibonacci_op::prepare,
144                                                        fibonacci_op::execute, {});
145     uint16_t prefix = static_cast<int32_t>(operationType) >> kLowBitsType;
146     uint16_t typeWithinExtension = static_cast<int32_t>(operationType) & kTypeWithinExtensionMask;
147     // Assumes no other extensions in use.
148     return prefix != 0 && typeWithinExtension == EXAMPLE_FIBONACCI ? &operationRegistration
149                                                                    : nullptr;
150 }
151 
getSupportedExtensions(getSupportedExtensions_cb cb)152 Return<void> FibonacciDriver::getSupportedExtensions(getSupportedExtensions_cb cb) {
153     cb(V1_0::ErrorStatus::NONE,
154        {
155                {
156                        .name = EXAMPLE_FIBONACCI_EXTENSION_NAME,
157                        .operandTypes =
158                                {
159                                        {
160                                                .type = EXAMPLE_INT64,
161                                                .isTensor = false,
162                                                .byteSize = 8,
163                                        },
164                                        {
165                                                .type = EXAMPLE_TENSOR_QUANT64_ASYMM,
166                                                .isTensor = true,
167                                                .byteSize = 8,
168                                        },
169                                },
170                },
171        });
172     return Void();
173 }
174 
getCapabilities_1_3(getCapabilities_1_3_cb cb)175 Return<void> FibonacciDriver::getCapabilities_1_3(getCapabilities_1_3_cb cb) {
176     android::nn::initVLogMask();
177     VLOG(DRIVER) << "getCapabilities()";
178     static const PerformanceInfo kPerf = {.execTime = 1.0f, .powerUsage = 1.0f};
179     Capabilities capabilities = {
180             .relaxedFloat32toFloat16PerformanceScalar = kPerf,
181             .relaxedFloat32toFloat16PerformanceTensor = kPerf,
182             .operandPerformance = nonExtensionOperandPerformance<HalVersion::V1_3>(kPerf),
183             .ifPerformance = kPerf,
184             .whilePerformance = kPerf};
185     cb(V1_3::ErrorStatus::NONE, capabilities);
186     return Void();
187 }
188 
getSupportedOperations_1_3(const V1_3::Model & model,getSupportedOperations_1_3_cb cb)189 Return<void> FibonacciDriver::getSupportedOperations_1_3(const V1_3::Model& model,
190                                                          getSupportedOperations_1_3_cb cb) {
191     VLOG(DRIVER) << "getSupportedOperations()";
192     if (!validateModel(model)) {
193         cb(V1_3::ErrorStatus::INVALID_ARGUMENT, {});
194         return Void();
195     }
196     const size_t count = model.main.operations.size();
197     std::vector<bool> supported(count);
198     for (size_t i = 0; i < count; ++i) {
199         const Operation& operation = model.main.operations[i];
200         if (fibonacci_op::isFibonacciOperation(operation, model)) {
201             if (!fibonacci_op::validate(operation, model)) {
202                 cb(V1_3::ErrorStatus::INVALID_ARGUMENT, {});
203                 return Void();
204             }
205             supported[i] = true;
206         }
207     }
208     cb(V1_3::ErrorStatus::NONE, supported);
209     return Void();
210 }
211 
212 }  // namespace sample_driver
213 }  // namespace nn
214 }  // namespace android
215