1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef ANDROID_HARDWARE_NEURALNETWORKS_V1_0_CALLBACKS_H
18 #define ANDROID_HARDWARE_NEURALNETWORKS_V1_0_CALLBACKS_H
19 
20 #include <android/hardware/neuralnetworks/1.0/IExecutionCallback.h>
21 #include <android/hardware/neuralnetworks/1.0/IPreparedModelCallback.h>
22 #include <chrono>
23 #include <condition_variable>
24 #include <functional>
25 #include <hidl/MQDescriptor.h>
26 #include <hidl/Status.h>
27 #include <mutex>
28 #include <thread>
29 
30 namespace android {
31 namespace hardware {
32 namespace neuralnetworks {
33 namespace V1_0 {
34 namespace implementation {
35 
36 using ::android::hardware::hidl_array;
37 using ::android::hardware::hidl_memory;
38 using ::android::hardware::hidl_string;
39 using ::android::hardware::hidl_vec;
40 using ::android::hardware::Return;
41 using ::android::hardware::Void;
42 using ::android::sp;
43 
44 /**
45  * The CallbackBase class is used internally by the NeuralNetworks runtime to
46  * synchronize between different threads. An asynchronous task is launched
47  * paired with a callback object. When a client thread requires the output being
48  * generated by the asynchronous task, the client thread can wait for the result
49  * and be blocked until it has completed or a timeout condition has been
50  * reached. Any wait* may safely be called concurrently, even on the same
51  * callback object. When the asynchronous task has finished its workload, it
52  * must immediately call "notify". If the asynchronous task has failed to launch,
53  * the function that tried to launch the asynchronous task must immediately call
54  * "notify". This "notify" call awakens any client threads waiting on the
55  * callback object.
56  *
57  * The CallbackBase class implements some of the base synchronization common to
58  * both PrepareModelCallback and ExecutionCallback. For consistency, any HIDL
59  * callback class must inherit from CallbackBase as well as the HIDL callback
60  * interface it implements.
61  *
62  * This class exists to enable synchronization across HIDL. When synchronization
63  * is only required in the same process, consider using std::future, std::mutex,
64  * std::condition_variable, or std::experimental::latch instead.
65  */
66 class CallbackBase {
67  public:
68     CallbackBase();
69     ~CallbackBase();
70 
71     /**
72      * CallbackBase::wait blocks until notify has been called on the callback
73      * object.
74      */
75     void wait();
76 
77     /**
78      * CallbackBase::wait_for blocks until notify has been called on the
79      * callback object or the time duration from the time the wait_for function
80      * was called has expired, whichever comes first.
81      *
82      * @return Status std::cv_status::no_timeout if the callback was notified
83      *                before the time duration expired, std::cv_status::timeout
84      *                otherwise.
85      */
86     template<class Rep, class Period>
87     std::cv_status wait_for(const std::chrono::duration<Rep,Period>& timeout_duration);
88 
89     /**
90      * CallbackBase::on_finish binds a function to the callback object. This
91      * bound function will be executed when CallbackBase::notify is called,
92      * before any calls to wait* return. (Note that CallbackBase::wait_for can
93      * return std::cv_status::timeout before CallbackBase::notify is called for
94      * the first time, and hence before the bound function is executed.)
95      *
96      * The bound function must not synchronize with or otherwise access the
97      * callback object it is bound to, as this could cause a deadlock.
98      *
99      * CallbackBase::on_finish can be called at most once on a given callback
100      * object, and the call to CallbackBase::on_finish must finish before
101      * CallbackBase::notify is called.
102      *
103      * @param post_work Function to be invoked the first time
104      *                  CallbackBase::notify is called. Must have a target --
105      *                  i.e., must not compare equal to nullptr. post_work
106      *                  returns true if it successfully completes, false if it
107      *                  fails.
108      * @return bool True if the function was successfully bound, false if
109      *              unsuccessful.
110      *
111      * TODO: Why does the return value of the callback matter?
112      */
113     bool on_finish(std::function<bool(void)> post_work);
114 
115     /**
116      * CallbackBase::bind_thread binds a thread to the event for later use by
117      * CallbackBase::join_thread.
118      *
119      * The thread must be passed using std::move.
120      *
121      * Once a thread is bound with CallbackBase::bind_thread, the client code
122      * should ensure that one of the following occurs before the event is
123      * destroyed:
124      * - CallbackBase::join_thread has been called.
125      * - CallbackBase::wait has been called.
126      * - CallbackBase::wait_for has been called and returned other than
127      *   std::cv_status::no_timeout.
128      *
129      * The bound thread shall not call any CallbackBase method with the
130      * exception of CallbackBase::notify, which it must call when the thread has
131      * finished its computation.
132      *
133      * CallbackBase::bind_thread can be called at most once on a given callback
134      * object.
135      *
136      * @param asyncThread Thread to be bound to the callback object. The thread
137      *                    object must represent a thread of execution -- i.e.,
138      *                    asyncThread.joinable() must be true.
139      * @return bool True if successful, false if thread was not properly bound.
140      */
141     bool bind_thread(std::thread&& asyncThread);
142 
143     /**
144      * CallbackBase::join_thread ensures that the thread (if any) bound to this
145      * event with CallbackBase::bind_thread has fully finished and cleaned its
146      * resources. It is legal to call this function multiple times, concurrently
147      * or sequentially.
148      */
149     void join_thread();
150 
151  protected:
152     /**
153      * CallbackBase::notify enables all prior and future wait* calls on the
154      * callback object to proceed. The call to CallbackBase::notify happens
155      * before any wait* calls on this callback object return (except in the case
156      * of wait_for timing out). The asynchronous call the callback object is
157      * paired with must ensure that any update to state that should be visible
158      * to the caller of wait* happens before the call to CallbackBase::notify.
159      *
160      * CallbackBase::notify must be called exactly once on a given callback
161      * object.
162      */
163     void notify();
164 
165  private:
166     // Same as CallbackBase::join_thread but assumes we already hold a lock on
167     // mMutex.
168     void join_thread_locked();
169 
170     bool                      mNotified;
171     std::mutex                mMutex;
172     std::condition_variable   mCondition;
173     std::function<bool(void)> mPostWork;
174     std::thread               mThread;
175 };
176 
177 /**
178  * The PreparedModelCallback class is used to receive the error status of
179  * preparing a model as well as the prepared model from a task executing
180  * asynchronously with respect to the runtime. If a calling thread calls wait*
181  * or get* on a PreparedModelCallback object and the corresponding asynchronous
182  * task has not finished preparing the model, the calling thread will block
183  * until the asynchronous task has called notify. For more information on the
184  * synchronization behavior, refer to the CallbackBase class.
185  *
186  * This class inherits the basic blocking and signaling calls from
187  * CallbackBase, and implements the HIDL notify call from
188  * IPreparedModelCallback. This callback object is passed as an argument to
189  * IDevice::prepareModel.
190  */
191 class PreparedModelCallback : public CallbackBase, public IPreparedModelCallback {
192  public:
193     PreparedModelCallback();
194     ~PreparedModelCallback() override;
195 
196     /**
197      * IPreparedModelCallback::notify marks the callback object with the return
198      * status of the asynchronous model preparation along with the prepared
199      * model, and calls CallbackBase::notify, enabling all prior and future
200      * wait* calls on the PreparedModelCallback object to proceed. For more
201      * information on the synchronization behavior, refer to the CallbackBase
202      * class.
203      *
204      * IPreparedModelCallback::notify must be called exactly once on a given
205      * PreparedModelCallback object.
206      *
207      * @param status Error status returned from asynchronously preparing the
208      *               model; will be:
209      *               - NONE if the asynchronous preparation was successful
210      *               - DEVICE_UNAVAILABLE if driver is offline or busy
211      *               - GENERAL_FAILURE if there is an unspecified error
212      *               - INVALID_ARGUMENT if the input model is invalid
213      * @param preparedModel Returned model that has been prepared for execution,
214      *                      nullptr if the model was unable to be prepared.
215      */
216     Return<void> notify(ErrorStatus status, const sp<IPreparedModel>& preparedModel) override;
217 
218     /**
219      * Retrieves the error status returned from the asynchronous task launched
220      * by IDevice::prepareModel. If IDevice::prepareModel has not finished
221      * asynchronously preparing the model, this call will block until the
222      * asynchronous task notifies the object.
223      *
224      * @return status Error status returned from asynchronously preparing the
225      *                model; will be:
226      *                - NONE if the asynchronous preparation was successful
227      *                - DEVICE_UNAVAILABLE if driver is offline or busy
228      *                - GENERAL_FAILURE if there is an unspecified error
229      *                - INVALID_ARGUMENT if the input model is invalid
230      */
231     ErrorStatus getStatus();
232 
233     /**
234      * Retrieves the model that has been prepared for execution from the
235      * asynchronous task launched by IDevice::prepareModel. If
236      * IDevice::prepareModel has not finished asynchronously preparing the
237      * model, this call will block until the asynchronous task notifies the
238      * object.
239      *
240      * @return preparedModel Returned model that has been prepared for
241      *                       execution, nullptr if the model was unable to be
242      *                       prepared.
243      */
244     sp<IPreparedModel> getPreparedModel();
245 
246  private:
247     ErrorStatus        mErrorStatus;
248     sp<IPreparedModel> mPreparedModel;
249 };
250 
251 /**
252  * The ExecutionCallback class is used to receive the error status of the
253  * execution from a task executing asynchronously with respect to the runtime.
254  * If a calling thread calls wait* or get* on a PreparedModelCallback object and
255  * the corresponding asynchronous task has not finished the execution, the
256  * calling thread will block until the asynchronous task has called notify. For
257  * more information on the synchronization behavior, refer to the CallbackBase
258  * class.
259  *
260  * This class inherits the basic blocking and signaling calls from
261  * CallbackBase, and implements the HIDL notify call from
262  * IExecutionCallback. This callback object is passed as an argument to
263  * IPreparedModel::execute.
264  */
265 class ExecutionCallback : public CallbackBase,  public IExecutionCallback {
266  public:
267     ExecutionCallback();
268     ~ExecutionCallback() override;
269 
270     /**
271      * IExecutionCallback::notify marks the callback object with the return
272      * status of the asynchronous execution that held this callback and enables
273      * all prior and future wait* calls on the ExecutionCallback object to
274      * proceed. For more information on the synchronization behavior, refer to
275      * the CallbackBase class.
276      *
277      * IExecutionCallback::notify must be called exactly once on a given
278      * ExecutionCallback object.
279      *
280      * @param status Error status returned from asynchronously preparing the
281      *               model; will be:
282      *               - NONE if the asynchronous execution was successful
283      *               - DEVICE_UNAVAILABLE if driver is offline or busy
284      *               - GENERAL_FAILURE if there is an unspecified error
285      *               - OUTPUT_INSUFFICIENT_SIZE if provided output buffer is
286      *                 not large enough to store the resultant values
287      *               - INVALID_ARGUMENT if the input request is invalid
288      */
289     Return<void> notify(ErrorStatus status) override;
290 
291     /**
292      * Retrieves the error status returned from the asynchronous task launched
293      * by IPreparedModel::execute. If IPreparedModel::execute has not finished
294      * asynchronously executing, this call will block until the asynchronous task
295      * notifies the object.
296      *
297      * @return status Error status returned from asynchronously preparing the
298      *                model; will be:
299      *                - NONE if the asynchronous execution was successful
300      *                - DEVICE_UNAVAILABLE if driver is offline or busy
301      *                - GENERAL_FAILURE if there is an unspecified error
302      *                - OUTPUT_INSUFFICIENT_SIZE if provided output buffer is
303      *                  not large enough to store the resultant values
304      *                - INVALID_ARGUMENT if the input request is invalid
305      */
306     ErrorStatus getStatus();
307 
308  private:
309     ErrorStatus mErrorStatus;
310 };
311 
312 
313 // template function implementation(s) below this point
314 
315 template<class Rep, class Period>
wait_for(const std::chrono::duration<Rep,Period> & timeout_duration)316 std::cv_status CallbackBase::wait_for(const std::chrono::duration<Rep,Period>& timeout_duration) {
317     std::unique_lock<std::mutex> lock(mMutex);
318     std::cv_status status = mCondition.wait_for(lock, timeout_duration, [this]{return mNotified;});
319     if (status != std::cv_status::timeout) {
320         join_thread_locked();
321     }
322     return status;
323 }
324 
325 }  // namespace implementation
326 }  // namespace V1_0
327 }  // namespace neuralnetworks
328 }  // namespace hardware
329 }  // namespace android
330 
331 #endif  // ANDROID_HARDWARE_NEURALNETWORKS_V1_0_CALLBACKS_H
332