1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 // Provides C++ classes to more easily use the Neural Networks API.
18 // TODO(b/117845862): this should be auto generated from NeuralNetworksWrapper.h.
19 
20 #ifndef ANDROID_PACKAGES_MODULES_NEURALNETWORKS_RUNTIME_TEST_TEST_NEURAL_NETWORKS_WRAPPER_H
21 #define ANDROID_PACKAGES_MODULES_NEURALNETWORKS_RUNTIME_TEST_TEST_NEURAL_NETWORKS_WRAPPER_H
22 
23 #include <math.h>
24 
25 #include <algorithm>
26 #include <memory>
27 #include <optional>
28 #include <string>
29 #include <utility>
30 #include <vector>
31 
32 #include "NeuralNetworks.h"
33 #include "NeuralNetworksWrapper.h"
34 #include "NeuralNetworksWrapperExtensions.h"
35 
36 #ifndef __NNAPI_FL5_MIN_ANDROID_API__
37 #define __NNAPI_FL5_MIN_ANDROID_API__ __ANDROID_API_FUTURE__
38 #endif
39 
40 namespace android {
41 namespace nn {
42 namespace test_wrapper {
43 
44 using wrapper::Event;
45 using wrapper::ExecutePreference;
46 using wrapper::ExecutePriority;
47 using wrapper::ExtensionModel;
48 using wrapper::ExtensionOperandParams;
49 using wrapper::ExtensionOperandType;
50 using wrapper::OperandType;
51 using wrapper::Result;
52 using wrapper::SymmPerChannelQuantParams;
53 using wrapper::Type;
54 
55 class Memory {
56    public:
57     // Takes ownership of a ANeuralNetworksMemory
Memory(ANeuralNetworksMemory * memory)58     Memory(ANeuralNetworksMemory* memory) : mMemory(memory) {}
59 
Memory(size_t size,int protect,int fd,size_t offset)60     Memory(size_t size, int protect, int fd, size_t offset) {
61         mValid = ANeuralNetworksMemory_createFromFd(size, protect, fd, offset, &mMemory) ==
62                  ANEURALNETWORKS_NO_ERROR;
63     }
64 
65 #ifdef __ANDROID__
Memory(AHardwareBuffer * buffer)66     Memory(AHardwareBuffer* buffer) {
67         mValid = ANeuralNetworksMemory_createFromAHardwareBuffer(buffer, &mMemory) ==
68                  ANEURALNETWORKS_NO_ERROR;
69     }
70 #endif  // __ANDROID__
71 
~Memory()72     virtual ~Memory() { ANeuralNetworksMemory_free(mMemory); }
73 
74     // Disallow copy semantics to ensure the runtime object can only be freed
75     // once. Copy semantics could be enabled if some sort of reference counting
76     // or deep-copy system for runtime objects is added later.
77     Memory(const Memory&) = delete;
78     Memory& operator=(const Memory&) = delete;
79 
80     // Move semantics to remove access to the runtime object from the wrapper
81     // object that is being moved. This ensures the runtime object will be
82     // freed only once.
Memory(Memory && other)83     Memory(Memory&& other) { *this = std::move(other); }
84     Memory& operator=(Memory&& other) {
85         if (this != &other) {
86             ANeuralNetworksMemory_free(mMemory);
87             mMemory = other.mMemory;
88             mValid = other.mValid;
89             other.mMemory = nullptr;
90             other.mValid = false;
91         }
92         return *this;
93     }
94 
get()95     ANeuralNetworksMemory* get() const { return mMemory; }
isValid()96     bool isValid() const { return mValid; }
97 
98    private:
99     ANeuralNetworksMemory* mMemory = nullptr;
100     bool mValid = true;
101 };
102 
103 class Model {
104    public:
Model()105     Model() {
106         // TODO handle the value returned by this call
107         ANeuralNetworksModel_create(&mModel);
108     }
~Model()109     ~Model() { ANeuralNetworksModel_free(mModel); }
110 
111     // Disallow copy semantics to ensure the runtime object can only be freed
112     // once. Copy semantics could be enabled if some sort of reference counting
113     // or deep-copy system for runtime objects is added later.
114     Model(const Model&) = delete;
115     Model& operator=(const Model&) = delete;
116 
117     // Move semantics to remove access to the runtime object from the wrapper
118     // object that is being moved. This ensures the runtime object will be
119     // freed only once.
Model(Model && other)120     Model(Model&& other) { *this = std::move(other); }
121     Model& operator=(Model&& other) {
122         if (this != &other) {
123             ANeuralNetworksModel_free(mModel);
124             mModel = other.mModel;
125             mNextOperandId = other.mNextOperandId;
126             mValid = other.mValid;
127             mRelaxed = other.mRelaxed;
128             mFinished = other.mFinished;
129             other.mModel = nullptr;
130             other.mNextOperandId = 0;
131             other.mValid = false;
132             other.mRelaxed = false;
133             other.mFinished = false;
134         }
135         return *this;
136     }
137 
finish()138     Result finish() {
139         if (mValid) {
140             auto result = static_cast<Result>(ANeuralNetworksModel_finish(mModel));
141             if (result != Result::NO_ERROR) {
142                 mValid = false;
143             }
144             mFinished = true;
145             return result;
146         } else {
147             return Result::BAD_STATE;
148         }
149     }
150 
addOperand(const OperandType * type)151     uint32_t addOperand(const OperandType* type) {
152         if (ANeuralNetworksModel_addOperand(mModel, &(type->operandType)) !=
153             ANEURALNETWORKS_NO_ERROR) {
154             mValid = false;
155         }
156         if (type->channelQuant) {
157             if (ANeuralNetworksModel_setOperandSymmPerChannelQuantParams(
158                         mModel, mNextOperandId, &type->channelQuant.value().params) !=
159                 ANEURALNETWORKS_NO_ERROR) {
160                 mValid = false;
161             }
162         }
163         return mNextOperandId++;
164     }
165 
166     template <typename T>
addConstantOperand(const OperandType * type,const T & value)167     uint32_t addConstantOperand(const OperandType* type, const T& value) {
168         static_assert(sizeof(T) <= ANEURALNETWORKS_MAX_SIZE_OF_IMMEDIATELY_COPIED_VALUES,
169                       "Values larger than ANEURALNETWORKS_MAX_SIZE_OF_IMMEDIATELY_COPIED_VALUES "
170                       "not supported");
171         uint32_t index = addOperand(type);
172         setOperandValue(index, &value);
173         return index;
174     }
175 
addModelOperand(const Model * value)176     uint32_t addModelOperand(const Model* value) {
177         OperandType operandType(Type::MODEL, {});
178         uint32_t operand = addOperand(&operandType);
179         setOperandValueFromModel(operand, value);
180         return operand;
181     }
182 
setOperandValue(uint32_t index,const void * buffer,size_t length)183     void setOperandValue(uint32_t index, const void* buffer, size_t length) {
184         if (ANeuralNetworksModel_setOperandValue(mModel, index, buffer, length) !=
185             ANEURALNETWORKS_NO_ERROR) {
186             mValid = false;
187         }
188     }
189 
190     template <typename T>
setOperandValue(uint32_t index,const T * value)191     void setOperandValue(uint32_t index, const T* value) {
192         static_assert(!std::is_pointer<T>(), "No operand may have a pointer as its value");
193         return setOperandValue(index, value, sizeof(T));
194     }
195 
setOperandValueFromMemory(uint32_t index,const Memory * memory,uint32_t offset,size_t length)196     void setOperandValueFromMemory(uint32_t index, const Memory* memory, uint32_t offset,
197                                    size_t length) {
198         if (ANeuralNetworksModel_setOperandValueFromMemory(mModel, index, memory->get(), offset,
199                                                            length) != ANEURALNETWORKS_NO_ERROR) {
200             mValid = false;
201         }
202     }
203 
setOperandValueFromModel(uint32_t index,const Model * value)204     void setOperandValueFromModel(uint32_t index, const Model* value) {
205         if (ANeuralNetworksModel_setOperandValueFromModel(mModel, index, value->mModel) !=
206             ANEURALNETWORKS_NO_ERROR) {
207             mValid = false;
208         }
209     }
210 
addOperation(ANeuralNetworksOperationType type,const std::vector<uint32_t> & inputs,const std::vector<uint32_t> & outputs)211     void addOperation(ANeuralNetworksOperationType type, const std::vector<uint32_t>& inputs,
212                       const std::vector<uint32_t>& outputs) {
213         if (ANeuralNetworksModel_addOperation(mModel, type, static_cast<uint32_t>(inputs.size()),
214                                               inputs.data(), static_cast<uint32_t>(outputs.size()),
215                                               outputs.data()) != ANEURALNETWORKS_NO_ERROR) {
216             mValid = false;
217         }
218     }
identifyInputsAndOutputs(const std::vector<uint32_t> & inputs,const std::vector<uint32_t> & outputs)219     void identifyInputsAndOutputs(const std::vector<uint32_t>& inputs,
220                                   const std::vector<uint32_t>& outputs) {
221         if (ANeuralNetworksModel_identifyInputsAndOutputs(
222                     mModel, static_cast<uint32_t>(inputs.size()), inputs.data(),
223                     static_cast<uint32_t>(outputs.size()),
224                     outputs.data()) != ANEURALNETWORKS_NO_ERROR) {
225             mValid = false;
226         }
227     }
228 
relaxComputationFloat32toFloat16(bool isRelax)229     void relaxComputationFloat32toFloat16(bool isRelax) {
230         if (ANeuralNetworksModel_relaxComputationFloat32toFloat16(mModel, isRelax) ==
231             ANEURALNETWORKS_NO_ERROR) {
232             mRelaxed = isRelax;
233         }
234     }
235 
getHandle()236     ANeuralNetworksModel* getHandle() const { return mModel; }
isValid()237     bool isValid() const { return mValid; }
isRelaxed()238     bool isRelaxed() const { return mRelaxed; }
isFinished()239     bool isFinished() const { return mFinished; }
240 
241    protected:
242     ANeuralNetworksModel* mModel = nullptr;
243     // We keep track of the operand ID as a convenience to the caller.
244     uint32_t mNextOperandId = 0;
245     bool mValid = true;
246     bool mRelaxed = false;
247     bool mFinished = false;
248 };
249 
250 class Compilation {
251    public:
252     // On success, createForDevice(s) will return Result::NO_ERROR and the created compilation;
253     // otherwise, it will return the error code and Compilation object wrapping a nullptr handle.
createForDevice(const Model * model,const ANeuralNetworksDevice * device)254     static std::pair<Result, Compilation> createForDevice(const Model* model,
255                                                           const ANeuralNetworksDevice* device) {
256         return createForDevices(model, {device});
257     }
createForDevices(const Model * model,const std::vector<const ANeuralNetworksDevice * > & devices)258     static std::pair<Result, Compilation> createForDevices(
259             const Model* model, const std::vector<const ANeuralNetworksDevice*>& devices) {
260         ANeuralNetworksCompilation* compilation = nullptr;
261         const Result result = static_cast<Result>(ANeuralNetworksCompilation_createForDevices(
262                 model->getHandle(), devices.empty() ? nullptr : devices.data(), devices.size(),
263                 &compilation));
264         return {result, Compilation(compilation)};
265     }
266 
Compilation(const Model * model)267     Compilation(const Model* model) {
268         int result = ANeuralNetworksCompilation_create(model->getHandle(), &mCompilation);
269         if (result != 0) {
270             // TODO Handle the error
271         }
272     }
273 
Compilation()274     Compilation() {}
275 
~Compilation()276     ~Compilation() { ANeuralNetworksCompilation_free(mCompilation); }
277 
278     // Disallow copy semantics to ensure the runtime object can only be freed
279     // once. Copy semantics could be enabled if some sort of reference counting
280     // or deep-copy system for runtime objects is added later.
281     Compilation(const Compilation&) = delete;
282     Compilation& operator=(const Compilation&) = delete;
283 
284     // Move semantics to remove access to the runtime object from the wrapper
285     // object that is being moved. This ensures the runtime object will be
286     // freed only once.
Compilation(Compilation && other)287     Compilation(Compilation&& other) { *this = std::move(other); }
288     Compilation& operator=(Compilation&& other) {
289         if (this != &other) {
290             ANeuralNetworksCompilation_free(mCompilation);
291             mCompilation = other.mCompilation;
292             other.mCompilation = nullptr;
293         }
294         return *this;
295     }
296 
setPreference(ExecutePreference preference)297     Result setPreference(ExecutePreference preference) {
298         return static_cast<Result>(ANeuralNetworksCompilation_setPreference(
299                 mCompilation, static_cast<int32_t>(preference)));
300     }
301 
setPriority(ExecutePriority priority)302     Result setPriority(ExecutePriority priority) {
303         return static_cast<Result>(ANeuralNetworksCompilation_setPriority(
304                 mCompilation, static_cast<int32_t>(priority)));
305     }
306 
setCaching(const std::string & cacheDir,const std::vector<uint8_t> & token)307     Result setCaching(const std::string& cacheDir, const std::vector<uint8_t>& token) {
308         if (token.size() != ANEURALNETWORKS_BYTE_SIZE_OF_CACHE_TOKEN) {
309             return Result::BAD_DATA;
310         }
311         return static_cast<Result>(ANeuralNetworksCompilation_setCaching(
312                 mCompilation, cacheDir.c_str(), token.data()));
313     }
314 
finish()315     Result finish() { return static_cast<Result>(ANeuralNetworksCompilation_finish(mCompilation)); }
316 
getPreferredMemoryAlignmentForInput(uint32_t index,uint32_t * alignment)317     Result getPreferredMemoryAlignmentForInput(uint32_t index, uint32_t* alignment) const {
318         if (__builtin_available(android __NNAPI_FL5_MIN_ANDROID_API__, *)) {
319             return static_cast<Result>(
320                     NNAPI_CALL(ANeuralNetworksCompilation_getPreferredMemoryAlignmentForInput(
321                             mCompilation, index, alignment)));
322         } else {
323             return Result::FEATURE_LEVEL_TOO_LOW;
324         }
325     };
326 
getPreferredMemoryPaddingForInput(uint32_t index,uint32_t * padding)327     Result getPreferredMemoryPaddingForInput(uint32_t index, uint32_t* padding) const {
328         if (__builtin_available(android __NNAPI_FL5_MIN_ANDROID_API__, *)) {
329             return static_cast<Result>(
330                     NNAPI_CALL(ANeuralNetworksCompilation_getPreferredMemoryPaddingForInput(
331                             mCompilation, index, padding)));
332         } else {
333             return Result::FEATURE_LEVEL_TOO_LOW;
334         }
335     };
336 
getPreferredMemoryAlignmentForOutput(uint32_t index,uint32_t * alignment)337     Result getPreferredMemoryAlignmentForOutput(uint32_t index, uint32_t* alignment) const {
338         if (__builtin_available(android __NNAPI_FL5_MIN_ANDROID_API__, *)) {
339             return static_cast<Result>(
340                     NNAPI_CALL(ANeuralNetworksCompilation_getPreferredMemoryAlignmentForOutput(
341                             mCompilation, index, alignment)));
342         } else {
343             return Result::FEATURE_LEVEL_TOO_LOW;
344         }
345     };
346 
getPreferredMemoryPaddingForOutput(uint32_t index,uint32_t * padding)347     Result getPreferredMemoryPaddingForOutput(uint32_t index, uint32_t* padding) const {
348         if (__builtin_available(android __NNAPI_FL5_MIN_ANDROID_API__, *)) {
349             return static_cast<Result>(
350                     NNAPI_CALL(ANeuralNetworksCompilation_getPreferredMemoryPaddingForOutput(
351                             mCompilation, index, padding)));
352         } else {
353             return Result::FEATURE_LEVEL_TOO_LOW;
354         }
355     };
356 
getHandle()357     ANeuralNetworksCompilation* getHandle() const { return mCompilation; }
358 
359    protected:
360     // Takes the ownership of ANeuralNetworksCompilation.
Compilation(ANeuralNetworksCompilation * compilation)361     Compilation(ANeuralNetworksCompilation* compilation) : mCompilation(compilation) {}
362 
363     ANeuralNetworksCompilation* mCompilation = nullptr;
364 };
365 
366 class Execution {
367    public:
Execution(const Compilation * compilation)368     Execution(const Compilation* compilation) : mCompilation(compilation->getHandle()) {
369         int result = ANeuralNetworksExecution_create(compilation->getHandle(), &mExecution);
370         if (result != 0) {
371             // TODO Handle the error
372         }
373     }
374 
~Execution()375     ~Execution() { ANeuralNetworksExecution_free(mExecution); }
376 
377     // Disallow copy semantics to ensure the runtime object can only be freed
378     // once. Copy semantics could be enabled if some sort of reference counting
379     // or deep-copy system for runtime objects is added later.
380     Execution(const Execution&) = delete;
381     Execution& operator=(const Execution&) = delete;
382 
383     // Move semantics to remove access to the runtime object from the wrapper
384     // object that is being moved. This ensures the runtime object will be
385     // freed only once.
Execution(Execution && other)386     Execution(Execution&& other) { *this = std::move(other); }
387     Execution& operator=(Execution&& other) {
388         if (this != &other) {
389             ANeuralNetworksExecution_free(mExecution);
390             mCompilation = other.mCompilation;
391             other.mCompilation = nullptr;
392             mExecution = other.mExecution;
393             other.mExecution = nullptr;
394         }
395         return *this;
396     }
397 
398     Result setInput(uint32_t index, const void* buffer, size_t length,
399                     const ANeuralNetworksOperandType* type = nullptr) {
400         return static_cast<Result>(
401                 ANeuralNetworksExecution_setInput(mExecution, index, type, buffer, length));
402     }
403 
404     template <typename T>
405     Result setInput(uint32_t index, const T* value,
406                     const ANeuralNetworksOperandType* type = nullptr) {
407         static_assert(!std::is_pointer<T>(), "No operand may have a pointer as its value");
408         return setInput(index, value, sizeof(T), type);
409     }
410 
411     Result setInputFromMemory(uint32_t index, const Memory* memory, uint32_t offset,
412                               uint32_t length, const ANeuralNetworksOperandType* type = nullptr) {
413         return static_cast<Result>(ANeuralNetworksExecution_setInputFromMemory(
414                 mExecution, index, type, memory->get(), offset, length));
415     }
416 
417     Result setOutput(uint32_t index, void* buffer, size_t length,
418                      const ANeuralNetworksOperandType* type = nullptr) {
419         return static_cast<Result>(
420                 ANeuralNetworksExecution_setOutput(mExecution, index, type, buffer, length));
421     }
422 
423     template <typename T>
424     Result setOutput(uint32_t index, T* value, const ANeuralNetworksOperandType* type = nullptr) {
425         static_assert(!std::is_pointer<T>(), "No operand may have a pointer as its value");
426         return setOutput(index, value, sizeof(T), type);
427     }
428 
429     Result setOutputFromMemory(uint32_t index, const Memory* memory, uint32_t offset,
430                                uint32_t length, const ANeuralNetworksOperandType* type = nullptr) {
431         return static_cast<Result>(ANeuralNetworksExecution_setOutputFromMemory(
432                 mExecution, index, type, memory->get(), offset, length));
433     }
434 
setLoopTimeout(uint64_t duration)435     Result setLoopTimeout(uint64_t duration) {
436         return static_cast<Result>(ANeuralNetworksExecution_setLoopTimeout(mExecution, duration));
437     }
438 
enableInputAndOutputPadding(bool enable)439     Result enableInputAndOutputPadding(bool enable) {
440         if (__builtin_available(android __NNAPI_FL5_MIN_ANDROID_API__, *)) {
441             return static_cast<Result>(
442                     ANeuralNetworksExecution_enableInputAndOutputPadding(mExecution, enable));
443         } else {
444             return Result::FEATURE_LEVEL_TOO_LOW;
445         }
446     }
447 
setReusable(bool reusable)448     Result setReusable(bool reusable) {
449         if (__builtin_available(android __NNAPI_FL5_MIN_ANDROID_API__, *)) {
450             return static_cast<Result>(
451                     NNAPI_CALL(ANeuralNetworksExecution_setReusable(mExecution, reusable)));
452         } else {
453             return Result::FEATURE_LEVEL_TOO_LOW;
454         }
455     }
456 
startCompute(Event * event)457     Result startCompute(Event* event) {
458         ANeuralNetworksEvent* ev = nullptr;
459         Result result = static_cast<Result>(ANeuralNetworksExecution_startCompute(mExecution, &ev));
460         event->set(ev);
461         return result;
462     }
463 
startComputeWithDependencies(const std::vector<const Event * > & dependencies,uint64_t duration,Event * event)464     Result startComputeWithDependencies(const std::vector<const Event*>& dependencies,
465                                         uint64_t duration, Event* event) {
466         std::vector<const ANeuralNetworksEvent*> deps(dependencies.size());
467         std::transform(dependencies.begin(), dependencies.end(), deps.begin(),
468                        [](const Event* e) { return e->getHandle(); });
469         ANeuralNetworksEvent* ev = nullptr;
470         Result result = static_cast<Result>(ANeuralNetworksExecution_startComputeWithDependencies(
471                 mExecution, deps.data(), deps.size(), duration, &ev));
472         event->set(ev);
473         return result;
474     }
475 
476     // By default, compute() uses the synchronous API. Either an argument or
477     // setComputeMode() can be used to change the behavior of compute() to
478     // either:
479     // - use the asynchronous or fenced API and then wait for computation to complete
480     // or
481     // - use the burst API
482     // Returns the previous ComputeMode.
483     enum class ComputeMode { SYNC, ASYNC, BURST, FENCED };
setComputeMode(ComputeMode mode)484     static ComputeMode setComputeMode(ComputeMode mode) {
485         ComputeMode oldComputeMode = mComputeMode;
486         mComputeMode = mode;
487         return oldComputeMode;
488     }
getComputeMode()489     static ComputeMode getComputeMode() { return mComputeMode; }
490 
491     Result compute(ComputeMode computeMode = mComputeMode) {
492         switch (computeMode) {
493             case ComputeMode::SYNC: {
494                 return static_cast<Result>(ANeuralNetworksExecution_compute(mExecution));
495             }
496             case ComputeMode::ASYNC: {
497                 ANeuralNetworksEvent* event = nullptr;
498                 Result result = static_cast<Result>(
499                         ANeuralNetworksExecution_startCompute(mExecution, &event));
500                 if (result != Result::NO_ERROR) {
501                     return result;
502                 }
503                 // TODO how to manage the lifetime of events when multiple waiters is not
504                 // clear.
505                 result = static_cast<Result>(ANeuralNetworksEvent_wait(event));
506                 ANeuralNetworksEvent_free(event);
507                 return result;
508             }
509             case ComputeMode::BURST: {
510                 ANeuralNetworksBurst* burst = nullptr;
511                 Result result =
512                         static_cast<Result>(ANeuralNetworksBurst_create(mCompilation, &burst));
513                 if (result != Result::NO_ERROR) {
514                     return result;
515                 }
516                 result = static_cast<Result>(
517                         ANeuralNetworksExecution_burstCompute(mExecution, burst));
518                 ANeuralNetworksBurst_free(burst);
519                 return result;
520             }
521             case ComputeMode::FENCED: {
522                 ANeuralNetworksEvent* event = nullptr;
523                 Result result =
524                         static_cast<Result>(ANeuralNetworksExecution_startComputeWithDependencies(
525                                 mExecution, nullptr, 0, 0, &event));
526                 if (result != Result::NO_ERROR) {
527                     return result;
528                 }
529                 result = static_cast<Result>(ANeuralNetworksEvent_wait(event));
530                 ANeuralNetworksEvent_free(event);
531                 return result;
532             }
533         }
534         return Result::BAD_DATA;
535     }
536 
getOutputOperandDimensions(uint32_t index,std::vector<uint32_t> * dimensions)537     Result getOutputOperandDimensions(uint32_t index, std::vector<uint32_t>* dimensions) {
538         uint32_t rank = 0;
539         Result result = static_cast<Result>(
540                 ANeuralNetworksExecution_getOutputOperandRank(mExecution, index, &rank));
541         dimensions->resize(rank);
542         if ((result != Result::NO_ERROR && result != Result::OUTPUT_INSUFFICIENT_SIZE) ||
543             rank == 0) {
544             return result;
545         }
546         result = static_cast<Result>(ANeuralNetworksExecution_getOutputOperandDimensions(
547                 mExecution, index, dimensions->data()));
548         return result;
549     }
550 
getHandle()551     ANeuralNetworksExecution* getHandle() { return mExecution; };
552 
553    private:
554     ANeuralNetworksCompilation* mCompilation = nullptr;
555     ANeuralNetworksExecution* mExecution = nullptr;
556 
557     // Initialized to ComputeMode::SYNC in TestNeuralNetworksWrapper.cpp.
558     static ComputeMode mComputeMode;
559 };
560 
561 }  // namespace test_wrapper
562 }  // namespace nn
563 }  // namespace android
564 
565 #endif  // ANDROID_PACKAGES_MODULES_NEURALNETWORKS_RUNTIME_TEST_TEST_NEURAL_NETWORKS_WRAPPER_H
566