1 #ifndef ANDROID_HARDWARE_NEURALNETWORKS_V1_0_CALLBACKS_H
2 #define ANDROID_HARDWARE_NEURALNETWORKS_V1_0_CALLBACKS_H
3
4 #include <android/hardware/neuralnetworks/1.0/IExecutionCallback.h>
5 #include <android/hardware/neuralnetworks/1.0/IPreparedModelCallback.h>
6 #include <chrono>
7 #include <condition_variable>
8 #include <functional>
9 #include <hidl/MQDescriptor.h>
10 #include <hidl/Status.h>
11 #include <mutex>
12 #include <thread>
13
14 namespace android {
15 namespace hardware {
16 namespace neuralnetworks {
17 namespace V1_0 {
18 namespace implementation {
19
20 /**
21 * The CallbackBase class is used internally by the NeuralNetworks runtime to
22 * synchronize between different threads. An asynchronous task is launched
23 * paired with a callback object. When a client thread requires the output being
24 * generated by the asynchronous task, the client thread can wait for the result
25 * and be blocked until it has completed or a timeout condition has been
26 * reached. Any wait* may safely be called concurrently, even on the same
27 * callback object. When the asynchronous task has finished its workload, it
28 * must immediately call "notify". If the asynchronous task has failed to launch,
29 * the function that tried to launch the asynchronous task must immediately call
30 * "notify". This "notify" call awakens any client threads waiting on the
31 * callback object.
32 *
33 * The CallbackBase class implements some of the base synchronization common to
34 * both PrepareModelCallback and ExecutionCallback. For consistency, any HIDL
35 * callback class must inherit from CallbackBase as well as the HIDL callback
36 * interface it implements.
37 *
38 * This class exists to enable synchronization across HIDL. When synchronization
39 * is only required in the same process, consider using std::future, std::mutex,
40 * std::condition_variable, or std::experimental::latch instead.
41 */
42 class CallbackBase {
43 public:
44 CallbackBase();
45 ~CallbackBase();
46
47 /**
48 * CallbackBase::wait blocks until notify has been called on the callback
49 * object.
50 */
51 void wait();
52
53 /**
54 * CallbackBase::wait_for blocks until notify has been called on the
55 * callback object or the time duration from the time the wait_for function
56 * was called has expired, whichever comes first.
57 *
58 * @return Status std::cv_status::no_timeout if the callback was notified
59 * before the time duration expired, std::cv_status::timeout
60 * otherwise.
61 */
62 template<class Rep, class Period>
63 std::cv_status wait_for(const std::chrono::duration<Rep,Period>& timeout_duration);
64
65 /**
66 * CallbackBase::on_finish binds a function to the callback object. This
67 * bound function will be executed when CallbackBase::notify is called,
68 * before any calls to wait* return. (Note that CallbackBase::wait_for can
69 * return std::cv_status::timeout before CallbackBase::notify is called for
70 * the first time, and hence before the bound function is executed.)
71 *
72 * The bound function must not synchronize with or otherwise access the
73 * callback object it is bound to, as this could cause a deadlock.
74 *
75 * CallbackBase::on_finish can be called at most once on a given callback
76 * object, and the call to CallbackBase::on_finish must finish before
77 * CallbackBase::notify is called.
78 *
79 * @param post_work Function to be invoked the first time
80 * CallbackBase::notify is called. Must have a target --
81 * i.e., must not compare equal to nullptr. post_work
82 * returns true if it successfully completes, false if it
83 * fails.
84 * @return bool True if the function was successfully bound, false if
85 * unsuccessful.
86 *
87 * TODO: Why does the return value of the callback matter?
88 */
89 bool on_finish(std::function<bool(void)> post_work);
90
91 /**
92 * CallbackBase::bind_thread binds a thread to the event for later use by
93 * CallbackBase::join_thread.
94 *
95 * The thread must be passed using std::move.
96 *
97 * Once a thread is bound with CallbackBase::bind_thread, the client code
98 * should ensure that one of the following occurs before the event is
99 * destroyed:
100 * - CallbackBase::join_thread has been called.
101 * - CallbackBase::wait has been called.
102 * - CallbackBase::wait_for has been called and returned other than
103 * std::cv_status::no_timeout.
104 *
105 * The bound thread shall not call any CallbackBase method with the
106 * exception of CallbackBase::notify, which it must call when the thread has
107 * finished its computation.
108 *
109 * CallbackBase::bind_thread can be called at most once on a given callback
110 * object.
111 *
112 * @param asyncThread Thread to be bound to the callback object. The thread
113 * object must represent a thread of execution -- i.e.,
114 * asyncThread.joinable() must be true.
115 * @return bool True if successful, false if thread was not properly bound.
116 */
117 bool bind_thread(std::thread&& asyncThread);
118
119 /**
120 * CallbackBase::join_thread ensures that the thread (if any) bound to this
121 * event with CallbackBase::bind_thread has fully finished and cleaned its
122 * resources. It is legal to call this function multiple times, concurrently
123 * or sequentially.
124 */
125 void join_thread();
126
127 protected:
128 /**
129 * CallbackBase::notify enables all prior and future wait* calls on the
130 * callback object to proceed. The call to CallbackBase::notify happens
131 * before any wait* calls on this callback object return (except in the case
132 * of wait_for timing out). The asynchronous call the callback object is
133 * paired with must ensure that any update to state that should be visible
134 * to the caller of wait* happens before the call to CallbackBase::notify.
135 *
136 * CallbackBase::notify must be called exactly once on a given callback
137 * object.
138 */
139 void notify();
140
141 private:
142 // Same as CallbackBase::join_thread but assumes we already hold a lock on
143 // mMutex.
144 void join_thread_locked();
145
146 bool mNotified;
147 std::mutex mMutex;
148 std::condition_variable mCondition;
149 std::function<bool(void)> mPostWork;
150 std::thread mThread;
151 };
152
153 /**
154 * The PreparedModelCallback class is used to receive the error status of
155 * preparing a model as well as the prepared model from a task executing
156 * asynchronously with respect to the runtime. If a calling thread calls wait*
157 * or get* on a PreparedModelCallback object and the corresponding asynchronous
158 * task has not finished preparing the model, the calling thread will block
159 * until the asynchronous task has called notify. For more information on the
160 * synchronization behavior, refer to the CallbackBase class.
161 *
162 * This class inherits the basic blocking and signaling calls from
163 * CallbackBase, and implements the HIDL notify call from
164 * IPreparedModelCallback. This callback object is passed as an argument to
165 * IDevice::prepareModel.
166 */
167 class PreparedModelCallback : public CallbackBase, public IPreparedModelCallback {
168 public:
169 PreparedModelCallback();
170 ~PreparedModelCallback() override;
171
172 /**
173 * IPreparedModelCallback::notify marks the callback object with the return
174 * status of the asynchronous model preparation along with the prepared
175 * model, and calls CallbackBase::notify, enabling all prior and future
176 * wait* calls on the PreparedModelCallback object to proceed. For more
177 * information on the synchronization behavior, refer to the CallbackBase
178 * class.
179 *
180 * IPreparedModelCallback::notify must be called exactly once on a given
181 * PreparedModelCallback object.
182 *
183 * @param status Error status returned from asynchronously preparing the
184 * model; will be:
185 * - NONE if the asynchronous preparation was successful
186 * - DEVICE_UNAVAILABLE if driver is offline or busy
187 * - GENERAL_FAILURE if there is an unspecified error
188 * - INVALID_ARGUMENT if the input model is invalid
189 * @param preparedModel Returned model that has been prepared for execution,
190 * nullptr if the model was unable to be prepared.
191 */
192 Return<void> notify(ErrorStatus status, const sp<IPreparedModel>& preparedModel) override;
193
194 /**
195 * Retrieves the error status returned from the asynchronous task launched
196 * by IDevice::prepareModel. If IDevice::prepareModel has not finished
197 * asynchronously preparing the model, this call will block until the
198 * asynchronous task notifies the object.
199 *
200 * @return status Error status returned from asynchronously preparing the
201 * model; will be:
202 * - NONE if the asynchronous preparation was successful
203 * - DEVICE_UNAVAILABLE if driver is offline or busy
204 * - GENERAL_FAILURE if there is an unspecified error
205 * - INVALID_ARGUMENT if the input model is invalid
206 */
207 ErrorStatus getStatus();
208
209 /**
210 * Retrieves the model that has been prepared for execution from the
211 * asynchronous task launched by IDevice::prepareModel. If
212 * IDevice::prepareModel has not finished asynchronously preparing the
213 * model, this call will block until the asynchronous task notifies the
214 * object.
215 *
216 * @return preparedModel Returned model that has been prepared for
217 * execution, nullptr if the model was unable to be
218 * prepared.
219 */
220 sp<IPreparedModel> getPreparedModel();
221
222 private:
223 ErrorStatus mErrorStatus;
224 sp<IPreparedModel> mPreparedModel;
225 };
226
227 /**
228 * The ExecutionCallback class is used to receive the error status of the
229 * execution from a task executing asynchronously with respect to the runtime.
230 * If a calling thread calls wait* or get* on a PreparedModelCallback object and
231 * the corresponding asynchronous task has not finished the execution, the
232 * calling thread will block until the asynchronous task has called notify. For
233 * more information on the synchronization behavior, refer to the CallbackBase
234 * class.
235 *
236 * This class inherits the basic blocking and signaling calls from
237 * CallbackBase, and implements the HIDL notify call from
238 * IExecutionCallback. This callback object is passed as an argument to
239 * IPreparedModel::execute.
240 */
241 class ExecutionCallback : public CallbackBase, public IExecutionCallback {
242 public:
243 ExecutionCallback();
244 ~ExecutionCallback() override;
245
246 /**
247 * IExecutionCallback::notify marks the callback object with the return
248 * status of the asynchronous execution that held this callback and enables
249 * all prior and future wait* calls on the ExecutionCallback object to
250 * proceed. For more information on the synchronization behavior, refer to
251 * the CallbackBase class.
252 *
253 * IExecutionCallback::notify must be called exactly once on a given
254 * ExecutionCallback object.
255 *
256 * @param status Error status returned from asynchronously preparing the
257 * model; will be:
258 * - NONE if the asynchronous execution was successful
259 * - DEVICE_UNAVAILABLE if driver is offline or busy
260 * - GENERAL_FAILURE if there is an unspecified error
261 * - OUTPUT_INSUFFICIENT_SIZE if provided output buffer is
262 * not large enough to store the resultant values
263 * - INVALID_ARGUMENT if the input request is invalid
264 */
265 Return<void> notify(ErrorStatus status) override;
266
267 /**
268 * Retrieves the error status returned from the asynchronous task launched
269 * by IPreparedModel::execute. If IPreparedModel::execute has not finished
270 * asynchronously executing, this call will block until the asynchronous task
271 * notifies the object.
272 *
273 * @return status Error status returned from asynchronously preparing the
274 * model; will be:
275 * - NONE if the asynchronous execution was successful
276 * - DEVICE_UNAVAILABLE if driver is offline or busy
277 * - GENERAL_FAILURE if there is an unspecified error
278 * - OUTPUT_INSUFFICIENT_SIZE if provided output buffer is
279 * not large enough to store the resultant values
280 * - INVALID_ARGUMENT if the input request is invalid
281 */
282 ErrorStatus getStatus();
283
284 private:
285 ErrorStatus mErrorStatus;
286 };
287
288
289 // template function implementation(s) below this point
290
291 template<class Rep, class Period>
wait_for(const std::chrono::duration<Rep,Period> & timeout_duration)292 std::cv_status CallbackBase::wait_for(const std::chrono::duration<Rep,Period>& timeout_duration) {
293 std::unique_lock<std::mutex> lock(mMutex);
294 std::cv_status status = mCondition.wait_for(lock, timeout_duration, [this]{return mNotified;});
295 if (status != std::cv_status::timeout) {
296 join_thread_locked();
297 }
298 return status;
299 }
300
301 } // namespace implementation
302 } // namespace V1_0
303 } // namespace neuralnetworks
304 } // namespace hardware
305 } // namespace android
306
307 #endif // ANDROID_HARDWARE_NEURALNETWORKS_V1_0_CALLBACKS_H
308