1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "Burst.h"
18 
19 #include "Conversions.h"
20 #include "Utils.h"
21 
22 #include <android-base/logging.h>
23 #include <android/binder_auto_utils.h>
24 #include <nnapi/IBurst.h>
25 #include <nnapi/IExecution.h>
26 #include <nnapi/Result.h>
27 #include <nnapi/TypeUtils.h>
28 #include <nnapi/Types.h>
29 #include <nnapi/hal/HandleError.h>
30 
31 #include <memory>
32 #include <mutex>
33 #include <optional>
34 #include <utility>
35 
36 namespace aidl::android::hardware::neuralnetworks::utils {
37 namespace {
38 
39 class BurstExecution final : public nn::IExecution,
40                              public std::enable_shared_from_this<BurstExecution> {
41     struct PrivateConstructorTag {};
42 
43   public:
44     static nn::GeneralResult<std::shared_ptr<const BurstExecution>> create(
45             std::shared_ptr<const Burst> burst, Request request,
46             std::vector<int64_t> memoryIdentifierTokens, bool measure, int64_t loopTimeoutDuration,
47             hal::utils::RequestRelocation relocation,
48             std::vector<Burst::OptionalCacheHold> cacheHolds);
49 
50     BurstExecution(PrivateConstructorTag tag, std::shared_ptr<const Burst> burst, Request request,
51                    std::vector<int64_t> memoryIdentifierTokens, bool measure,
52                    int64_t loopTimeoutDuration, hal::utils::RequestRelocation relocation,
53                    std::vector<Burst::OptionalCacheHold> cacheHolds);
54 
55     nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> compute(
56             const nn::OptionalTimePoint& deadline) const override;
57 
58     nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>> computeFenced(
59             const std::vector<nn::SyncFence>& waitFor, const nn::OptionalTimePoint& deadline,
60             const nn::OptionalDuration& timeoutDurationAfterFence) const override;
61 
62   private:
63     const std::shared_ptr<const Burst> kBurst;
64     const Request kRequest;
65     const std::vector<int64_t> kMemoryIdentifierTokens;
66     const bool kMeasure;
67     const int64_t kLoopTimeoutDuration;
68     const hal::utils::RequestRelocation kRelocation;
69     const std::vector<Burst::OptionalCacheHold> kCacheHolds;
70 };
71 
convertExecutionResults(const std::vector<OutputShape> & outputShapes,const Timing & timing)72 nn::GeneralResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> convertExecutionResults(
73         const std::vector<OutputShape>& outputShapes, const Timing& timing) {
74     return std::make_pair(NN_TRY(nn::convert(outputShapes)), NN_TRY(nn::convert(timing)));
75 }
76 
77 }  // namespace
78 
MemoryCache(std::shared_ptr<aidl_hal::IBurst> burst)79 Burst::MemoryCache::MemoryCache(std::shared_ptr<aidl_hal::IBurst> burst)
80     : kBurst(std::move(burst)) {}
81 
getOrCacheMemory(const nn::SharedMemory & memory)82 std::pair<int64_t, Burst::MemoryCache::SharedCleanup> Burst::MemoryCache::getOrCacheMemory(
83         const nn::SharedMemory& memory) {
84     std::lock_guard lock(mMutex);
85 
86     // Get the cache payload or create it (with default values) if it does not exist.
87     auto& cachedPayload = mCache[memory];
88     {
89         const auto& [identifier, maybeCleaner] = cachedPayload;
90         // If cache payload already exists, reuse it.
91         if (auto cleaner = maybeCleaner.lock()) {
92             return std::make_pair(identifier, std::move(cleaner));
93         }
94     }
95 
96     // If the code reaches this point, the cached payload either did not exist or expired prior to
97     // this call.
98 
99     // Allocate a new identifier.
100     CHECK_LT(mUnusedIdentifier, std::numeric_limits<int64_t>::max());
101     const int64_t identifier = mUnusedIdentifier++;
102 
103     // Create reference-counted self-cleaning cache object.
104     auto self = weak_from_this();
105     Task cleanup = [memory, identifier, maybeMemoryCache = std::move(self)] {
106         if (const auto memoryCache = maybeMemoryCache.lock()) {
107             memoryCache->tryFreeMemory(memory, identifier);
108         }
109     };
110     auto cleaner = std::make_shared<const Cleanup>(std::move(cleanup));
111 
112     // Store the result in the cache and return it.
113     auto result = std::make_pair(identifier, std::move(cleaner));
114     cachedPayload = result;
115     return result;
116 }
117 
118 std::optional<std::pair<int64_t, Burst::MemoryCache::SharedCleanup>>
getMemoryIfAvailable(const nn::SharedMemory & memory)119 Burst::MemoryCache::getMemoryIfAvailable(const nn::SharedMemory& memory) {
120     std::lock_guard lock(mMutex);
121 
122     // Get the existing cached entry if it exists.
123     const auto iter = mCache.find(memory);
124     if (iter != mCache.end()) {
125         const auto& [identifier, maybeCleaner] = iter->second;
126         if (auto cleaner = maybeCleaner.lock()) {
127             return std::make_pair(identifier, std::move(cleaner));
128         }
129     }
130 
131     // If the code reaches this point, the cached payload did not exist or was actively being
132     // deleted.
133     return std::nullopt;
134 }
135 
tryFreeMemory(const nn::SharedMemory & memory,int64_t identifier)136 void Burst::MemoryCache::tryFreeMemory(const nn::SharedMemory& memory, int64_t identifier) {
137     {
138         std::lock_guard guard(mMutex);
139         // Remove the cached memory and payload if it is present but expired. Note that it may not
140         // be present or may not be expired because another thread may have removed or cached the
141         // same memory object before the current thread locked mMutex in tryFreeMemory.
142         const auto iter = mCache.find(memory);
143         if (iter != mCache.end()) {
144             if (std::get<WeakCleanup>(iter->second).expired()) {
145                 mCache.erase(iter);
146             }
147         }
148     }
149     kBurst->releaseMemoryResource(identifier);
150 }
151 
create(std::shared_ptr<aidl_hal::IBurst> burst)152 nn::GeneralResult<std::shared_ptr<const Burst>> Burst::create(
153         std::shared_ptr<aidl_hal::IBurst> burst) {
154     if (burst == nullptr) {
155         return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
156                << "aidl_hal::utils::Burst::create must have non-null burst";
157     }
158 
159     return std::make_shared<const Burst>(PrivateConstructorTag{}, std::move(burst));
160 }
161 
Burst(PrivateConstructorTag,std::shared_ptr<aidl_hal::IBurst> burst)162 Burst::Burst(PrivateConstructorTag /*tag*/, std::shared_ptr<aidl_hal::IBurst> burst)
163     : kBurst(std::move(burst)), kMemoryCache(std::make_shared<MemoryCache>(kBurst)) {
164     CHECK(kBurst != nullptr);
165 }
166 
cacheMemory(const nn::SharedMemory & memory) const167 Burst::OptionalCacheHold Burst::cacheMemory(const nn::SharedMemory& memory) const {
168     auto [identifier, hold] = kMemoryCache->getOrCacheMemory(memory);
169     return hold;
170 }
171 
execute(const nn::Request & request,nn::MeasureTiming measure,const nn::OptionalTimePoint & deadline,const nn::OptionalDuration & loopTimeoutDuration) const172 nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> Burst::execute(
173         const nn::Request& request, nn::MeasureTiming measure,
174         const nn::OptionalTimePoint& deadline,
175         const nn::OptionalDuration& loopTimeoutDuration) const {
176     // Ensure that request is ready for IPC.
177     std::optional<nn::Request> maybeRequestInShared;
178     hal::utils::RequestRelocation relocation;
179     const nn::Request& requestInShared =
180             NN_TRY(hal::utils::makeExecutionFailure(hal::utils::convertRequestFromPointerToShared(
181                     &request, nn::kDefaultRequestMemoryAlignment, nn::kDefaultRequestMemoryPadding,
182                     &maybeRequestInShared, &relocation)));
183 
184     const auto aidlRequest = NN_TRY(hal::utils::makeExecutionFailure(convert(requestInShared)));
185     const auto aidlMeasure = NN_TRY(hal::utils::makeExecutionFailure(convert(measure)));
186     const auto aidlDeadline = NN_TRY(hal::utils::makeExecutionFailure(convert(deadline)));
187     const auto aidlLoopTimeoutDuration =
188             NN_TRY(hal::utils::makeExecutionFailure(convert(loopTimeoutDuration)));
189 
190     std::vector<int64_t> memoryIdentifierTokens;
191     std::vector<OptionalCacheHold> holds;
192     memoryIdentifierTokens.reserve(requestInShared.pools.size());
193     holds.reserve(requestInShared.pools.size());
194     for (const auto& memoryPool : requestInShared.pools) {
195         if (const auto* memory = std::get_if<nn::SharedMemory>(&memoryPool)) {
196             if (auto cached = kMemoryCache->getMemoryIfAvailable(*memory)) {
197                 auto& [identifier, hold] = *cached;
198                 memoryIdentifierTokens.push_back(identifier);
199                 holds.push_back(std::move(hold));
200                 continue;
201             }
202         }
203         memoryIdentifierTokens.push_back(-1);
204     }
205     CHECK_EQ(requestInShared.pools.size(), memoryIdentifierTokens.size());
206 
207     return executeInternal(aidlRequest, memoryIdentifierTokens, aidlMeasure, aidlDeadline,
208                            aidlLoopTimeoutDuration, relocation);
209 }
210 
executeInternal(const Request & request,const std::vector<int64_t> & memoryIdentifierTokens,bool measure,int64_t deadline,int64_t loopTimeoutDuration,const hal::utils::RequestRelocation & relocation) const211 nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> Burst::executeInternal(
212         const Request& request, const std::vector<int64_t>& memoryIdentifierTokens, bool measure,
213         int64_t deadline, int64_t loopTimeoutDuration,
214         const hal::utils::RequestRelocation& relocation) const {
215     // Ensure that at most one execution is in flight at any given time.
216     const bool alreadyInFlight = mExecutionInFlight.test_and_set();
217     if (alreadyInFlight) {
218         return NN_ERROR() << "IBurst already has an execution in flight";
219     }
220     const auto guard = ::android::base::make_scope_guard([this] { mExecutionInFlight.clear(); });
221 
222     if (relocation.input) {
223         relocation.input->flush();
224     }
225 
226     ExecutionResult executionResult;
227     const auto ret = kBurst->executeSynchronously(request, memoryIdentifierTokens, measure,
228                                                   deadline, loopTimeoutDuration, &executionResult);
229     HANDLE_ASTATUS(ret) << "execute failed";
230     if (!executionResult.outputSufficientSize) {
231         auto canonicalOutputShapes =
232                 nn::convert(executionResult.outputShapes).value_or(std::vector<nn::OutputShape>{});
233         return NN_ERROR(nn::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE, std::move(canonicalOutputShapes))
234                << "execution failed with " << nn::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE;
235     }
236     auto [outputShapes, timing] = NN_TRY(hal::utils::makeExecutionFailure(
237             convertExecutionResults(executionResult.outputShapes, executionResult.timing)));
238 
239     if (relocation.output) {
240         relocation.output->flush();
241     }
242     return std::make_pair(std::move(outputShapes), timing);
243 }
244 
createReusableExecution(const nn::Request & request,nn::MeasureTiming measure,const nn::OptionalDuration & loopTimeoutDuration) const245 nn::GeneralResult<nn::SharedExecution> Burst::createReusableExecution(
246         const nn::Request& request, nn::MeasureTiming measure,
247         const nn::OptionalDuration& loopTimeoutDuration) const {
248     // Ensure that request is ready for IPC.
249     std::optional<nn::Request> maybeRequestInShared;
250     hal::utils::RequestRelocation relocation;
251     const nn::Request& requestInShared = NN_TRY(hal::utils::convertRequestFromPointerToShared(
252             &request, nn::kDefaultRequestMemoryAlignment, nn::kDefaultRequestMemoryPadding,
253             &maybeRequestInShared, &relocation));
254 
255     auto aidlRequest = NN_TRY(convert(requestInShared));
256     const auto aidlMeasure = NN_TRY(convert(measure));
257     const auto aidlLoopTimeoutDuration = NN_TRY(convert(loopTimeoutDuration));
258 
259     std::vector<int64_t> memoryIdentifierTokens;
260     std::vector<OptionalCacheHold> holds;
261     memoryIdentifierTokens.reserve(requestInShared.pools.size());
262     holds.reserve(requestInShared.pools.size());
263     for (const auto& memoryPool : requestInShared.pools) {
264         if (const auto* memory = std::get_if<nn::SharedMemory>(&memoryPool)) {
265             if (auto cached = kMemoryCache->getMemoryIfAvailable(*memory)) {
266                 auto& [identifier, hold] = *cached;
267                 memoryIdentifierTokens.push_back(identifier);
268                 holds.push_back(std::move(hold));
269                 continue;
270             }
271         }
272         memoryIdentifierTokens.push_back(-1);
273     }
274     CHECK_EQ(requestInShared.pools.size(), memoryIdentifierTokens.size());
275 
276     return BurstExecution::create(shared_from_this(), std::move(aidlRequest),
277                                   std::move(memoryIdentifierTokens), aidlMeasure,
278                                   aidlLoopTimeoutDuration, std::move(relocation), std::move(holds));
279 }
280 
create(std::shared_ptr<const Burst> burst,Request request,std::vector<int64_t> memoryIdentifierTokens,bool measure,int64_t loopTimeoutDuration,hal::utils::RequestRelocation relocation,std::vector<Burst::OptionalCacheHold> cacheHolds)281 nn::GeneralResult<std::shared_ptr<const BurstExecution>> BurstExecution::create(
282         std::shared_ptr<const Burst> burst, Request request,
283         std::vector<int64_t> memoryIdentifierTokens, bool measure, int64_t loopTimeoutDuration,
284         hal::utils::RequestRelocation relocation,
285         std::vector<Burst::OptionalCacheHold> cacheHolds) {
286     if (burst == nullptr) {
287         return NN_ERROR() << "aidl::utils::BurstExecution::create must have non-null burst";
288     }
289 
290     return std::make_shared<const BurstExecution>(
291             PrivateConstructorTag{}, std::move(burst), std::move(request),
292             std::move(memoryIdentifierTokens), measure, loopTimeoutDuration, std::move(relocation),
293             std::move(cacheHolds));
294 }
295 
BurstExecution(PrivateConstructorTag,std::shared_ptr<const Burst> burst,Request request,std::vector<int64_t> memoryIdentifierTokens,bool measure,int64_t loopTimeoutDuration,hal::utils::RequestRelocation relocation,std::vector<Burst::OptionalCacheHold> cacheHolds)296 BurstExecution::BurstExecution(PrivateConstructorTag /*tag*/, std::shared_ptr<const Burst> burst,
297                                Request request, std::vector<int64_t> memoryIdentifierTokens,
298                                bool measure, int64_t loopTimeoutDuration,
299                                hal::utils::RequestRelocation relocation,
300                                std::vector<Burst::OptionalCacheHold> cacheHolds)
301     : kBurst(std::move(burst)),
302       kRequest(std::move(request)),
303       kMemoryIdentifierTokens(std::move(memoryIdentifierTokens)),
304       kMeasure(measure),
305       kLoopTimeoutDuration(loopTimeoutDuration),
306       kRelocation(std::move(relocation)),
307       kCacheHolds(std::move(cacheHolds)) {}
308 
compute(const nn::OptionalTimePoint & deadline) const309 nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> BurstExecution::compute(
310         const nn::OptionalTimePoint& deadline) const {
311     const auto aidlDeadline = NN_TRY(hal::utils::makeExecutionFailure(convert(deadline)));
312     return kBurst->executeInternal(kRequest, kMemoryIdentifierTokens, kMeasure, aidlDeadline,
313                                    kLoopTimeoutDuration, kRelocation);
314 }
315 
316 nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>>
computeFenced(const std::vector<nn::SyncFence> &,const nn::OptionalTimePoint &,const nn::OptionalDuration &) const317 BurstExecution::computeFenced(const std::vector<nn::SyncFence>& /*waitFor*/,
318                               const nn::OptionalTimePoint& /*deadline*/,
319                               const nn::OptionalDuration& /*timeoutDurationAfterFence*/) const {
320     return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
321            << "IExecution::computeFenced is not supported on burst object";
322 }
323 
324 }  // namespace aidl::android::hardware::neuralnetworks::utils
325