1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_EXTENSIONS_H
18 #define ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_EXTENSIONS_H
19 
20 #include "NeuralNetworksExtensions.h"
21 #include "NeuralNetworksWrapper.h"
22 
23 #include <variant>
24 
25 namespace android {
26 namespace nn {
27 namespace extension_wrapper {
28 
29 using wrapper::SymmPerChannelQuantParams;
30 using wrapper::Type;
31 
32 struct ExtensionOperandParams {
33     std::vector<uint8_t> data;
34 
ExtensionOperandParamsExtensionOperandParams35     ExtensionOperandParams(std::vector<uint8_t> data) : data(std::move(data)) {}
36 
37     template <typename T>
ExtensionOperandParamsExtensionOperandParams38     ExtensionOperandParams(const T& data)
39         : ExtensionOperandParams(
40                   std::vector(reinterpret_cast<const uint8_t*>(&data),
41                               reinterpret_cast<const uint8_t*>(&data) + sizeof(data))) {
42         static_assert(std::is_trivially_copyable<T>::value, "data must be trivially copyable");
43     }
44 };
45 
46 struct OperandType {
47     using ExtraParams =
48             std::variant<std::monostate, SymmPerChannelQuantParams, ExtensionOperandParams>;
49 
50     ANeuralNetworksOperandType operandType;
51     std::vector<uint32_t> dimensions;
52     ExtraParams extraParams;
53 
OperandTypeOperandType54     OperandType(const OperandType& other)
55         : operandType(other.operandType),
56           dimensions(other.dimensions),
57           extraParams(other.extraParams) {
58         operandType.dimensions = dimensions.size() > 0 ? dimensions.data() : nullptr;
59     }
60 
61     OperandType& operator=(const OperandType& other) {
62         if (this != &other) {
63             operandType = other.operandType;
64             dimensions = other.dimensions;
65             extraParams = other.extraParams;
66             operandType.dimensions = dimensions.size() > 0 ? dimensions.data() : nullptr;
67         }
68         return *this;
69     }
70 
71     OperandType(Type type, std::vector<uint32_t> d, float scale = 0.0f, int32_t zeroPoint = 0,
72                 ExtraParams&& extraParams = std::monostate())
dimensionsOperandType73         : dimensions(std::move(d)), extraParams(std::move(extraParams)) {
74         operandType = {
75                 .type = static_cast<int32_t>(type),
76                 .dimensionCount = static_cast<uint32_t>(dimensions.size()),
77                 .dimensions = dimensions.size() > 0 ? dimensions.data() : nullptr,
78                 .scale = scale,
79                 .zeroPoint = zeroPoint,
80         };
81     }
82 
OperandTypeOperandType83     OperandType(Type type, std::vector<uint32_t> dimensions, float scale, int32_t zeroPoint,
84                 SymmPerChannelQuantParams&& channelQuant)
85         : OperandType(type, dimensions, scale, zeroPoint, ExtraParams(std::move(channelQuant))) {}
86 
OperandTypeOperandType87     OperandType(Type type, std::vector<uint32_t> dimensions, ExtraParams&& extraParams)
88         : OperandType(type, dimensions, 0.0f, 0, std::move(extraParams)) {}
89 };
90 
91 class Model : public wrapper::Model {
92    public:
93     using wrapper::Model::Model;  // Inherit constructors.
94 
getExtensionOperandType(const char * extensionName,uint16_t typeWithinExtension)95     int32_t getExtensionOperandType(const char* extensionName, uint16_t typeWithinExtension) {
96         int32_t result;
97         if (ANeuralNetworksModel_getExtensionOperandType(mModel, extensionName, typeWithinExtension,
98                                                          &result) != ANEURALNETWORKS_NO_ERROR) {
99             mValid = false;
100         }
101         return result;
102     }
103 
getExtensionOperationType(const char * extensionName,uint16_t typeWithinExtension)104     ANeuralNetworksOperationType getExtensionOperationType(const char* extensionName,
105                                                            uint16_t typeWithinExtension) {
106         ANeuralNetworksOperationType result;
107         if (ANeuralNetworksModel_getExtensionOperationType(mModel, extensionName,
108                                                            typeWithinExtension,
109                                                            &result) != ANEURALNETWORKS_NO_ERROR) {
110             mValid = false;
111         }
112         return result;
113     }
114 
addOperand(const OperandType * type)115     uint32_t addOperand(const OperandType* type) {
116         if (ANeuralNetworksModel_addOperand(mModel, &(type->operandType)) !=
117             ANEURALNETWORKS_NO_ERROR) {
118             mValid = false;
119         }
120         if (std::holds_alternative<SymmPerChannelQuantParams>(type->extraParams)) {
121             const auto& channelQuant = std::get<SymmPerChannelQuantParams>(type->extraParams);
122             if (ANeuralNetworksModel_setOperandSymmPerChannelQuantParams(
123                         mModel, mNextOperandId, &channelQuant.params) != ANEURALNETWORKS_NO_ERROR) {
124                 mValid = false;
125             }
126         } else if (std::holds_alternative<ExtensionOperandParams>(type->extraParams)) {
127             const auto& extension = std::get<ExtensionOperandParams>(type->extraParams);
128             if (ANeuralNetworksModel_setOperandExtensionData(
129                         mModel, mNextOperandId, extension.data.data(), extension.data.size()) !=
130                 ANEURALNETWORKS_NO_ERROR) {
131                 mValid = false;
132             }
133         }
134         return mNextOperandId++;
135     }
136 };
137 
138 }  // namespace extension_wrapper
139 
140 namespace wrapper {
141 
142 using ExtensionModel = extension_wrapper::Model;
143 using ExtensionOperandType = extension_wrapper::OperandType;
144 using ExtensionOperandParams = extension_wrapper::ExtensionOperandParams;
145 
146 }  // namespace wrapper
147 }  // namespace nn
148 }  // namespace android
149 
150 #endif  //  ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_EXTENSIONS_H
151