1 #include "Callbacks.h"
2 #include <android-base/logging.h>
3 
4 namespace android {
5 namespace hardware {
6 namespace neuralnetworks {
7 namespace V1_0 {
8 namespace implementation {
9 
CallbackBase()10 CallbackBase::CallbackBase() : mNotified(false) {}
11 
~CallbackBase()12 CallbackBase::~CallbackBase() {
13     // Note that we cannot call CallbackBase::join_thread from here:
14     // CallbackBase is intended to be reference counted, and it is possible that
15     // the reference count drops to zero in the bound thread, causing the
16     // bound thread to call this destructor. If a thread tries to join
17     // itself, it throws an exception, producing a message like the
18     // following:
19     //
20     //     terminating with uncaught exception of type std::__1::system_error:
21     //     thread::join failed: Resource deadlock would occur
22 }
23 
wait()24 void CallbackBase::wait() {
25     std::unique_lock<std::mutex> lock(mMutex);
26     mCondition.wait(lock, [this]{return mNotified;});
27     join_thread_locked();
28 }
29 
on_finish(std::function<bool (void)> post_work)30 bool CallbackBase::on_finish(std::function<bool(void)> post_work) {
31     std::lock_guard<std::mutex> lock(mMutex);
32     if (mPostWork != nullptr) {
33         LOG(ERROR) << "CallbackBase::on_finish -- a post-work function has already been bound to "
34                    "this callback object";
35         return false;
36     }
37     if (post_work == nullptr) {
38         LOG(ERROR) << "CallbackBase::on_finish -- the new post-work function is invalid";
39         return false;
40     }
41     mPostWork = std::move(post_work);
42     return true;
43 }
44 
bind_thread(std::thread && asyncThread)45 bool CallbackBase::bind_thread(std::thread&& asyncThread) {
46     std::lock_guard<std::mutex> lock(mMutex);
47     if (mThread.joinable()) {
48         LOG(ERROR) << "CallbackBase::bind_thread -- a thread has already been bound to this "
49                    "callback object";
50         return false;
51     }
52     if (!asyncThread.joinable()) {
53         LOG(ERROR) << "CallbackBase::bind_thread -- the new thread is not joinable";
54         return false;
55     }
56     mThread = std::move(asyncThread);
57     return true;
58 }
59 
join_thread()60 void CallbackBase::join_thread() {
61     std::lock_guard<std::mutex> lock(mMutex);
62     join_thread_locked();
63 }
64 
notify()65 void CallbackBase::notify() {
66     {
67         std::lock_guard<std::mutex> lock(mMutex);
68         mNotified = true;
69         if (mPostWork != nullptr) {
70             bool success = mPostWork();
71             if (!success) {
72                 LOG(ERROR) << "CallbackBase::notify -- post work failed";
73             }
74         }
75     }
76     mCondition.notify_all();
77 }
78 
join_thread_locked()79 void CallbackBase::join_thread_locked() {
80     if (mThread.joinable()) {
81         mThread.join();
82     }
83 }
84 
PreparedModelCallback()85 PreparedModelCallback::PreparedModelCallback() :
86         mErrorStatus(ErrorStatus::GENERAL_FAILURE), mPreparedModel(nullptr) {}
87 
~PreparedModelCallback()88 PreparedModelCallback::~PreparedModelCallback() {}
89 
notify(ErrorStatus errorStatus,const sp<IPreparedModel> & preparedModel)90 Return<void> PreparedModelCallback::notify(ErrorStatus errorStatus,
91                                            const sp<IPreparedModel>& preparedModel) {
92     mErrorStatus = errorStatus;
93     mPreparedModel = preparedModel;
94     CallbackBase::notify();
95     return Void();
96 }
97 
getStatus()98 ErrorStatus PreparedModelCallback::getStatus() {
99     wait();
100     return mErrorStatus;
101 }
102 
getPreparedModel()103 sp<IPreparedModel> PreparedModelCallback::getPreparedModel() {
104     wait();
105     return mPreparedModel;
106 }
107 
ExecutionCallback()108 ExecutionCallback::ExecutionCallback() : mErrorStatus(ErrorStatus::GENERAL_FAILURE) {}
109 
~ExecutionCallback()110 ExecutionCallback::~ExecutionCallback() {}
111 
notify(ErrorStatus errorStatus)112 Return<void> ExecutionCallback::notify(ErrorStatus errorStatus) {
113     mErrorStatus = errorStatus;
114     CallbackBase::notify();
115     return Void();
116 }
117 
getStatus()118 ErrorStatus ExecutionCallback::getStatus() {
119     wait();
120     return mErrorStatus;
121 }
122 
123 }  // namespace implementation
124 }  // namespace V1_0
125 }  // namespace neuralnetworks
126 }  // namespace hardware
127 }  // namespace android
128