1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 // Contains all the entry points to the C Neural Networks API.
18 // We do basic validation of the operands and then call the class
19 // that implements the functionality.
20
21 #define LOG_TAG "NeuralNetworks"
22
23 #include "NeuralNetworks.h"
24
25 #include "Callbacks.h"
26 #include "CompilationBuilder.h"
27 #include "ExecutionBuilder.h"
28 #include "Manager.h"
29 #include "Memory.h"
30 #include "NeuralNetworksOEM.h"
31 #include "ModelBuilder.h"
32 #include "Utils.h"
33
34 #include <memory>
35 #include <vector>
36
37 // Make sure the constants defined in the header files have not changed values.
38 // IMPORTANT: When adding new values, update kNumberOfDataTypes or kNumberOfDataTypesOEM
39 // in Utils.h.
40 static_assert(ANEURALNETWORKS_FLOAT32 == 0, "ANEURALNETWORKS_FLOAT32 has changed");
41 static_assert(ANEURALNETWORKS_INT32 == 1, "ANEURALNETWORKS_INT32 has changed");
42 static_assert(ANEURALNETWORKS_UINT32 == 2, "ANEURALNETWORKS_UINT32 has changed");
43 static_assert(ANEURALNETWORKS_TENSOR_FLOAT32 == 3,
44 "ANEURALNETWORKS_TENSOR_FLOAT32 has changed");
45 static_assert(ANEURALNETWORKS_TENSOR_INT32 == 4, "ANEURALNETWORKS_TENSOR_INT32 has changed");
46 static_assert(ANEURALNETWORKS_TENSOR_QUANT8_ASYMM == 5,
47 "ANEURALNETWORKS_TENSOR_QUANT8_ASYMM has changed");
48 static_assert(ANEURALNETWORKS_OEM_SCALAR == 10000, "ANEURALNETWORKS_OEM_SCALAR has changed");
49 static_assert(ANEURALNETWORKS_TENSOR_OEM_BYTE == 10001,
50 "ANEURALNETWORKS_TENSOR_OEM_BYTE has changed");
51
52 // IMPORTANT: When adding new values, update kNumberOfOperationTypes or
53 // kNumberOfOperationTypesOEMin Utils.h.
54 static_assert(ANEURALNETWORKS_ADD == 0, "ANEURALNETWORKS_ADD has changed");
55 static_assert(ANEURALNETWORKS_AVERAGE_POOL_2D == 1,
56 "ANEURALNETWORKS_AVERAGE_POOL_2D has changed");
57 static_assert(ANEURALNETWORKS_CONCATENATION == 2, "ANEURALNETWORKS_CONCATENATION has changed");
58 static_assert(ANEURALNETWORKS_CONV_2D == 3, "ANEURALNETWORKS_CONV_2D has changed");
59 static_assert(ANEURALNETWORKS_DEPTHWISE_CONV_2D == 4,
60 "ANEURALNETWORKS_DEPTHWISE_CONV_2D has changed");
61 static_assert(ANEURALNETWORKS_DEPTH_TO_SPACE == 5,
62 "ANEURALNETWORKS_DEPTH_TO_SPACE has changed");
63 static_assert(ANEURALNETWORKS_DEQUANTIZE == 6, "ANEURALNETWORKS_DEQUANTIZE has changed");
64 static_assert(ANEURALNETWORKS_EMBEDDING_LOOKUP == 7,
65 "ANEURALNETWORKS_EMBEDDING_LOOKUP has changed");
66 static_assert(ANEURALNETWORKS_FLOOR == 8, "ANEURALNETWORKS_FLOOR has changed");
67 static_assert(ANEURALNETWORKS_FULLY_CONNECTED == 9,
68 "ANEURALNETWORKS_FULLY_CONNECTED has changed");
69 static_assert(ANEURALNETWORKS_HASHTABLE_LOOKUP == 10,
70 "ANEURALNETWORKS_HASHTABLE_LOOKUP has changed");
71 static_assert(ANEURALNETWORKS_L2_NORMALIZATION == 11,
72 "ANEURALNETWORKS_L2_NORMALIZATION has changed");
73 static_assert(ANEURALNETWORKS_L2_POOL_2D == 12, "ANEURALNETWORKS_L2_POOL has changed");
74 static_assert(ANEURALNETWORKS_LOCAL_RESPONSE_NORMALIZATION == 13,
75 "ANEURALNETWORKS_LOCAL_RESPONSE_NORMALIZATION has changed");
76 static_assert(ANEURALNETWORKS_LOGISTIC == 14, "ANEURALNETWORKS_LOGISTIC has changed");
77 static_assert(ANEURALNETWORKS_LSH_PROJECTION == 15,
78 "ANEURALNETWORKS_LSH_PROJECTION has changed");
79 static_assert(ANEURALNETWORKS_LSTM == 16, "ANEURALNETWORKS_LSTM has changed");
80 static_assert(ANEURALNETWORKS_MAX_POOL_2D == 17, "ANEURALNETWORKS_MAX_POOL has changed");
81 static_assert(ANEURALNETWORKS_MUL == 18, "ANEURALNETWORKS_MUL has changed");
82 static_assert(ANEURALNETWORKS_RELU == 19, "ANEURALNETWORKS_RELU has changed");
83 static_assert(ANEURALNETWORKS_RELU1 == 20, "ANEURALNETWORKS_RELU1 has changed");
84 static_assert(ANEURALNETWORKS_RELU6 == 21, "ANEURALNETWORKS_RELU6 has changed");
85 static_assert(ANEURALNETWORKS_RESHAPE == 22, "ANEURALNETWORKS_RESHAPE has changed");
86 static_assert(ANEURALNETWORKS_RESIZE_BILINEAR == 23,
87 "ANEURALNETWORKS_RESIZE_BILINEAR has changed");
88 static_assert(ANEURALNETWORKS_RNN == 24, "ANEURALNETWORKS_RNN has changed");
89 static_assert(ANEURALNETWORKS_SOFTMAX == 25, "ANEURALNETWORKS_SOFTMAX has changed");
90 static_assert(ANEURALNETWORKS_SPACE_TO_DEPTH == 26,
91 "ANEURALNETWORKS_SPACE_TO_DEPTH has changed");
92 static_assert(ANEURALNETWORKS_SVDF == 27, "ANEURALNETWORKS_SVDF has changed");
93 static_assert(ANEURALNETWORKS_TANH == 28, "ANEURALNETWORKS_TANH has changed");
94
95 static_assert(ANEURALNETWORKS_BATCH_TO_SPACE_ND == 29, "ANEURALNETWORKS_BATCH_TO_SPACE_ND has changed");
96 static_assert(ANEURALNETWORKS_DIV == 30, "ANEURALNETWORKS_DIV has changed");
97 static_assert(ANEURALNETWORKS_MEAN == 31, "ANEURALNETWORKS_MEAN has changed");
98 static_assert(ANEURALNETWORKS_PAD == 32, "ANEURALNETWORKS_PAD has changed");
99 static_assert(ANEURALNETWORKS_SPACE_TO_BATCH_ND == 33, "ANEURALNETWORKS_SPACE_TO_BATCH_ND has changed");
100 static_assert(ANEURALNETWORKS_SQUEEZE == 34, "ANEURALNETWORKS_SQUEEZE has changed");
101 static_assert(ANEURALNETWORKS_STRIDED_SLICE == 35, "ANEURALNETWORKS_STRIDED_SLICE has changed");
102 static_assert(ANEURALNETWORKS_SUB == 36, "ANEURALNETWORKS_TANH has changed");
103 static_assert(ANEURALNETWORKS_TRANSPOSE == 37, "ANEURALNETWORKS_TRANSPOSE has changed");
104
105 static_assert(ANEURALNETWORKS_OEM_OPERATION == 10000,
106 "ANEURALNETWORKS_OEM_OPERATION has changed");
107
108 static_assert(ANEURALNETWORKS_FUSED_NONE == 0, "ANEURALNETWORKS_FUSED_NONE has changed");
109 static_assert(ANEURALNETWORKS_FUSED_RELU == 1, "ANEURALNETWORKS_FUSED_RELU has changed");
110 static_assert(ANEURALNETWORKS_FUSED_RELU1 == 2, "ANEURALNETWORKS_FUSED_RELU1 has changed");
111 static_assert(ANEURALNETWORKS_FUSED_RELU6 == 3, "ANEURALNETWORKS_FUSED_RELU6 has changed");
112
113 static_assert(ANEURALNETWORKS_PREFER_LOW_POWER == 0,
114 "ANEURALNETWORKS_PREFER_LOW_POWER has changed");
115 static_assert(ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER == 1,
116 "ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER has changed");
117 static_assert(ANEURALNETWORKS_PREFER_SUSTAINED_SPEED == 2,
118 "ANEURALNETWORKS_PREFER_SUSTAINED_SPEED has changed");
119
120 static_assert(ANEURALNETWORKS_NO_ERROR == 0, "ANEURALNETWORKS_NO_ERROR has changed");
121 static_assert(ANEURALNETWORKS_OUT_OF_MEMORY == 1, "ANEURALNETWORKS_OUT_OF_MEMORY has changed");
122 static_assert(ANEURALNETWORKS_INCOMPLETE == 2, "ANEURALNETWORKS_INCOMPLETE has changed");
123 static_assert(ANEURALNETWORKS_UNEXPECTED_NULL == 3,
124 "ANEURALNETWORKS_UNEXPECTED_NULL has changed");
125 static_assert(ANEURALNETWORKS_BAD_DATA == 4, "ANEURALNETWORKS_BAD_DATA has changed");
126 static_assert(ANEURALNETWORKS_OP_FAILED == 5, "ANEURALNETWORKS_OP_FAILED has changed");
127 static_assert(ANEURALNETWORKS_BAD_STATE == 6, "ANEURALNETWORKS_BAD_STATE has changed");
128
129 static_assert(ANEURALNETWORKS_MAX_SIZE_OF_IMMEDIATELY_COPIED_VALUES == 128,
130 "ANEURALNETWORKS_MAX_SIZE_OF_IMMEDIATELY_COPIED_VALUES has changed");
131
132 // Make sure that the constants are compatible with the values defined in
133 // hardware/interfaces/neuralnetworks/1.0/types.hal.
134 static_assert(static_cast<int32_t>(OperandType::OEM) == ANEURALNETWORKS_OEM_SCALAR,
135 "OEM != ANEURALNETWORKS_OEM");
136 static_assert(static_cast<int32_t>(OperandType::FLOAT32) == ANEURALNETWORKS_FLOAT32,
137 "FLOAT32 != ANEURALNETWORKS_FLOAT32");
138 static_assert(static_cast<int32_t>(OperandType::INT32) == ANEURALNETWORKS_INT32,
139 "INT32 != ANEURALNETWORKS_INT32");
140 static_assert(static_cast<int32_t>(OperandType::UINT32) == ANEURALNETWORKS_UINT32,
141 "UINT32 != ANEURALNETWORKS_UINT32");
142 static_assert(static_cast<int32_t>(OperandType::TENSOR_OEM_BYTE) == ANEURALNETWORKS_TENSOR_OEM_BYTE,
143 "TENSOR_OEM_BYTE != ANEURALNETWORKS_TENSOR_OEM_BYTE");
144 static_assert(static_cast<int32_t>(OperandType::TENSOR_FLOAT32) == ANEURALNETWORKS_TENSOR_FLOAT32,
145 "TENSOR_FLOAT32 != ANEURALNETWORKS_TENSOR_FLOAT32");
146 static_assert(static_cast<int32_t>(OperandType::TENSOR_QUANT8_ASYMM) ==
147 ANEURALNETWORKS_TENSOR_QUANT8_ASYMM,
148 "TENSOR_QUANT8_ASYMM != ANEURALNETWORKS_TENSOR_QUANT8_ASYMM");
149
150 static_assert(static_cast<int32_t>(OperationType::ADD) == ANEURALNETWORKS_ADD,
151 "OperationType::ADD != ANEURALNETWORKS_ADD");
152 static_assert(static_cast<int32_t>(OperationType::AVERAGE_POOL_2D) ==
153 ANEURALNETWORKS_AVERAGE_POOL_2D,
154 "OperationType::AVERAGE_POOL_2D != ANEURALNETWORKS_AVERAGE_POOL_2D");
155 static_assert(static_cast<int32_t>(OperationType::CONV_2D) == ANEURALNETWORKS_CONV_2D,
156 "OperationType::CONV_2D != ANEURALNETWORKS_CONV_2D");
157 static_assert(static_cast<int32_t>(OperationType::DEPTHWISE_CONV_2D) ==
158 ANEURALNETWORKS_DEPTHWISE_CONV_2D,
159 "OperationType::DEPTHWISE_CONV_2D != ANEURALNETWORKS_DEPTHWISE_CONV_2D");
160 static_assert(static_cast<int32_t>(OperationType::DEPTH_TO_SPACE) ==
161 ANEURALNETWORKS_DEPTH_TO_SPACE,
162 "OperationType::DEPTH_TO_SPACE != ANEURALNETWORKS_DEPTH_TO_SPACE");
163 static_assert(static_cast<int32_t>(OperationType::DEQUANTIZE) == ANEURALNETWORKS_DEQUANTIZE,
164 "OperationType::DEQUANTIZE != ANEURALNETWORKS_DEQUANTIZE");
165 static_assert(static_cast<int32_t>(OperationType::EMBEDDING_LOOKUP) ==
166 ANEURALNETWORKS_EMBEDDING_LOOKUP,
167 "OperationType::EMBEDDING_LOOKUP != ANEURALNETWORKS_EMBEDDING_LOOKUP");
168 static_assert(static_cast<int32_t>(OperationType::FLOOR) == ANEURALNETWORKS_FLOOR,
169 "OperationType::FLOOR != ANEURALNETWORKS_FLOOR");
170 static_assert(static_cast<int32_t>(OperationType::FULLY_CONNECTED) ==
171 ANEURALNETWORKS_FULLY_CONNECTED,
172 "OperationType::FULLY_CONNECTED != ANEURALNETWORKS_FULLY_CONNECTED");
173 static_assert(static_cast<int32_t>(OperationType::HASHTABLE_LOOKUP) ==
174 ANEURALNETWORKS_HASHTABLE_LOOKUP,
175 "OperationType::HASHTABLE_LOOKUP != ANEURALNETWORKS_HASHTABLE_LOOKUP");
176 static_assert(static_cast<int32_t>(OperationType::L2_NORMALIZATION) ==
177 ANEURALNETWORKS_L2_NORMALIZATION,
178 "OperationType::L2_NORMALIZATION != ANEURALNETWORKS_L2_NORMALIZATION");
179 static_assert(static_cast<int32_t>(OperationType::L2_POOL_2D) == ANEURALNETWORKS_L2_POOL_2D,
180 "OperationType::L2_POOL_2D != ANEURALNETWORKS_L2_POOL_2D");
181 static_assert(static_cast<int32_t>(OperationType::LOCAL_RESPONSE_NORMALIZATION) ==
182 ANEURALNETWORKS_LOCAL_RESPONSE_NORMALIZATION,
183 "OperationType::LOCAL_RESPONSE_NORMALIZATION != "
184 "ANEURALNETWORKS_LOCAL_RESPONSE_NORMALIZATION");
185 static_assert(static_cast<int32_t>(OperationType::LOGISTIC) == ANEURALNETWORKS_LOGISTIC,
186 "OperationType::LOGISTIC != ANEURALNETWORKS_LOGISTIC");
187 static_assert(static_cast<int32_t>(OperationType::LSH_PROJECTION) ==
188 ANEURALNETWORKS_LSH_PROJECTION,
189 "OperationType::LSH_PROJECTION != ANEURALNETWORKS_LSH_PROJECTION");
190 static_assert(static_cast<int32_t>(OperationType::LSTM) == ANEURALNETWORKS_LSTM,
191 "OperationType::LSTM != ANEURALNETWORKS_LSTM");
192 static_assert(static_cast<int32_t>(OperationType::MAX_POOL_2D) == ANEURALNETWORKS_MAX_POOL_2D,
193 "OperationType::MAX_POOL_2D != ANEURALNETWORKS_MAX_POOL_2D");
194 static_assert(static_cast<int32_t>(OperationType::MUL) == ANEURALNETWORKS_MUL,
195 "OperationType::MUL != ANEURALNETWORKS_MUL");
196 static_assert(static_cast<int32_t>(OperationType::RELU) == ANEURALNETWORKS_RELU,
197 "OperationType::RELU != ANEURALNETWORKS_RELU");
198 static_assert(static_cast<int32_t>(OperationType::RELU1) == ANEURALNETWORKS_RELU1,
199 "OperationType::RELU1 != ANEURALNETWORKS_RELU1");
200 static_assert(static_cast<int32_t>(OperationType::RELU6) == ANEURALNETWORKS_RELU6,
201 "OperationType::RELU6 != ANEURALNETWORKS_RELU6");
202 static_assert(static_cast<int32_t>(OperationType::RESHAPE) == ANEURALNETWORKS_RESHAPE,
203 "OperationType::RESHAPE != ANEURALNETWORKS_RESHAPE");
204 static_assert(static_cast<int32_t>(OperationType::RESIZE_BILINEAR) ==
205 ANEURALNETWORKS_RESIZE_BILINEAR,
206 "OperationType::RESIZE_BILINEAR != ANEURALNETWORKS_RESIZE_BILINEAR");
207 static_assert(static_cast<int32_t>(OperationType::RNN) == ANEURALNETWORKS_RNN,
208 "OperationType::RNN != ANEURALNETWORKS_RNN");
209 static_assert(static_cast<int32_t>(OperationType::SOFTMAX) == ANEURALNETWORKS_SOFTMAX,
210 "OperationType::SOFTMAX != ANEURALNETWORKS_SOFTMAX");
211 static_assert(static_cast<int32_t>(OperationType::SPACE_TO_DEPTH) ==
212 ANEURALNETWORKS_SPACE_TO_DEPTH,
213 "OperationType::SPACE_TO_DEPTH != ANEURALNETWORKS_SPACE_TO_DEPTH");
214 static_assert(static_cast<int32_t>(OperationType::SVDF) == ANEURALNETWORKS_SVDF,
215 "OperationType::SVDF != ANEURALNETWORKS_SVDF");
216 static_assert(static_cast<int32_t>(OperationType::TANH) == ANEURALNETWORKS_TANH,
217 "OperationType::TANH != ANEURALNETWORKS_TANH");
218
219 static_assert(static_cast<int32_t>(OperationType::BATCH_TO_SPACE_ND) == ANEURALNETWORKS_BATCH_TO_SPACE_ND,
220 "OperationType::BATCH_TO_SPACE_ND != ANEURALNETWORKS_BATCH_TO_SPACE_ND");
221 static_assert(static_cast<int32_t>(OperationType::DIV) == ANEURALNETWORKS_DIV,
222 "OperationType::DIV != ANEURALNETWORKS_DIV");
223 static_assert(static_cast<int32_t>(OperationType::MEAN) == ANEURALNETWORKS_MEAN,
224 "OperationType::MEAN != ANEURALNETWORKS_MEAN");
225 static_assert(static_cast<int32_t>(OperationType::PAD) == ANEURALNETWORKS_PAD,
226 "OperationType::PAD != ANEURALNETWORKS_PAD");
227 static_assert(static_cast<int32_t>(OperationType::SPACE_TO_BATCH_ND) ==
228 ANEURALNETWORKS_SPACE_TO_BATCH_ND,
229 "OperationType::SPACE_TO_BATCH_ND != ANEURALNETWORKS_SPACE_TO_BATCH_ND");
230 static_assert(static_cast<int32_t>(OperationType::SQUEEZE) == ANEURALNETWORKS_SQUEEZE,
231 "OperationType::SQUEEZE != ANEURALNETWORKS_SQUEEZE");
232 static_assert(static_cast<int32_t>(OperationType::STRIDED_SLICE) ==
233 ANEURALNETWORKS_STRIDED_SLICE,
234 "OperationType::STRIDED_SLICE != ANEURALNETWORKS_STRIDED_SLICE");
235 static_assert(static_cast<int32_t>(OperationType::SUB) == ANEURALNETWORKS_SUB,
236 "OperationType::SUB != ANEURALNETWORKS_SUB");
237 static_assert(static_cast<int32_t>(OperationType::TRANSPOSE) == ANEURALNETWORKS_TRANSPOSE,
238 "OperationType::TRANSPOSE != ANEURALNETWORKS_TRANSPOSE");
239
240 static_assert(static_cast<int32_t>(FusedActivationFunc::NONE) == ANEURALNETWORKS_FUSED_NONE,
241 "FusedActivationFunc::NONE != ANEURALNETWORKS_FUSED_NONE");
242 static_assert(static_cast<int32_t>(FusedActivationFunc::RELU) == ANEURALNETWORKS_FUSED_RELU,
243 "FusedActivationFunc::RELU != ANEURALNETWORKS_FUSED_RELU");
244 static_assert(static_cast<int32_t>(FusedActivationFunc::RELU1) == ANEURALNETWORKS_FUSED_RELU1,
245 "FusedActivationFunc::RELU1 != ANEURALNETWORKS_FUSED_RELU1");
246 static_assert(static_cast<int32_t>(FusedActivationFunc::RELU6) == ANEURALNETWORKS_FUSED_RELU6,
247 "FusedActivationFunc::RELU6 != ANEURALNETWORKS_FUSED_RELU6");
248
249 using android::sp;
250 using namespace android::nn;
251
ANeuralNetworksMemory_createFromFd(size_t size,int prot,int fd,size_t offset,ANeuralNetworksMemory ** memory)252 int ANeuralNetworksMemory_createFromFd(size_t size, int prot, int fd, size_t offset,
253 ANeuralNetworksMemory** memory) {
254 *memory = nullptr;
255 std::unique_ptr<MemoryFd> m = std::make_unique<MemoryFd>();
256 if (m == nullptr) {
257 return ANEURALNETWORKS_OUT_OF_MEMORY;
258 }
259 int n = m->set(size, prot, fd, offset);
260 if (n != ANEURALNETWORKS_NO_ERROR) {
261 return n;
262 }
263 *memory = reinterpret_cast<ANeuralNetworksMemory*>(m.release());
264 return ANEURALNETWORKS_NO_ERROR;
265 }
266
ANeuralNetworksMemory_free(ANeuralNetworksMemory * memory)267 void ANeuralNetworksMemory_free(ANeuralNetworksMemory* memory) {
268 // No validation. Free of nullptr is valid.
269 Memory* m = reinterpret_cast<Memory*>(memory);
270 delete m;
271 }
272
ANeuralNetworksModel_create(ANeuralNetworksModel ** model)273 int ANeuralNetworksModel_create(ANeuralNetworksModel** model) {
274 initVLogMask();
275 if (!model) {
276 LOG(ERROR) << "ANeuralNetworksModel_create passed a nullptr";
277 return ANEURALNETWORKS_UNEXPECTED_NULL;
278 }
279 ModelBuilder* m = new (std::nothrow) ModelBuilder();
280 if (m == nullptr) {
281 *model = nullptr;
282 return ANEURALNETWORKS_OUT_OF_MEMORY;
283 }
284 *model = reinterpret_cast<ANeuralNetworksModel*>(m);
285 return ANEURALNETWORKS_NO_ERROR;
286 }
287
ANeuralNetworksModel_free(ANeuralNetworksModel * model)288 void ANeuralNetworksModel_free(ANeuralNetworksModel* model) {
289 // No validation. Free of nullptr is valid.
290 ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model);
291 delete m;
292 }
293
ANeuralNetworksModel_finish(ANeuralNetworksModel * model)294 int ANeuralNetworksModel_finish(ANeuralNetworksModel* model) {
295 if (!model) {
296 LOG(ERROR) << "ANeuralNetworksModel_finish passed a nullptr";
297 return ANEURALNETWORKS_UNEXPECTED_NULL;
298 }
299 ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model);
300 return m->finish();
301 }
302
ANeuralNetworksModel_addOperand(ANeuralNetworksModel * model,const ANeuralNetworksOperandType * type)303 int ANeuralNetworksModel_addOperand(ANeuralNetworksModel* model,
304 const ANeuralNetworksOperandType* type) {
305 if (!model || !type) {
306 LOG(ERROR) << "ANeuralNetworksModel_addOperand passed a nullptr";
307 return ANEURALNETWORKS_UNEXPECTED_NULL;
308 }
309 ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model);
310 return m->addOperand(*type);
311 }
312
ANeuralNetworksModel_setOperandValue(ANeuralNetworksModel * model,int32_t index,const void * buffer,size_t length)313 int ANeuralNetworksModel_setOperandValue(ANeuralNetworksModel* model, int32_t index,
314 const void* buffer, size_t length) {
315 if (!model || (!buffer && length != 0)) {
316 LOG(ERROR) << "ANeuralNetworksModel_setOperandValue passed a nullptr";
317 return ANEURALNETWORKS_UNEXPECTED_NULL;
318 }
319 ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model);
320 return m->setOperandValue(index, buffer, length);
321 }
322
ANeuralNetworksModel_setOperandValueFromMemory(ANeuralNetworksModel * model,int32_t index,const ANeuralNetworksMemory * memory,size_t offset,size_t length)323 int ANeuralNetworksModel_setOperandValueFromMemory(ANeuralNetworksModel* model, int32_t index,
324 const ANeuralNetworksMemory* memory,
325 size_t offset, size_t length) {
326 if (!model || !memory) {
327 LOG(ERROR) << "ANeuralNetworksModel_setOperandValue passed a nullptr";
328 return ANEURALNETWORKS_UNEXPECTED_NULL;
329 }
330 const Memory* mem = reinterpret_cast<const Memory*>(memory);
331 ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model);
332 return m->setOperandValueFromMemory(index, mem, offset, length);
333 }
334
ANeuralNetworksModel_addOperation(ANeuralNetworksModel * model,ANeuralNetworksOperationType type,uint32_t inputCount,const uint32_t * inputs,uint32_t outputCount,const uint32_t * outputs)335 int ANeuralNetworksModel_addOperation(ANeuralNetworksModel* model,
336 ANeuralNetworksOperationType type, uint32_t inputCount,
337 const uint32_t* inputs, uint32_t outputCount,
338 const uint32_t* outputs) {
339 if (!model || !inputs || !outputs) {
340 LOG(ERROR) << "ANeuralNetworksModel_addOperation passed a nullptr";
341 return ANEURALNETWORKS_UNEXPECTED_NULL;
342 }
343 ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model);
344 return m->addOperation(type, inputCount, inputs, outputCount, outputs);
345 }
346
ANeuralNetworksModel_identifyInputsAndOutputs(ANeuralNetworksModel * model,uint32_t inputCount,const uint32_t * inputs,uint32_t outputCount,const uint32_t * outputs)347 int ANeuralNetworksModel_identifyInputsAndOutputs(ANeuralNetworksModel* model, uint32_t inputCount,
348 const uint32_t* inputs, uint32_t outputCount,
349 const uint32_t* outputs) {
350 if (!model || !inputs || !outputs) {
351 LOG(ERROR) << ("ANeuralNetworksModel_identifyInputsAndOutputs passed a nullptr");
352 return ANEURALNETWORKS_UNEXPECTED_NULL;
353 }
354 ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model);
355 return m->identifyInputsAndOutputs(inputCount, inputs, outputCount, outputs);
356 }
357
ANeuralNetworksModel_relaxComputationFloat32toFloat16(ANeuralNetworksModel * model,bool allow)358 int ANeuralNetworksModel_relaxComputationFloat32toFloat16(ANeuralNetworksModel* model,
359 bool allow) {
360 if (!model) {
361 LOG(ERROR) << ("ANeuralNetworksModel_relaxComputationFloat32toFloat16 passed a nullptr");
362 return ANEURALNETWORKS_UNEXPECTED_NULL;
363 }
364 ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model);
365 return m->relaxComputationFloat32toFloat16(allow);
366 }
367
ANeuralNetworksCompilation_create(ANeuralNetworksModel * model,ANeuralNetworksCompilation ** compilation)368 int ANeuralNetworksCompilation_create(ANeuralNetworksModel* model,
369 ANeuralNetworksCompilation** compilation) {
370 if (!model || !compilation) {
371 LOG(ERROR) << "ANeuralNetworksCompilation_create passed a nullptr";
372 return ANEURALNETWORKS_UNEXPECTED_NULL;
373 }
374
375 ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model);
376 CompilationBuilder* c = nullptr;
377 int result = m->createCompilation(&c);
378 *compilation = reinterpret_cast<ANeuralNetworksCompilation*>(c);
379 return result;
380 }
381
ANeuralNetworksCompilation_free(ANeuralNetworksCompilation * compilation)382 void ANeuralNetworksCompilation_free(ANeuralNetworksCompilation* compilation) {
383 // No validation. Free of nullptr is valid.
384 // TODO specification says that a compilation-in-flight can be deleted
385 CompilationBuilder* c = reinterpret_cast<CompilationBuilder*>(compilation);
386 delete c;
387 }
388
ANeuralNetworksCompilation_setPreference(ANeuralNetworksCompilation * compilation,int32_t preference)389 int ANeuralNetworksCompilation_setPreference(ANeuralNetworksCompilation* compilation,
390 int32_t preference) {
391 if (!compilation) {
392 LOG(ERROR) << "ANeuralNetworksCompilation_setPreference passed a nullptr";
393 return ANEURALNETWORKS_UNEXPECTED_NULL;
394 }
395 CompilationBuilder* c = reinterpret_cast<CompilationBuilder*>(compilation);
396 return c->setPreference(preference);
397 }
398
ANeuralNetworksCompilation_finish(ANeuralNetworksCompilation * compilation)399 int ANeuralNetworksCompilation_finish(ANeuralNetworksCompilation* compilation) {
400 if (!compilation) {
401 LOG(ERROR) << "ANeuralNetworksCompilation_finish passed a nullptr";
402 return ANEURALNETWORKS_UNEXPECTED_NULL;
403 }
404 CompilationBuilder* c = reinterpret_cast<CompilationBuilder*>(compilation);
405 return c->finish();
406 }
407
ANeuralNetworksExecution_create(ANeuralNetworksCompilation * compilation,ANeuralNetworksExecution ** execution)408 int ANeuralNetworksExecution_create(ANeuralNetworksCompilation* compilation,
409 ANeuralNetworksExecution** execution) {
410 if (!compilation || !execution) {
411 LOG(ERROR) << "ANeuralNetworksExecution_create passed a nullptr";
412 return ANEURALNETWORKS_UNEXPECTED_NULL;
413 }
414
415 CompilationBuilder* c = reinterpret_cast<CompilationBuilder*>(compilation);
416 ExecutionBuilder* r = nullptr;
417 int result = c->createExecution(&r);
418 *execution = reinterpret_cast<ANeuralNetworksExecution*>(r);
419 return result;
420 }
421
ANeuralNetworksExecution_free(ANeuralNetworksExecution * execution)422 void ANeuralNetworksExecution_free(ANeuralNetworksExecution* execution) {
423 // TODO specification says that an execution-in-flight can be deleted
424 // No validation. Free of nullptr is valid.
425 ExecutionBuilder* r = reinterpret_cast<ExecutionBuilder*>(execution);
426 delete r;
427 }
428
ANeuralNetworksExecution_setInput(ANeuralNetworksExecution * execution,int32_t index,const ANeuralNetworksOperandType * type,const void * buffer,size_t length)429 int ANeuralNetworksExecution_setInput(ANeuralNetworksExecution* execution, int32_t index,
430 const ANeuralNetworksOperandType* type, const void* buffer,
431 size_t length) {
432 if (!execution || (!buffer && length != 0)) {
433 LOG(ERROR) << "ANeuralNetworksExecution_setInput passed a nullptr";
434 return ANEURALNETWORKS_UNEXPECTED_NULL;
435 }
436 ExecutionBuilder* r = reinterpret_cast<ExecutionBuilder*>(execution);
437 return r->setInput(index, type, buffer, length);
438 }
439
ANeuralNetworksExecution_setInputFromMemory(ANeuralNetworksExecution * execution,int32_t index,const ANeuralNetworksOperandType * type,const ANeuralNetworksMemory * memory,size_t offset,size_t length)440 int ANeuralNetworksExecution_setInputFromMemory(ANeuralNetworksExecution* execution, int32_t index,
441 const ANeuralNetworksOperandType* type,
442 const ANeuralNetworksMemory* memory, size_t offset,
443 size_t length) {
444 if (!execution || !memory) {
445 LOG(ERROR) << "ANeuralNetworksExecution_setInputFromMemory passed a nullptr";
446 return ANEURALNETWORKS_UNEXPECTED_NULL;
447 }
448
449 const Memory* m = reinterpret_cast<const Memory*>(memory);
450 ExecutionBuilder* r = reinterpret_cast<ExecutionBuilder*>(execution);
451 return r->setInputFromMemory(index, type, m, offset, length);
452 }
453
ANeuralNetworksExecution_setOutput(ANeuralNetworksExecution * execution,int32_t index,const ANeuralNetworksOperandType * type,void * buffer,size_t length)454 int ANeuralNetworksExecution_setOutput(ANeuralNetworksExecution* execution, int32_t index,
455 const ANeuralNetworksOperandType* type, void* buffer,
456 size_t length) {
457 if (!execution || (!buffer && length != 0)) {
458 LOG(ERROR) << "ANeuralNetworksExecution_setOutput passed a nullptr";
459 return ANEURALNETWORKS_UNEXPECTED_NULL;
460 }
461 ExecutionBuilder* r = reinterpret_cast<ExecutionBuilder*>(execution);
462 return r->setOutput(index, type, buffer, length);
463 }
464
ANeuralNetworksExecution_setOutputFromMemory(ANeuralNetworksExecution * execution,int32_t index,const ANeuralNetworksOperandType * type,const ANeuralNetworksMemory * memory,size_t offset,size_t length)465 int ANeuralNetworksExecution_setOutputFromMemory(ANeuralNetworksExecution* execution, int32_t index,
466 const ANeuralNetworksOperandType* type,
467 const ANeuralNetworksMemory* memory, size_t offset,
468 size_t length) {
469 if (!execution || !memory) {
470 LOG(ERROR) << "ANeuralNetworksExecution_setOutputFromMemory passed a nullptr";
471 return ANEURALNETWORKS_UNEXPECTED_NULL;
472 }
473
474 ExecutionBuilder* r = reinterpret_cast<ExecutionBuilder*>(execution);
475 const Memory* m = reinterpret_cast<const Memory*>(memory);
476 return r->setOutputFromMemory(index, type, m, offset, length);
477 }
478
ANeuralNetworksExecution_startCompute(ANeuralNetworksExecution * execution,ANeuralNetworksEvent ** event)479 int ANeuralNetworksExecution_startCompute(ANeuralNetworksExecution* execution,
480 ANeuralNetworksEvent** event) {
481 if (!execution || !event) {
482 LOG(ERROR) << "ANeuralNetworksExecution_startCompute passed a nullptr";
483 return ANEURALNETWORKS_UNEXPECTED_NULL;
484 }
485 // TODO validate the rest
486
487 ExecutionBuilder* r = reinterpret_cast<ExecutionBuilder*>(execution);
488
489 // Dynamically allocate an sp to wrap an ExecutionCallback, seen in the NN
490 // API as an abstract event object. The sp<ExecutionCallback> object is
491 // returned when the execution has been successfully launched, otherwise a
492 // nullptr is returned. The sp is used for ref-counting purposes. Without
493 // it, the HIDL service could attempt to communicate with a dead callback
494 // object.
495 std::unique_ptr<sp<ExecutionCallback>> e = std::make_unique<sp<ExecutionCallback>>();
496 *event = nullptr;
497
498 int n = r->startCompute(e.get());
499 if (n != ANEURALNETWORKS_NO_ERROR) {
500 return n;
501 }
502 *event = reinterpret_cast<ANeuralNetworksEvent*>(e.release());
503 return ANEURALNETWORKS_NO_ERROR;
504 }
505
ANeuralNetworksEvent_wait(ANeuralNetworksEvent * event)506 int ANeuralNetworksEvent_wait(ANeuralNetworksEvent* event) {
507 if (event == nullptr) {
508 LOG(ERROR) << "ANeuralNetworksEvent_wait passed a nullptr";
509 return ANEURALNETWORKS_UNEXPECTED_NULL;
510 }
511
512 sp<ExecutionCallback>* e = reinterpret_cast<sp<ExecutionCallback>*>(event);
513 (*e)->wait();
514 return convertErrorStatusToResultCode((*e)->getStatus());
515 }
516
ANeuralNetworksEvent_free(ANeuralNetworksEvent * event)517 void ANeuralNetworksEvent_free(ANeuralNetworksEvent* event) {
518 // No validation. Free of nullptr is valid.
519 if (event) {
520 sp<ExecutionCallback>* e = reinterpret_cast<sp<ExecutionCallback>*>(event);
521 (*e)->wait();
522 delete e;
523 }
524 }
525