1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 // Contains all the entry points to the C Neural Networks API.
18 // We do basic validation of the operands and then call the class
19 // that implements the functionality.
20 
21 #define LOG_TAG "NeuralNetworks"
22 
23 #include "NeuralNetworks.h"
24 
25 #include "Callbacks.h"
26 #include "CompilationBuilder.h"
27 #include "ExecutionBuilder.h"
28 #include "Manager.h"
29 #include "Memory.h"
30 #include "NeuralNetworksOEM.h"
31 #include "ModelBuilder.h"
32 #include "Utils.h"
33 
34 #include <memory>
35 #include <vector>
36 
37 // Make sure the constants defined in the header files have not changed values.
38 // IMPORTANT: When adding new values, update kNumberOfDataTypes or kNumberOfDataTypesOEM
39 // in Utils.h.
40 static_assert(ANEURALNETWORKS_FLOAT32 == 0, "ANEURALNETWORKS_FLOAT32 has changed");
41 static_assert(ANEURALNETWORKS_INT32 == 1, "ANEURALNETWORKS_INT32 has changed");
42 static_assert(ANEURALNETWORKS_UINT32 == 2, "ANEURALNETWORKS_UINT32 has changed");
43 static_assert(ANEURALNETWORKS_TENSOR_FLOAT32 == 3,
44               "ANEURALNETWORKS_TENSOR_FLOAT32 has changed");
45 static_assert(ANEURALNETWORKS_TENSOR_INT32 == 4, "ANEURALNETWORKS_TENSOR_INT32 has changed");
46 static_assert(ANEURALNETWORKS_TENSOR_QUANT8_ASYMM == 5,
47               "ANEURALNETWORKS_TENSOR_QUANT8_ASYMM has changed");
48 static_assert(ANEURALNETWORKS_OEM_SCALAR == 10000, "ANEURALNETWORKS_OEM_SCALAR has changed");
49 static_assert(ANEURALNETWORKS_TENSOR_OEM_BYTE == 10001,
50               "ANEURALNETWORKS_TENSOR_OEM_BYTE has changed");
51 
52 // IMPORTANT: When adding new values, update kNumberOfOperationTypes or
53 // kNumberOfOperationTypesOEMin Utils.h.
54 static_assert(ANEURALNETWORKS_ADD == 0, "ANEURALNETWORKS_ADD has changed");
55 static_assert(ANEURALNETWORKS_AVERAGE_POOL_2D == 1,
56               "ANEURALNETWORKS_AVERAGE_POOL_2D has changed");
57 static_assert(ANEURALNETWORKS_CONCATENATION == 2, "ANEURALNETWORKS_CONCATENATION has changed");
58 static_assert(ANEURALNETWORKS_CONV_2D == 3, "ANEURALNETWORKS_CONV_2D has changed");
59 static_assert(ANEURALNETWORKS_DEPTHWISE_CONV_2D == 4,
60               "ANEURALNETWORKS_DEPTHWISE_CONV_2D has changed");
61 static_assert(ANEURALNETWORKS_DEPTH_TO_SPACE == 5,
62               "ANEURALNETWORKS_DEPTH_TO_SPACE has changed");
63 static_assert(ANEURALNETWORKS_DEQUANTIZE == 6, "ANEURALNETWORKS_DEQUANTIZE has changed");
64 static_assert(ANEURALNETWORKS_EMBEDDING_LOOKUP == 7,
65               "ANEURALNETWORKS_EMBEDDING_LOOKUP has changed");
66 static_assert(ANEURALNETWORKS_FLOOR == 8, "ANEURALNETWORKS_FLOOR has changed");
67 static_assert(ANEURALNETWORKS_FULLY_CONNECTED == 9,
68               "ANEURALNETWORKS_FULLY_CONNECTED has changed");
69 static_assert(ANEURALNETWORKS_HASHTABLE_LOOKUP == 10,
70               "ANEURALNETWORKS_HASHTABLE_LOOKUP has changed");
71 static_assert(ANEURALNETWORKS_L2_NORMALIZATION == 11,
72               "ANEURALNETWORKS_L2_NORMALIZATION has changed");
73 static_assert(ANEURALNETWORKS_L2_POOL_2D == 12, "ANEURALNETWORKS_L2_POOL has changed");
74 static_assert(ANEURALNETWORKS_LOCAL_RESPONSE_NORMALIZATION == 13,
75               "ANEURALNETWORKS_LOCAL_RESPONSE_NORMALIZATION has changed");
76 static_assert(ANEURALNETWORKS_LOGISTIC == 14, "ANEURALNETWORKS_LOGISTIC has changed");
77 static_assert(ANEURALNETWORKS_LSH_PROJECTION == 15,
78               "ANEURALNETWORKS_LSH_PROJECTION has changed");
79 static_assert(ANEURALNETWORKS_LSTM == 16, "ANEURALNETWORKS_LSTM has changed");
80 static_assert(ANEURALNETWORKS_MAX_POOL_2D == 17, "ANEURALNETWORKS_MAX_POOL has changed");
81 static_assert(ANEURALNETWORKS_MUL == 18, "ANEURALNETWORKS_MUL has changed");
82 static_assert(ANEURALNETWORKS_RELU == 19, "ANEURALNETWORKS_RELU has changed");
83 static_assert(ANEURALNETWORKS_RELU1 == 20, "ANEURALNETWORKS_RELU1 has changed");
84 static_assert(ANEURALNETWORKS_RELU6 == 21, "ANEURALNETWORKS_RELU6 has changed");
85 static_assert(ANEURALNETWORKS_RESHAPE == 22, "ANEURALNETWORKS_RESHAPE has changed");
86 static_assert(ANEURALNETWORKS_RESIZE_BILINEAR == 23,
87               "ANEURALNETWORKS_RESIZE_BILINEAR has changed");
88 static_assert(ANEURALNETWORKS_RNN == 24, "ANEURALNETWORKS_RNN has changed");
89 static_assert(ANEURALNETWORKS_SOFTMAX == 25, "ANEURALNETWORKS_SOFTMAX has changed");
90 static_assert(ANEURALNETWORKS_SPACE_TO_DEPTH == 26,
91               "ANEURALNETWORKS_SPACE_TO_DEPTH has changed");
92 static_assert(ANEURALNETWORKS_SVDF == 27, "ANEURALNETWORKS_SVDF has changed");
93 static_assert(ANEURALNETWORKS_TANH == 28, "ANEURALNETWORKS_TANH has changed");
94 
95 static_assert(ANEURALNETWORKS_BATCH_TO_SPACE_ND == 29, "ANEURALNETWORKS_BATCH_TO_SPACE_ND has changed");
96 static_assert(ANEURALNETWORKS_DIV == 30, "ANEURALNETWORKS_DIV has changed");
97 static_assert(ANEURALNETWORKS_MEAN == 31, "ANEURALNETWORKS_MEAN has changed");
98 static_assert(ANEURALNETWORKS_PAD == 32, "ANEURALNETWORKS_PAD has changed");
99 static_assert(ANEURALNETWORKS_SPACE_TO_BATCH_ND == 33, "ANEURALNETWORKS_SPACE_TO_BATCH_ND has changed");
100 static_assert(ANEURALNETWORKS_SQUEEZE == 34, "ANEURALNETWORKS_SQUEEZE has changed");
101 static_assert(ANEURALNETWORKS_STRIDED_SLICE == 35, "ANEURALNETWORKS_STRIDED_SLICE has changed");
102 static_assert(ANEURALNETWORKS_SUB == 36, "ANEURALNETWORKS_TANH has changed");
103 static_assert(ANEURALNETWORKS_TRANSPOSE == 37, "ANEURALNETWORKS_TRANSPOSE has changed");
104 
105 static_assert(ANEURALNETWORKS_OEM_OPERATION == 10000,
106               "ANEURALNETWORKS_OEM_OPERATION has changed");
107 
108 static_assert(ANEURALNETWORKS_FUSED_NONE == 0, "ANEURALNETWORKS_FUSED_NONE has changed");
109 static_assert(ANEURALNETWORKS_FUSED_RELU == 1, "ANEURALNETWORKS_FUSED_RELU has changed");
110 static_assert(ANEURALNETWORKS_FUSED_RELU1 == 2, "ANEURALNETWORKS_FUSED_RELU1 has changed");
111 static_assert(ANEURALNETWORKS_FUSED_RELU6 == 3, "ANEURALNETWORKS_FUSED_RELU6 has changed");
112 
113 static_assert(ANEURALNETWORKS_PREFER_LOW_POWER == 0,
114               "ANEURALNETWORKS_PREFER_LOW_POWER has changed");
115 static_assert(ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER == 1,
116               "ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER has changed");
117 static_assert(ANEURALNETWORKS_PREFER_SUSTAINED_SPEED == 2,
118               "ANEURALNETWORKS_PREFER_SUSTAINED_SPEED has changed");
119 
120 static_assert(ANEURALNETWORKS_NO_ERROR == 0, "ANEURALNETWORKS_NO_ERROR has changed");
121 static_assert(ANEURALNETWORKS_OUT_OF_MEMORY == 1, "ANEURALNETWORKS_OUT_OF_MEMORY has changed");
122 static_assert(ANEURALNETWORKS_INCOMPLETE == 2, "ANEURALNETWORKS_INCOMPLETE has changed");
123 static_assert(ANEURALNETWORKS_UNEXPECTED_NULL == 3,
124               "ANEURALNETWORKS_UNEXPECTED_NULL has changed");
125 static_assert(ANEURALNETWORKS_BAD_DATA == 4, "ANEURALNETWORKS_BAD_DATA has changed");
126 static_assert(ANEURALNETWORKS_OP_FAILED == 5, "ANEURALNETWORKS_OP_FAILED has changed");
127 static_assert(ANEURALNETWORKS_BAD_STATE == 6, "ANEURALNETWORKS_BAD_STATE has changed");
128 
129 static_assert(ANEURALNETWORKS_MAX_SIZE_OF_IMMEDIATELY_COPIED_VALUES == 128,
130               "ANEURALNETWORKS_MAX_SIZE_OF_IMMEDIATELY_COPIED_VALUES has changed");
131 
132 // Make sure that the constants are compatible with the values defined in
133 // hardware/interfaces/neuralnetworks/1.0/types.hal.
134 static_assert(static_cast<int32_t>(OperandType::OEM) == ANEURALNETWORKS_OEM_SCALAR,
135               "OEM != ANEURALNETWORKS_OEM");
136 static_assert(static_cast<int32_t>(OperandType::FLOAT32) == ANEURALNETWORKS_FLOAT32,
137               "FLOAT32 != ANEURALNETWORKS_FLOAT32");
138 static_assert(static_cast<int32_t>(OperandType::INT32) == ANEURALNETWORKS_INT32,
139               "INT32 != ANEURALNETWORKS_INT32");
140 static_assert(static_cast<int32_t>(OperandType::UINT32) == ANEURALNETWORKS_UINT32,
141               "UINT32 != ANEURALNETWORKS_UINT32");
142 static_assert(static_cast<int32_t>(OperandType::TENSOR_OEM_BYTE) == ANEURALNETWORKS_TENSOR_OEM_BYTE,
143               "TENSOR_OEM_BYTE != ANEURALNETWORKS_TENSOR_OEM_BYTE");
144 static_assert(static_cast<int32_t>(OperandType::TENSOR_FLOAT32) == ANEURALNETWORKS_TENSOR_FLOAT32,
145               "TENSOR_FLOAT32 != ANEURALNETWORKS_TENSOR_FLOAT32");
146 static_assert(static_cast<int32_t>(OperandType::TENSOR_QUANT8_ASYMM) ==
147                           ANEURALNETWORKS_TENSOR_QUANT8_ASYMM,
148               "TENSOR_QUANT8_ASYMM != ANEURALNETWORKS_TENSOR_QUANT8_ASYMM");
149 
150 static_assert(static_cast<int32_t>(OperationType::ADD) == ANEURALNETWORKS_ADD,
151               "OperationType::ADD != ANEURALNETWORKS_ADD");
152 static_assert(static_cast<int32_t>(OperationType::AVERAGE_POOL_2D) ==
153                           ANEURALNETWORKS_AVERAGE_POOL_2D,
154               "OperationType::AVERAGE_POOL_2D != ANEURALNETWORKS_AVERAGE_POOL_2D");
155 static_assert(static_cast<int32_t>(OperationType::CONV_2D) == ANEURALNETWORKS_CONV_2D,
156               "OperationType::CONV_2D != ANEURALNETWORKS_CONV_2D");
157 static_assert(static_cast<int32_t>(OperationType::DEPTHWISE_CONV_2D) ==
158                           ANEURALNETWORKS_DEPTHWISE_CONV_2D,
159               "OperationType::DEPTHWISE_CONV_2D != ANEURALNETWORKS_DEPTHWISE_CONV_2D");
160 static_assert(static_cast<int32_t>(OperationType::DEPTH_TO_SPACE) ==
161                           ANEURALNETWORKS_DEPTH_TO_SPACE,
162               "OperationType::DEPTH_TO_SPACE != ANEURALNETWORKS_DEPTH_TO_SPACE");
163 static_assert(static_cast<int32_t>(OperationType::DEQUANTIZE) == ANEURALNETWORKS_DEQUANTIZE,
164               "OperationType::DEQUANTIZE != ANEURALNETWORKS_DEQUANTIZE");
165 static_assert(static_cast<int32_t>(OperationType::EMBEDDING_LOOKUP) ==
166                           ANEURALNETWORKS_EMBEDDING_LOOKUP,
167               "OperationType::EMBEDDING_LOOKUP != ANEURALNETWORKS_EMBEDDING_LOOKUP");
168 static_assert(static_cast<int32_t>(OperationType::FLOOR) == ANEURALNETWORKS_FLOOR,
169               "OperationType::FLOOR != ANEURALNETWORKS_FLOOR");
170 static_assert(static_cast<int32_t>(OperationType::FULLY_CONNECTED) ==
171                           ANEURALNETWORKS_FULLY_CONNECTED,
172               "OperationType::FULLY_CONNECTED != ANEURALNETWORKS_FULLY_CONNECTED");
173 static_assert(static_cast<int32_t>(OperationType::HASHTABLE_LOOKUP) ==
174                           ANEURALNETWORKS_HASHTABLE_LOOKUP,
175               "OperationType::HASHTABLE_LOOKUP != ANEURALNETWORKS_HASHTABLE_LOOKUP");
176 static_assert(static_cast<int32_t>(OperationType::L2_NORMALIZATION) ==
177                           ANEURALNETWORKS_L2_NORMALIZATION,
178               "OperationType::L2_NORMALIZATION != ANEURALNETWORKS_L2_NORMALIZATION");
179 static_assert(static_cast<int32_t>(OperationType::L2_POOL_2D) == ANEURALNETWORKS_L2_POOL_2D,
180               "OperationType::L2_POOL_2D != ANEURALNETWORKS_L2_POOL_2D");
181 static_assert(static_cast<int32_t>(OperationType::LOCAL_RESPONSE_NORMALIZATION) ==
182                           ANEURALNETWORKS_LOCAL_RESPONSE_NORMALIZATION,
183               "OperationType::LOCAL_RESPONSE_NORMALIZATION != "
184               "ANEURALNETWORKS_LOCAL_RESPONSE_NORMALIZATION");
185 static_assert(static_cast<int32_t>(OperationType::LOGISTIC) == ANEURALNETWORKS_LOGISTIC,
186               "OperationType::LOGISTIC != ANEURALNETWORKS_LOGISTIC");
187 static_assert(static_cast<int32_t>(OperationType::LSH_PROJECTION) ==
188                           ANEURALNETWORKS_LSH_PROJECTION,
189               "OperationType::LSH_PROJECTION != ANEURALNETWORKS_LSH_PROJECTION");
190 static_assert(static_cast<int32_t>(OperationType::LSTM) == ANEURALNETWORKS_LSTM,
191               "OperationType::LSTM != ANEURALNETWORKS_LSTM");
192 static_assert(static_cast<int32_t>(OperationType::MAX_POOL_2D) == ANEURALNETWORKS_MAX_POOL_2D,
193               "OperationType::MAX_POOL_2D != ANEURALNETWORKS_MAX_POOL_2D");
194 static_assert(static_cast<int32_t>(OperationType::MUL) == ANEURALNETWORKS_MUL,
195               "OperationType::MUL != ANEURALNETWORKS_MUL");
196 static_assert(static_cast<int32_t>(OperationType::RELU) == ANEURALNETWORKS_RELU,
197               "OperationType::RELU != ANEURALNETWORKS_RELU");
198 static_assert(static_cast<int32_t>(OperationType::RELU1) == ANEURALNETWORKS_RELU1,
199               "OperationType::RELU1 != ANEURALNETWORKS_RELU1");
200 static_assert(static_cast<int32_t>(OperationType::RELU6) == ANEURALNETWORKS_RELU6,
201               "OperationType::RELU6 != ANEURALNETWORKS_RELU6");
202 static_assert(static_cast<int32_t>(OperationType::RESHAPE) == ANEURALNETWORKS_RESHAPE,
203               "OperationType::RESHAPE != ANEURALNETWORKS_RESHAPE");
204 static_assert(static_cast<int32_t>(OperationType::RESIZE_BILINEAR) ==
205                           ANEURALNETWORKS_RESIZE_BILINEAR,
206               "OperationType::RESIZE_BILINEAR != ANEURALNETWORKS_RESIZE_BILINEAR");
207 static_assert(static_cast<int32_t>(OperationType::RNN) == ANEURALNETWORKS_RNN,
208               "OperationType::RNN != ANEURALNETWORKS_RNN");
209 static_assert(static_cast<int32_t>(OperationType::SOFTMAX) == ANEURALNETWORKS_SOFTMAX,
210               "OperationType::SOFTMAX != ANEURALNETWORKS_SOFTMAX");
211 static_assert(static_cast<int32_t>(OperationType::SPACE_TO_DEPTH) ==
212                           ANEURALNETWORKS_SPACE_TO_DEPTH,
213               "OperationType::SPACE_TO_DEPTH != ANEURALNETWORKS_SPACE_TO_DEPTH");
214 static_assert(static_cast<int32_t>(OperationType::SVDF) == ANEURALNETWORKS_SVDF,
215               "OperationType::SVDF != ANEURALNETWORKS_SVDF");
216 static_assert(static_cast<int32_t>(OperationType::TANH) == ANEURALNETWORKS_TANH,
217               "OperationType::TANH != ANEURALNETWORKS_TANH");
218 
219 static_assert(static_cast<int32_t>(OperationType::BATCH_TO_SPACE_ND) == ANEURALNETWORKS_BATCH_TO_SPACE_ND,
220               "OperationType::BATCH_TO_SPACE_ND != ANEURALNETWORKS_BATCH_TO_SPACE_ND");
221 static_assert(static_cast<int32_t>(OperationType::DIV) == ANEURALNETWORKS_DIV,
222               "OperationType::DIV != ANEURALNETWORKS_DIV");
223 static_assert(static_cast<int32_t>(OperationType::MEAN) == ANEURALNETWORKS_MEAN,
224               "OperationType::MEAN != ANEURALNETWORKS_MEAN");
225 static_assert(static_cast<int32_t>(OperationType::PAD) == ANEURALNETWORKS_PAD,
226               "OperationType::PAD != ANEURALNETWORKS_PAD");
227 static_assert(static_cast<int32_t>(OperationType::SPACE_TO_BATCH_ND) ==
228                           ANEURALNETWORKS_SPACE_TO_BATCH_ND,
229               "OperationType::SPACE_TO_BATCH_ND != ANEURALNETWORKS_SPACE_TO_BATCH_ND");
230 static_assert(static_cast<int32_t>(OperationType::SQUEEZE) == ANEURALNETWORKS_SQUEEZE,
231               "OperationType::SQUEEZE != ANEURALNETWORKS_SQUEEZE");
232 static_assert(static_cast<int32_t>(OperationType::STRIDED_SLICE) ==
233                           ANEURALNETWORKS_STRIDED_SLICE,
234               "OperationType::STRIDED_SLICE != ANEURALNETWORKS_STRIDED_SLICE");
235 static_assert(static_cast<int32_t>(OperationType::SUB) == ANEURALNETWORKS_SUB,
236               "OperationType::SUB != ANEURALNETWORKS_SUB");
237 static_assert(static_cast<int32_t>(OperationType::TRANSPOSE) == ANEURALNETWORKS_TRANSPOSE,
238               "OperationType::TRANSPOSE != ANEURALNETWORKS_TRANSPOSE");
239 
240 static_assert(static_cast<int32_t>(FusedActivationFunc::NONE) == ANEURALNETWORKS_FUSED_NONE,
241               "FusedActivationFunc::NONE != ANEURALNETWORKS_FUSED_NONE");
242 static_assert(static_cast<int32_t>(FusedActivationFunc::RELU) == ANEURALNETWORKS_FUSED_RELU,
243               "FusedActivationFunc::RELU != ANEURALNETWORKS_FUSED_RELU");
244 static_assert(static_cast<int32_t>(FusedActivationFunc::RELU1) == ANEURALNETWORKS_FUSED_RELU1,
245               "FusedActivationFunc::RELU1 != ANEURALNETWORKS_FUSED_RELU1");
246 static_assert(static_cast<int32_t>(FusedActivationFunc::RELU6) == ANEURALNETWORKS_FUSED_RELU6,
247               "FusedActivationFunc::RELU6 != ANEURALNETWORKS_FUSED_RELU6");
248 
249 using android::sp;
250 using namespace android::nn;
251 
ANeuralNetworksMemory_createFromFd(size_t size,int prot,int fd,size_t offset,ANeuralNetworksMemory ** memory)252 int ANeuralNetworksMemory_createFromFd(size_t size, int prot, int fd, size_t offset,
253                                        ANeuralNetworksMemory** memory) {
254     *memory = nullptr;
255     std::unique_ptr<MemoryFd> m = std::make_unique<MemoryFd>();
256     if (m == nullptr) {
257         return ANEURALNETWORKS_OUT_OF_MEMORY;
258     }
259     int n = m->set(size, prot, fd, offset);
260     if (n != ANEURALNETWORKS_NO_ERROR) {
261         return n;
262     }
263     *memory = reinterpret_cast<ANeuralNetworksMemory*>(m.release());
264     return ANEURALNETWORKS_NO_ERROR;
265 }
266 
ANeuralNetworksMemory_free(ANeuralNetworksMemory * memory)267 void ANeuralNetworksMemory_free(ANeuralNetworksMemory* memory) {
268     // No validation.  Free of nullptr is valid.
269     Memory* m = reinterpret_cast<Memory*>(memory);
270     delete m;
271 }
272 
ANeuralNetworksModel_create(ANeuralNetworksModel ** model)273 int ANeuralNetworksModel_create(ANeuralNetworksModel** model) {
274     initVLogMask();
275     if (!model) {
276         LOG(ERROR) << "ANeuralNetworksModel_create passed a nullptr";
277         return ANEURALNETWORKS_UNEXPECTED_NULL;
278     }
279     ModelBuilder* m = new (std::nothrow) ModelBuilder();
280     if (m == nullptr) {
281         *model = nullptr;
282         return ANEURALNETWORKS_OUT_OF_MEMORY;
283     }
284     *model = reinterpret_cast<ANeuralNetworksModel*>(m);
285     return ANEURALNETWORKS_NO_ERROR;
286 }
287 
ANeuralNetworksModel_free(ANeuralNetworksModel * model)288 void ANeuralNetworksModel_free(ANeuralNetworksModel* model) {
289     // No validation.  Free of nullptr is valid.
290     ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model);
291     delete m;
292 }
293 
ANeuralNetworksModel_finish(ANeuralNetworksModel * model)294 int ANeuralNetworksModel_finish(ANeuralNetworksModel* model) {
295     if (!model) {
296         LOG(ERROR) << "ANeuralNetworksModel_finish passed a nullptr";
297         return ANEURALNETWORKS_UNEXPECTED_NULL;
298     }
299     ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model);
300     return m->finish();
301 }
302 
ANeuralNetworksModel_addOperand(ANeuralNetworksModel * model,const ANeuralNetworksOperandType * type)303 int ANeuralNetworksModel_addOperand(ANeuralNetworksModel* model,
304                                     const ANeuralNetworksOperandType* type) {
305     if (!model || !type) {
306         LOG(ERROR) << "ANeuralNetworksModel_addOperand passed a nullptr";
307         return ANEURALNETWORKS_UNEXPECTED_NULL;
308     }
309     ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model);
310     return m->addOperand(*type);
311 }
312 
ANeuralNetworksModel_setOperandValue(ANeuralNetworksModel * model,int32_t index,const void * buffer,size_t length)313 int ANeuralNetworksModel_setOperandValue(ANeuralNetworksModel* model, int32_t index,
314                                          const void* buffer, size_t length) {
315     if (!model || (!buffer && length != 0)) {
316         LOG(ERROR) << "ANeuralNetworksModel_setOperandValue passed a nullptr";
317         return ANEURALNETWORKS_UNEXPECTED_NULL;
318     }
319     ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model);
320     return m->setOperandValue(index, buffer, length);
321 }
322 
ANeuralNetworksModel_setOperandValueFromMemory(ANeuralNetworksModel * model,int32_t index,const ANeuralNetworksMemory * memory,size_t offset,size_t length)323 int ANeuralNetworksModel_setOperandValueFromMemory(ANeuralNetworksModel* model, int32_t index,
324                                                    const ANeuralNetworksMemory* memory,
325                                                    size_t offset, size_t length) {
326     if (!model || !memory) {
327         LOG(ERROR) << "ANeuralNetworksModel_setOperandValue passed a nullptr";
328         return ANEURALNETWORKS_UNEXPECTED_NULL;
329     }
330     const Memory* mem = reinterpret_cast<const Memory*>(memory);
331     ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model);
332     return m->setOperandValueFromMemory(index, mem, offset, length);
333 }
334 
ANeuralNetworksModel_addOperation(ANeuralNetworksModel * model,ANeuralNetworksOperationType type,uint32_t inputCount,const uint32_t * inputs,uint32_t outputCount,const uint32_t * outputs)335 int ANeuralNetworksModel_addOperation(ANeuralNetworksModel* model,
336                                       ANeuralNetworksOperationType type, uint32_t inputCount,
337                                       const uint32_t* inputs, uint32_t outputCount,
338                                       const uint32_t* outputs) {
339     if (!model || !inputs || !outputs) {
340         LOG(ERROR) << "ANeuralNetworksModel_addOperation passed a nullptr";
341         return ANEURALNETWORKS_UNEXPECTED_NULL;
342     }
343     ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model);
344     return m->addOperation(type, inputCount, inputs, outputCount, outputs);
345 }
346 
ANeuralNetworksModel_identifyInputsAndOutputs(ANeuralNetworksModel * model,uint32_t inputCount,const uint32_t * inputs,uint32_t outputCount,const uint32_t * outputs)347 int ANeuralNetworksModel_identifyInputsAndOutputs(ANeuralNetworksModel* model, uint32_t inputCount,
348                                                   const uint32_t* inputs, uint32_t outputCount,
349                                                   const uint32_t* outputs) {
350     if (!model || !inputs || !outputs) {
351         LOG(ERROR) << ("ANeuralNetworksModel_identifyInputsAndOutputs passed a nullptr");
352         return ANEURALNETWORKS_UNEXPECTED_NULL;
353     }
354     ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model);
355     return m->identifyInputsAndOutputs(inputCount, inputs, outputCount, outputs);
356 }
357 
ANeuralNetworksModel_relaxComputationFloat32toFloat16(ANeuralNetworksModel * model,bool allow)358 int ANeuralNetworksModel_relaxComputationFloat32toFloat16(ANeuralNetworksModel* model,
359                                                           bool allow) {
360     if (!model) {
361         LOG(ERROR) << ("ANeuralNetworksModel_relaxComputationFloat32toFloat16 passed a nullptr");
362         return ANEURALNETWORKS_UNEXPECTED_NULL;
363     }
364     ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model);
365     return m->relaxComputationFloat32toFloat16(allow);
366 }
367 
ANeuralNetworksCompilation_create(ANeuralNetworksModel * model,ANeuralNetworksCompilation ** compilation)368 int ANeuralNetworksCompilation_create(ANeuralNetworksModel* model,
369                                       ANeuralNetworksCompilation** compilation) {
370     if (!model || !compilation) {
371         LOG(ERROR) << "ANeuralNetworksCompilation_create passed a nullptr";
372         return ANEURALNETWORKS_UNEXPECTED_NULL;
373     }
374 
375     ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model);
376     CompilationBuilder* c = nullptr;
377     int result = m->createCompilation(&c);
378     *compilation = reinterpret_cast<ANeuralNetworksCompilation*>(c);
379     return result;
380 }
381 
ANeuralNetworksCompilation_free(ANeuralNetworksCompilation * compilation)382 void ANeuralNetworksCompilation_free(ANeuralNetworksCompilation* compilation) {
383     // No validation.  Free of nullptr is valid.
384     // TODO specification says that a compilation-in-flight can be deleted
385     CompilationBuilder* c = reinterpret_cast<CompilationBuilder*>(compilation);
386     delete c;
387 }
388 
ANeuralNetworksCompilation_setPreference(ANeuralNetworksCompilation * compilation,int32_t preference)389 int ANeuralNetworksCompilation_setPreference(ANeuralNetworksCompilation* compilation,
390                                              int32_t preference) {
391     if (!compilation) {
392         LOG(ERROR) << "ANeuralNetworksCompilation_setPreference passed a nullptr";
393         return ANEURALNETWORKS_UNEXPECTED_NULL;
394     }
395     CompilationBuilder* c = reinterpret_cast<CompilationBuilder*>(compilation);
396     return c->setPreference(preference);
397 }
398 
ANeuralNetworksCompilation_finish(ANeuralNetworksCompilation * compilation)399 int ANeuralNetworksCompilation_finish(ANeuralNetworksCompilation* compilation) {
400     if (!compilation) {
401         LOG(ERROR) << "ANeuralNetworksCompilation_finish passed a nullptr";
402         return ANEURALNETWORKS_UNEXPECTED_NULL;
403     }
404     CompilationBuilder* c = reinterpret_cast<CompilationBuilder*>(compilation);
405     return c->finish();
406 }
407 
ANeuralNetworksExecution_create(ANeuralNetworksCompilation * compilation,ANeuralNetworksExecution ** execution)408 int ANeuralNetworksExecution_create(ANeuralNetworksCompilation* compilation,
409                                     ANeuralNetworksExecution** execution) {
410     if (!compilation || !execution) {
411         LOG(ERROR) << "ANeuralNetworksExecution_create passed a nullptr";
412         return ANEURALNETWORKS_UNEXPECTED_NULL;
413     }
414 
415     CompilationBuilder* c = reinterpret_cast<CompilationBuilder*>(compilation);
416     ExecutionBuilder* r = nullptr;
417     int result = c->createExecution(&r);
418     *execution = reinterpret_cast<ANeuralNetworksExecution*>(r);
419     return result;
420 }
421 
ANeuralNetworksExecution_free(ANeuralNetworksExecution * execution)422 void ANeuralNetworksExecution_free(ANeuralNetworksExecution* execution) {
423     // TODO specification says that an execution-in-flight can be deleted
424     // No validation.  Free of nullptr is valid.
425     ExecutionBuilder* r = reinterpret_cast<ExecutionBuilder*>(execution);
426     delete r;
427 }
428 
ANeuralNetworksExecution_setInput(ANeuralNetworksExecution * execution,int32_t index,const ANeuralNetworksOperandType * type,const void * buffer,size_t length)429 int ANeuralNetworksExecution_setInput(ANeuralNetworksExecution* execution, int32_t index,
430                                       const ANeuralNetworksOperandType* type, const void* buffer,
431                                       size_t length) {
432     if (!execution || (!buffer && length != 0)) {
433         LOG(ERROR) << "ANeuralNetworksExecution_setInput passed a nullptr";
434         return ANEURALNETWORKS_UNEXPECTED_NULL;
435     }
436     ExecutionBuilder* r = reinterpret_cast<ExecutionBuilder*>(execution);
437     return r->setInput(index, type, buffer, length);
438 }
439 
ANeuralNetworksExecution_setInputFromMemory(ANeuralNetworksExecution * execution,int32_t index,const ANeuralNetworksOperandType * type,const ANeuralNetworksMemory * memory,size_t offset,size_t length)440 int ANeuralNetworksExecution_setInputFromMemory(ANeuralNetworksExecution* execution, int32_t index,
441                                                 const ANeuralNetworksOperandType* type,
442                                                 const ANeuralNetworksMemory* memory, size_t offset,
443                                                 size_t length) {
444     if (!execution || !memory) {
445         LOG(ERROR) << "ANeuralNetworksExecution_setInputFromMemory passed a nullptr";
446         return ANEURALNETWORKS_UNEXPECTED_NULL;
447     }
448 
449     const Memory* m = reinterpret_cast<const Memory*>(memory);
450     ExecutionBuilder* r = reinterpret_cast<ExecutionBuilder*>(execution);
451     return r->setInputFromMemory(index, type, m, offset, length);
452 }
453 
ANeuralNetworksExecution_setOutput(ANeuralNetworksExecution * execution,int32_t index,const ANeuralNetworksOperandType * type,void * buffer,size_t length)454 int ANeuralNetworksExecution_setOutput(ANeuralNetworksExecution* execution, int32_t index,
455                                        const ANeuralNetworksOperandType* type, void* buffer,
456                                        size_t length) {
457     if (!execution || (!buffer && length != 0)) {
458         LOG(ERROR) << "ANeuralNetworksExecution_setOutput passed a nullptr";
459         return ANEURALNETWORKS_UNEXPECTED_NULL;
460     }
461     ExecutionBuilder* r = reinterpret_cast<ExecutionBuilder*>(execution);
462     return r->setOutput(index, type, buffer, length);
463 }
464 
ANeuralNetworksExecution_setOutputFromMemory(ANeuralNetworksExecution * execution,int32_t index,const ANeuralNetworksOperandType * type,const ANeuralNetworksMemory * memory,size_t offset,size_t length)465 int ANeuralNetworksExecution_setOutputFromMemory(ANeuralNetworksExecution* execution, int32_t index,
466                                                  const ANeuralNetworksOperandType* type,
467                                                  const ANeuralNetworksMemory* memory, size_t offset,
468                                                  size_t length) {
469     if (!execution || !memory) {
470         LOG(ERROR) << "ANeuralNetworksExecution_setOutputFromMemory passed a nullptr";
471         return ANEURALNETWORKS_UNEXPECTED_NULL;
472     }
473 
474     ExecutionBuilder* r = reinterpret_cast<ExecutionBuilder*>(execution);
475     const Memory* m = reinterpret_cast<const Memory*>(memory);
476     return r->setOutputFromMemory(index, type, m, offset, length);
477 }
478 
ANeuralNetworksExecution_startCompute(ANeuralNetworksExecution * execution,ANeuralNetworksEvent ** event)479 int ANeuralNetworksExecution_startCompute(ANeuralNetworksExecution* execution,
480                                           ANeuralNetworksEvent** event) {
481     if (!execution || !event) {
482         LOG(ERROR) << "ANeuralNetworksExecution_startCompute passed a nullptr";
483         return ANEURALNETWORKS_UNEXPECTED_NULL;
484     }
485     // TODO validate the rest
486 
487     ExecutionBuilder* r = reinterpret_cast<ExecutionBuilder*>(execution);
488 
489     // Dynamically allocate an sp to wrap an ExecutionCallback, seen in the NN
490     // API as an abstract event object. The sp<ExecutionCallback> object is
491     // returned when the execution has been successfully launched, otherwise a
492     // nullptr is returned. The sp is used for ref-counting purposes. Without
493     // it, the HIDL service could attempt to communicate with a dead callback
494     // object.
495     std::unique_ptr<sp<ExecutionCallback>> e = std::make_unique<sp<ExecutionCallback>>();
496     *event = nullptr;
497 
498     int n = r->startCompute(e.get());
499     if (n != ANEURALNETWORKS_NO_ERROR) {
500         return n;
501     }
502     *event = reinterpret_cast<ANeuralNetworksEvent*>(e.release());
503     return ANEURALNETWORKS_NO_ERROR;
504 }
505 
ANeuralNetworksEvent_wait(ANeuralNetworksEvent * event)506 int ANeuralNetworksEvent_wait(ANeuralNetworksEvent* event) {
507     if (event == nullptr) {
508         LOG(ERROR) << "ANeuralNetworksEvent_wait passed a nullptr";
509         return ANEURALNETWORKS_UNEXPECTED_NULL;
510     }
511 
512     sp<ExecutionCallback>* e = reinterpret_cast<sp<ExecutionCallback>*>(event);
513     (*e)->wait();
514     return convertErrorStatusToResultCode((*e)->getStatus());
515 }
516 
ANeuralNetworksEvent_free(ANeuralNetworksEvent * event)517 void ANeuralNetworksEvent_free(ANeuralNetworksEvent* event) {
518     // No validation.  Free of nullptr is valid.
519     if (event) {
520         sp<ExecutionCallback>* e = reinterpret_cast<sp<ExecutionCallback>*>(event);
521         (*e)->wait();
522         delete e;
523     }
524 }
525