1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "ExecutionCallback"
18 
19 #include "ExecutionCallback.h"
20 
21 #include <android-base/logging.h>
22 
23 #include <limits>
24 #include <utility>
25 #include <vector>
26 
27 namespace android::nn {
28 
notify(ErrorStatus status,const std::vector<OutputShape> & outputShapes,const Timing & timing)29 void ExecutionCallback::notify(ErrorStatus status, const std::vector<OutputShape>& outputShapes,
30                                const Timing& timing) {
31     notifyInternal(status, outputShapes, timing);
32 }
33 
wait() const34 void ExecutionCallback::wait() const {
35     std::unique_lock<std::mutex> lock(mMutex);
36     mCondition.wait(lock, [this] { return mNotified; });
37 
38     /*
39      * Note that we cannot call std::thread::join from ExecutionCallback's
40      * destructor: ExecutionCallback is intended to be reference counted, and it
41      * is possible that the reference count drops to zero in the bound thread,
42      * causing the bound thread to call this destructor. If a thread tries to
43      * join itself, it throws an exception, producing a message like the
44      * following:
45      *
46      *     terminating with uncaught exception of type std::__1::system_error:
47      *     thread::join failed: Resource deadlock would occur
48      */
49     if (mThread.joinable()) {
50         mThread.join();
51     }
52 }
53 
getStatus() const54 ErrorStatus ExecutionCallback::getStatus() const {
55     wait();
56     return mErrorStatus;
57 }
58 
getOutputShapes() const59 const std::vector<OutputShape>& ExecutionCallback::getOutputShapes() const {
60     wait();
61     return mOutputShapes;
62 }
63 
getTiming() const64 Timing ExecutionCallback::getTiming() const {
65     wait();
66     return mTiming;
67 }
68 
bindThread(std::thread asyncThread)69 bool ExecutionCallback::bindThread(std::thread asyncThread) {
70     std::lock_guard<std::mutex> lock(mMutex);
71 
72     // Ensure ExecutionCallback object does not already have a thread bound
73     if (mThread.joinable()) {
74         LOG(ERROR) << "ExecutionCallback::bindThread -- a thread has already been bound to this "
75                       "callback object";
76         return false;
77     }
78 
79     // Ensure the new thread is valid
80     if (!asyncThread.joinable()) {
81         LOG(ERROR) << "ExecutionCallback::bindThread -- the new thread is not joinable";
82         return false;
83     }
84 
85     mThread = std::move(asyncThread);
86     return true;
87 }
88 
setOnFinish(const ExecutionFinish & finish)89 void ExecutionCallback::setOnFinish(const ExecutionFinish& finish) {
90     std::lock_guard<std::mutex> hold(mMutex);
91 
92     // Ensure ExecutionCallback object does not already have a "finish" callback
93     if (mOnFinish != nullptr) {
94         LOG(ERROR) << "ExecutionCallback::setOnFinish -- object already has a \"finish\" callback";
95         return;
96     }
97 
98     // Ensure new "finish" callback is valid
99     if (finish == nullptr) {
100         LOG(ERROR) << "ExecutionCallback::setOnFinish -- \"finish\" callback is invalid";
101         return;
102     }
103 
104     // Essure ExecutionCallback object has not already been notified
105     if (mNotified) {
106         LOG(ERROR) << "ExecutionCallback::setOnFinish -- ExecutionCallback has already been "
107                       "notified with results";
108         return;
109     }
110 
111     mOnFinish = finish;
112 }
113 
notifyInternal(ErrorStatus errorStatus,std::vector<OutputShape> outputShapes,Timing timing)114 void ExecutionCallback::notifyInternal(ErrorStatus errorStatus,
115                                        std::vector<OutputShape> outputShapes, Timing timing) {
116     // check results
117     {
118         if (errorStatus == ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
119             // outputShapes must not be empty if OUTPUT_INSUFFICIENT_SIZE.
120             if (outputShapes.size() == 0) {
121                 LOG(ERROR)
122                         << "Notified with empty output shape vector when OUTPUT_INSUFFICIENT_SIZE";
123                 errorStatus = ErrorStatus::GENERAL_FAILURE;
124                 outputShapes = {};
125                 timing = {};
126             }
127         } else if (errorStatus != ErrorStatus::NONE) {
128             // outputShapes must be empty if errorStatus is neither NONE nor
129             // OUTPUT_INSUFFICIENT_SIZE.
130             if (outputShapes.size() != 0) {
131                 LOG(ERROR) << "Notified with non-empty output shape vector when error status is "
132                               "neither NONE nor OUTPUT_INSUFFICIENT_SIZE";
133                 errorStatus = ErrorStatus::GENERAL_FAILURE;
134                 outputShapes = {};
135                 timing = {};
136             }
137         }
138     }
139 
140     // store results
141     {
142         std::lock_guard<std::mutex> hold(mMutex);
143 
144         // quick-return if object has already been notified
145         if (mNotified) {
146             return;
147         }
148 
149         mErrorStatus = errorStatus;
150         mOutputShapes = std::move(outputShapes);
151         mTiming = timing;
152         mNotified = true;
153 
154         if (mOnFinish != nullptr) {
155             ErrorStatus status = mOnFinish(mErrorStatus, mOutputShapes);
156             mOnFinish = nullptr;
157             if (status != ErrorStatus::NONE) {
158                 mErrorStatus = status;
159             }
160         }
161     }
162     mCondition.notify_all();
163 }
164 
165 }  // namespace android::nn
166