1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 // Provides C++ classes to more easily use the Neural Networks API.
18 // TODO(b/117845862): this should be auto generated from NeuralNetworksWrapper.h.
19 
20 #ifndef ANDROID_ML_NN_RUNTIME_TEST_TEST_NEURAL_NETWORKS_WRAPPER_H
21 #define ANDROID_ML_NN_RUNTIME_TEST_TEST_NEURAL_NETWORKS_WRAPPER_H
22 
23 #include "NeuralNetworks.h"
24 #include "NeuralNetworksWrapper.h"
25 #include "NeuralNetworksWrapperExtensions.h"
26 
27 #include <math.h>
28 #include <optional>
29 #include <string>
30 #include <vector>
31 
32 namespace android {
33 namespace nn {
34 namespace test_wrapper {
35 
36 using wrapper::Event;
37 using wrapper::ExecutePreference;
38 using wrapper::ExtensionModel;
39 using wrapper::ExtensionOperandParams;
40 using wrapper::ExtensionOperandType;
41 using wrapper::Memory;
42 using wrapper::Model;
43 using wrapper::OperandType;
44 using wrapper::Result;
45 using wrapper::SymmPerChannelQuantParams;
46 using wrapper::Type;
47 
48 class Compilation {
49    public:
Compilation(const Model * model)50     Compilation(const Model* model) {
51         int result = ANeuralNetworksCompilation_create(model->getHandle(), &mCompilation);
52         if (result != 0) {
53             // TODO Handle the error
54         }
55     }
56 
Compilation()57     Compilation() {}
58 
~Compilation()59     ~Compilation() { ANeuralNetworksCompilation_free(mCompilation); }
60 
61     // Disallow copy semantics to ensure the runtime object can only be freed
62     // once. Copy semantics could be enabled if some sort of reference counting
63     // or deep-copy system for runtime objects is added later.
64     Compilation(const Compilation&) = delete;
65     Compilation& operator=(const Compilation&) = delete;
66 
67     // Move semantics to remove access to the runtime object from the wrapper
68     // object that is being moved. This ensures the runtime object will be
69     // freed only once.
Compilation(Compilation && other)70     Compilation(Compilation&& other) { *this = std::move(other); }
71     Compilation& operator=(Compilation&& other) {
72         if (this != &other) {
73             ANeuralNetworksCompilation_free(mCompilation);
74             mCompilation = other.mCompilation;
75             other.mCompilation = nullptr;
76         }
77         return *this;
78     }
79 
setPreference(ExecutePreference preference)80     Result setPreference(ExecutePreference preference) {
81         return static_cast<Result>(ANeuralNetworksCompilation_setPreference(
82                 mCompilation, static_cast<int32_t>(preference)));
83     }
84 
setCaching(const std::string & cacheDir,const std::vector<uint8_t> & token)85     Result setCaching(const std::string& cacheDir, const std::vector<uint8_t>& token) {
86         if (token.size() != ANEURALNETWORKS_BYTE_SIZE_OF_CACHE_TOKEN) {
87             return Result::BAD_DATA;
88         }
89         return static_cast<Result>(ANeuralNetworksCompilation_setCaching(
90                 mCompilation, cacheDir.c_str(), token.data()));
91     }
92 
finish()93     Result finish() { return static_cast<Result>(ANeuralNetworksCompilation_finish(mCompilation)); }
94 
getHandle()95     ANeuralNetworksCompilation* getHandle() const { return mCompilation; }
96 
97    protected:
98     ANeuralNetworksCompilation* mCompilation = nullptr;
99 };
100 
101 class Execution {
102    public:
Execution(const Compilation * compilation)103     Execution(const Compilation* compilation) : mCompilation(compilation->getHandle()) {
104         int result = ANeuralNetworksExecution_create(compilation->getHandle(), &mExecution);
105         if (result != 0) {
106             // TODO Handle the error
107         }
108     }
109 
~Execution()110     ~Execution() { ANeuralNetworksExecution_free(mExecution); }
111 
112     // Disallow copy semantics to ensure the runtime object can only be freed
113     // once. Copy semantics could be enabled if some sort of reference counting
114     // or deep-copy system for runtime objects is added later.
115     Execution(const Execution&) = delete;
116     Execution& operator=(const Execution&) = delete;
117 
118     // Move semantics to remove access to the runtime object from the wrapper
119     // object that is being moved. This ensures the runtime object will be
120     // freed only once.
Execution(Execution && other)121     Execution(Execution&& other) { *this = std::move(other); }
122     Execution& operator=(Execution&& other) {
123         if (this != &other) {
124             ANeuralNetworksExecution_free(mExecution);
125             mCompilation = other.mCompilation;
126             other.mCompilation = nullptr;
127             mExecution = other.mExecution;
128             other.mExecution = nullptr;
129         }
130         return *this;
131     }
132 
133     Result setInput(uint32_t index, const void* buffer, size_t length,
134                     const ANeuralNetworksOperandType* type = nullptr) {
135         return static_cast<Result>(
136                 ANeuralNetworksExecution_setInput(mExecution, index, type, buffer, length));
137     }
138 
139     Result setInputFromMemory(uint32_t index, const Memory* memory, uint32_t offset,
140                               uint32_t length, const ANeuralNetworksOperandType* type = nullptr) {
141         return static_cast<Result>(ANeuralNetworksExecution_setInputFromMemory(
142                 mExecution, index, type, memory->get(), offset, length));
143     }
144 
145     Result setOutput(uint32_t index, void* buffer, size_t length,
146                      const ANeuralNetworksOperandType* type = nullptr) {
147         return static_cast<Result>(
148                 ANeuralNetworksExecution_setOutput(mExecution, index, type, buffer, length));
149     }
150 
151     Result setOutputFromMemory(uint32_t index, const Memory* memory, uint32_t offset,
152                                uint32_t length, const ANeuralNetworksOperandType* type = nullptr) {
153         return static_cast<Result>(ANeuralNetworksExecution_setOutputFromMemory(
154                 mExecution, index, type, memory->get(), offset, length));
155     }
156 
startCompute(Event * event)157     Result startCompute(Event* event) {
158         ANeuralNetworksEvent* ev = nullptr;
159         Result result = static_cast<Result>(ANeuralNetworksExecution_startCompute(mExecution, &ev));
160         event->set(ev);
161         return result;
162     }
163 
compute()164     Result compute() {
165         switch (mComputeMode) {
166             case ComputeMode::SYNC: {
167                 return static_cast<Result>(ANeuralNetworksExecution_compute(mExecution));
168             }
169             case ComputeMode::ASYNC: {
170                 ANeuralNetworksEvent* event = nullptr;
171                 Result result = static_cast<Result>(
172                         ANeuralNetworksExecution_startCompute(mExecution, &event));
173                 if (result != Result::NO_ERROR) {
174                     return result;
175                 }
176                 // TODO how to manage the lifetime of events when multiple waiters is not
177                 // clear.
178                 result = static_cast<Result>(ANeuralNetworksEvent_wait(event));
179                 ANeuralNetworksEvent_free(event);
180                 return result;
181             }
182             case ComputeMode::BURST: {
183                 ANeuralNetworksBurst* burst = nullptr;
184                 Result result =
185                         static_cast<Result>(ANeuralNetworksBurst_create(mCompilation, &burst));
186                 if (result != Result::NO_ERROR) {
187                     return result;
188                 }
189                 result = static_cast<Result>(
190                         ANeuralNetworksExecution_burstCompute(mExecution, burst));
191                 ANeuralNetworksBurst_free(burst);
192                 return result;
193             }
194         }
195         return Result::BAD_DATA;
196     }
197 
198     // By default, compute() uses the synchronous API. setComputeMode() can be
199     // used to change the behavior of compute() to either:
200     // - use the asynchronous API and then wait for computation to complete
201     // or
202     // - use the burst API
203     // Returns the previous ComputeMode.
204     enum class ComputeMode { SYNC, ASYNC, BURST };
setComputeMode(ComputeMode mode)205     static ComputeMode setComputeMode(ComputeMode mode) {
206         ComputeMode oldComputeMode = mComputeMode;
207         mComputeMode = mode;
208         return oldComputeMode;
209     }
210 
getOutputOperandDimensions(uint32_t index,std::vector<uint32_t> * dimensions)211     Result getOutputOperandDimensions(uint32_t index, std::vector<uint32_t>* dimensions) {
212         uint32_t rank = 0;
213         Result result = static_cast<Result>(
214                 ANeuralNetworksExecution_getOutputOperandRank(mExecution, index, &rank));
215         dimensions->resize(rank);
216         if ((result != Result::NO_ERROR && result != Result::OUTPUT_INSUFFICIENT_SIZE) ||
217             rank == 0) {
218             return result;
219         }
220         result = static_cast<Result>(ANeuralNetworksExecution_getOutputOperandDimensions(
221                 mExecution, index, dimensions->data()));
222         return result;
223     }
224 
225    private:
226     ANeuralNetworksCompilation* mCompilation = nullptr;
227     ANeuralNetworksExecution* mExecution = nullptr;
228 
229     // Initialized to ComputeMode::SYNC in TestNeuralNetworksWrapper.cpp.
230     static ComputeMode mComputeMode;
231 };
232 
233 }  // namespace test_wrapper
234 }  // namespace nn
235 }  // namespace android
236 
237 #endif  // ANDROID_ML_NN_RUNTIME_TEST_TEST_NEURAL_NETWORKS_WRAPPER_H
238