1 #ifndef ANDROID_HARDWARE_NEURALNETWORKS_V1_0_CALLBACKS_H
2 #define ANDROID_HARDWARE_NEURALNETWORKS_V1_0_CALLBACKS_H
3 
4 #include <android/hardware/neuralnetworks/1.0/IExecutionCallback.h>
5 #include <android/hardware/neuralnetworks/1.0/IPreparedModelCallback.h>
6 #include <chrono>
7 #include <condition_variable>
8 #include <functional>
9 #include <hidl/MQDescriptor.h>
10 #include <hidl/Status.h>
11 #include <mutex>
12 #include <thread>
13 
14 namespace android {
15 namespace hardware {
16 namespace neuralnetworks {
17 namespace V1_0 {
18 namespace implementation {
19 
20 /**
21  * The CallbackBase class is used internally by the NeuralNetworks runtime to
22  * synchronize between different threads. An asynchronous task is launched
23  * paired with a callback object. When a client thread requires the output being
24  * generated by the asynchronous task, the client thread can wait for the result
25  * and be blocked until it has completed or a timeout condition has been
26  * reached. Any wait* may safely be called concurrently, even on the same
27  * callback object. When the asynchronous task has finished its workload, it
28  * must immediately call "notify". If the asynchronous task has failed to launch,
29  * the function that tried to launch the asynchronous task must immediately call
30  * "notify". This "notify" call awakens any client threads waiting on the
31  * callback object.
32  *
33  * The CallbackBase class implements some of the base synchronization common to
34  * both PrepareModelCallback and ExecutionCallback. For consistency, any HIDL
35  * callback class must inherit from CallbackBase as well as the HIDL callback
36  * interface it implements.
37  *
38  * This class exists to enable synchronization across HIDL. When synchronization
39  * is only required in the same process, consider using std::future, std::mutex,
40  * std::condition_variable, or std::experimental::latch instead.
41  */
42 class CallbackBase {
43  public:
44     CallbackBase();
45     ~CallbackBase();
46 
47     /**
48      * CallbackBase::wait blocks until notify has been called on the callback
49      * object.
50      */
51     void wait();
52 
53     /**
54      * CallbackBase::wait_for blocks until notify has been called on the
55      * callback object or the time duration from the time the wait_for function
56      * was called has expired, whichever comes first.
57      *
58      * @return Status std::cv_status::no_timeout if the callback was notified
59      *                before the time duration expired, std::cv_status::timeout
60      *                otherwise.
61      */
62     template<class Rep, class Period>
63     std::cv_status wait_for(const std::chrono::duration<Rep,Period>& timeout_duration);
64 
65     /**
66      * CallbackBase::on_finish binds a function to the callback object. This
67      * bound function will be executed when CallbackBase::notify is called,
68      * before any calls to wait* return. (Note that CallbackBase::wait_for can
69      * return std::cv_status::timeout before CallbackBase::notify is called for
70      * the first time, and hence before the bound function is executed.)
71      *
72      * The bound function must not synchronize with or otherwise access the
73      * callback object it is bound to, as this could cause a deadlock.
74      *
75      * CallbackBase::on_finish can be called at most once on a given callback
76      * object, and the call to CallbackBase::on_finish must finish before
77      * CallbackBase::notify is called.
78      *
79      * @param post_work Function to be invoked the first time
80      *                  CallbackBase::notify is called. Must have a target --
81      *                  i.e., must not compare equal to nullptr. post_work
82      *                  returns true if it successfully completes, false if it
83      *                  fails.
84      * @return bool True if the function was successfully bound, false if
85      *              unsuccessful.
86      *
87      * TODO: Why does the return value of the callback matter?
88      */
89     bool on_finish(std::function<bool(void)> post_work);
90 
91     /**
92      * CallbackBase::bind_thread binds a thread to the event for later use by
93      * CallbackBase::join_thread.
94      *
95      * The thread must be passed using std::move.
96      *
97      * Once a thread is bound with CallbackBase::bind_thread, the client code
98      * should ensure that one of the following occurs before the event is
99      * destroyed:
100      * - CallbackBase::join_thread has been called.
101      * - CallbackBase::wait has been called.
102      * - CallbackBase::wait_for has been called and returned other than
103      *   std::cv_status::no_timeout.
104      *
105      * The bound thread shall not call any CallbackBase method with the
106      * exception of CallbackBase::notify, which it must call when the thread has
107      * finished its computation.
108      *
109      * CallbackBase::bind_thread can be called at most once on a given callback
110      * object.
111      *
112      * @param asyncThread Thread to be bound to the callback object. The thread
113      *                    object must represent a thread of execution -- i.e.,
114      *                    asyncThread.joinable() must be true.
115      * @return bool True if successful, false if thread was not properly bound.
116      */
117     bool bind_thread(std::thread&& asyncThread);
118 
119     /**
120      * CallbackBase::join_thread ensures that the thread (if any) bound to this
121      * event with CallbackBase::bind_thread has fully finished and cleaned its
122      * resources. It is legal to call this function multiple times, concurrently
123      * or sequentially.
124      */
125     void join_thread();
126 
127  protected:
128     /**
129      * CallbackBase::notify enables all prior and future wait* calls on the
130      * callback object to proceed. The call to CallbackBase::notify happens
131      * before any wait* calls on this callback object return (except in the case
132      * of wait_for timing out). The asynchronous call the callback object is
133      * paired with must ensure that any update to state that should be visible
134      * to the caller of wait* happens before the call to CallbackBase::notify.
135      *
136      * CallbackBase::notify must be called exactly once on a given callback
137      * object.
138      */
139     void notify();
140 
141  private:
142     // Same as CallbackBase::join_thread but assumes we already hold a lock on
143     // mMutex.
144     void join_thread_locked();
145 
146     bool                      mNotified;
147     std::mutex                mMutex;
148     std::condition_variable   mCondition;
149     std::function<bool(void)> mPostWork;
150     std::thread               mThread;
151 };
152 
153 /**
154  * The PreparedModelCallback class is used to receive the error status of
155  * preparing a model as well as the prepared model from a task executing
156  * asynchronously with respect to the runtime. If a calling thread calls wait*
157  * or get* on a PreparedModelCallback object and the corresponding asynchronous
158  * task has not finished preparing the model, the calling thread will block
159  * until the asynchronous task has called notify. For more information on the
160  * synchronization behavior, refer to the CallbackBase class.
161  *
162  * This class inherits the basic blocking and signaling calls from
163  * CallbackBase, and implements the HIDL notify call from
164  * IPreparedModelCallback. This callback object is passed as an argument to
165  * IDevice::prepareModel.
166  */
167 class PreparedModelCallback : public CallbackBase, public IPreparedModelCallback {
168  public:
169     PreparedModelCallback();
170     ~PreparedModelCallback() override;
171 
172     /**
173      * IPreparedModelCallback::notify marks the callback object with the return
174      * status of the asynchronous model preparation along with the prepared
175      * model, and calls CallbackBase::notify, enabling all prior and future
176      * wait* calls on the PreparedModelCallback object to proceed. For more
177      * information on the synchronization behavior, refer to the CallbackBase
178      * class.
179      *
180      * IPreparedModelCallback::notify must be called exactly once on a given
181      * PreparedModelCallback object.
182      *
183      * @param status Error status returned from asynchronously preparing the
184      *               model; will be:
185      *               - NONE if the asynchronous preparation was successful
186      *               - DEVICE_UNAVAILABLE if driver is offline or busy
187      *               - GENERAL_FAILURE if there is an unspecified error
188      *               - INVALID_ARGUMENT if the input model is invalid
189      * @param preparedModel Returned model that has been prepared for execution,
190      *                      nullptr if the model was unable to be prepared.
191      */
192     Return<void> notify(ErrorStatus status, const sp<IPreparedModel>& preparedModel) override;
193 
194     /**
195      * Retrieves the error status returned from the asynchronous task launched
196      * by IDevice::prepareModel. If IDevice::prepareModel has not finished
197      * asynchronously preparing the model, this call will block until the
198      * asynchronous task notifies the object.
199      *
200      * @return status Error status returned from asynchronously preparing the
201      *                model; will be:
202      *                - NONE if the asynchronous preparation was successful
203      *                - DEVICE_UNAVAILABLE if driver is offline or busy
204      *                - GENERAL_FAILURE if there is an unspecified error
205      *                - INVALID_ARGUMENT if the input model is invalid
206      */
207     ErrorStatus getStatus();
208 
209     /**
210      * Retrieves the model that has been prepared for execution from the
211      * asynchronous task launched by IDevice::prepareModel. If
212      * IDevice::prepareModel has not finished asynchronously preparing the
213      * model, this call will block until the asynchronous task notifies the
214      * object.
215      *
216      * @return preparedModel Returned model that has been prepared for
217      *                       execution, nullptr if the model was unable to be
218      *                       prepared.
219      */
220     sp<IPreparedModel> getPreparedModel();
221 
222  private:
223     ErrorStatus        mErrorStatus;
224     sp<IPreparedModel> mPreparedModel;
225 };
226 
227 /**
228  * The ExecutionCallback class is used to receive the error status of the
229  * execution from a task executing asynchronously with respect to the runtime.
230  * If a calling thread calls wait* or get* on a PreparedModelCallback object and
231  * the corresponding asynchronous task has not finished the execution, the
232  * calling thread will block until the asynchronous task has called notify. For
233  * more information on the synchronization behavior, refer to the CallbackBase
234  * class.
235  *
236  * This class inherits the basic blocking and signaling calls from
237  * CallbackBase, and implements the HIDL notify call from
238  * IExecutionCallback. This callback object is passed as an argument to
239  * IPreparedModel::execute.
240  */
241 class ExecutionCallback : public CallbackBase,  public IExecutionCallback {
242  public:
243     ExecutionCallback();
244     ~ExecutionCallback() override;
245 
246     /**
247      * IExecutionCallback::notify marks the callback object with the return
248      * status of the asynchronous execution that held this callback and enables
249      * all prior and future wait* calls on the ExecutionCallback object to
250      * proceed. For more information on the synchronization behavior, refer to
251      * the CallbackBase class.
252      *
253      * IExecutionCallback::notify must be called exactly once on a given
254      * ExecutionCallback object.
255      *
256      * @param status Error status returned from asynchronously preparing the
257      *               model; will be:
258      *               - NONE if the asynchronous execution was successful
259      *               - DEVICE_UNAVAILABLE if driver is offline or busy
260      *               - GENERAL_FAILURE if there is an unspecified error
261      *               - OUTPUT_INSUFFICIENT_SIZE if provided output buffer is
262      *                 not large enough to store the resultant values
263      *               - INVALID_ARGUMENT if the input request is invalid
264      */
265     Return<void> notify(ErrorStatus status) override;
266 
267     /**
268      * Retrieves the error status returned from the asynchronous task launched
269      * by IPreparedModel::execute. If IPreparedModel::execute has not finished
270      * asynchronously executing, this call will block until the asynchronous task
271      * notifies the object.
272      *
273      * @return status Error status returned from asynchronously preparing the
274      *                model; will be:
275      *                - NONE if the asynchronous execution was successful
276      *                - DEVICE_UNAVAILABLE if driver is offline or busy
277      *                - GENERAL_FAILURE if there is an unspecified error
278      *                - OUTPUT_INSUFFICIENT_SIZE if provided output buffer is
279      *                  not large enough to store the resultant values
280      *                - INVALID_ARGUMENT if the input request is invalid
281      */
282     ErrorStatus getStatus();
283 
284  private:
285     ErrorStatus mErrorStatus;
286 };
287 
288 
289 // template function implementation(s) below this point
290 
291 template<class Rep, class Period>
wait_for(const std::chrono::duration<Rep,Period> & timeout_duration)292 std::cv_status CallbackBase::wait_for(const std::chrono::duration<Rep,Period>& timeout_duration) {
293     std::unique_lock<std::mutex> lock(mMutex);
294     std::cv_status status = mCondition.wait_for(lock, timeout_duration, [this]{return mNotified;});
295     if (status != std::cv_status::timeout) {
296         join_thread_locked();
297     }
298     return status;
299 }
300 
301 }  // namespace implementation
302 }  // namespace V1_0
303 }  // namespace neuralnetworks
304 }  // namespace hardware
305 }  // namespace android
306 
307 #endif  // ANDROID_HARDWARE_NEURALNETWORKS_V1_0_CALLBACKS_H
308