1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "ExecutionBurstController"
18 
19 #include "ExecutionBurstController.h"
20 #include "ExecutionBurstUtils.h"
21 
22 #include <android-base/logging.h>
23 #include <android-base/thread_annotations.h>
24 #include <nnapi/IBurst.h>
25 #include <nnapi/IPreparedModel.h>
26 #include <nnapi/Result.h>
27 #include <nnapi/TypeUtils.h>
28 #include <nnapi/Types.h>
29 #include <nnapi/Validation.h>
30 #include <nnapi/hal/1.0/Conversions.h>
31 #include <nnapi/hal/CommonUtils.h>
32 #include <nnapi/hal/HandleError.h>
33 #include <nnapi/hal/ProtectCallback.h>
34 #include <nnapi/hal/TransferValue.h>
35 
36 #include <algorithm>
37 #include <cstring>
38 #include <limits>
39 #include <memory>
40 #include <string>
41 #include <thread>
42 #include <tuple>
43 #include <utility>
44 #include <vector>
45 
46 #include "Callbacks.h"
47 #include "Conversions.h"
48 #include "Tracing.h"
49 #include "Utils.h"
50 
51 namespace android::hardware::neuralnetworks::V1_2::utils {
52 namespace {
53 
54 class BurstExecution final : public nn::IExecution,
55                              public std::enable_shared_from_this<BurstExecution> {
56     struct PrivateConstructorTag {};
57 
58   public:
59     static nn::GeneralResult<std::shared_ptr<const BurstExecution>> create(
60             std::shared_ptr<const ExecutionBurstController> controller,
61             std::vector<FmqRequestDatum> request, hal::utils::RequestRelocation relocation,
62             std::vector<ExecutionBurstController::OptionalCacheHold> cacheHolds);
63 
64     BurstExecution(PrivateConstructorTag tag,
65                    std::shared_ptr<const ExecutionBurstController> controller,
66                    std::vector<FmqRequestDatum> request, hal::utils::RequestRelocation relocation,
67                    std::vector<ExecutionBurstController::OptionalCacheHold> cacheHolds);
68 
69     nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> compute(
70             const nn::OptionalTimePoint& deadline) const override;
71 
72     nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>> computeFenced(
73             const std::vector<nn::SyncFence>& waitFor, const nn::OptionalTimePoint& deadline,
74             const nn::OptionalDuration& timeoutDurationAfterFence) const override;
75 
76   private:
77     const std::shared_ptr<const ExecutionBurstController> kController;
78     const std::vector<FmqRequestDatum> kRequest;
79     const hal::utils::RequestRelocation kRelocation;
80     const std::vector<ExecutionBurstController::OptionalCacheHold> kCacheHolds;
81 };
82 
83 nn::GeneralResult<sp<IBurstContext>> executionBurstResultCallback(
84         V1_0::ErrorStatus status, const sp<IBurstContext>& burstContext) {
85     HANDLE_HAL_STATUS(status) << "IPreparedModel::configureExecutionBurst failed with status "
86                               << toString(status);
87     if (burstContext == nullptr) {
88         return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
89                << "IPreparedModel::configureExecutionBurst returned nullptr for burst";
90     }
91     return burstContext;
92 }
93 
94 nn::GeneralResult<hidl_vec<hidl_memory>> getMemoriesHelper(
95         const hidl_vec<int32_t>& slots,
96         const std::shared_ptr<ExecutionBurstController::MemoryCache>& memoryCache) {
97     hidl_vec<hidl_memory> memories(slots.size());
98     for (size_t i = 0; i < slots.size(); ++i) {
99         const int32_t slot = slots[i];
100         const auto memory = NN_TRY(memoryCache->getMemory(slot));
101         memories[i] = NN_TRY(V1_0::utils::unvalidatedConvert(memory));
102         if (!memories[i].valid()) {
103             return NN_ERROR() << "memory at slot " << slot << " is invalid";
104         }
105     }
106     return memories;
107 }
108 
109 }  // namespace
110 
111 // MemoryCache methods
112 
113 ExecutionBurstController::MemoryCache::MemoryCache() {
114     constexpr size_t kPreallocatedCount = 1024;
115     std::vector<int32_t> freeSlotsSpace;
116     freeSlotsSpace.reserve(kPreallocatedCount);
117     mFreeSlots = std::stack<int32_t, std::vector<int32_t>>(std::move(freeSlotsSpace));
118     mMemoryCache.reserve(kPreallocatedCount);
119     mCacheCleaner.reserve(kPreallocatedCount);
120 }
121 
122 void ExecutionBurstController::MemoryCache::setBurstContext(sp<IBurstContext> burstContext) {
123     std::lock_guard guard(mMutex);
124     mBurstContext = std::move(burstContext);
125 }
126 
127 std::pair<int32_t, ExecutionBurstController::MemoryCache::SharedCleanup>
128 ExecutionBurstController::MemoryCache::cacheMemory(const nn::SharedMemory& memory) {
129     std::unique_lock lock(mMutex);
130     base::ScopedLockAssertion lockAssert(mMutex);
131 
132     // Use existing cache entry if (1) the Memory object is in the cache and (2) the cache entry is
133     // not currently being freed.
134     auto iter = mMemoryIdToSlot.find(memory);
135     while (iter != mMemoryIdToSlot.end()) {
136         const int32_t slot = iter->second;
137         if (auto cleaner = mCacheCleaner.at(slot).lock()) {
138             return std::make_pair(slot, std::move(cleaner));
139         }
140 
141         // If the code reaches this point, the Memory object was in the cache, but is currently
142         // being destroyed. This code waits until the cache entry has been freed, then loops to
143         // ensure the cache entry has been freed or has been made present by another thread.
144         mCond.wait(lock);
145         iter = mMemoryIdToSlot.find(memory);
146     }
147 
148     // Allocate a new cache entry.
149     const int32_t slot = allocateSlotLocked();
150     mMemoryIdToSlot[memory] = slot;
151     mMemoryCache[slot] = memory;
152 
153     // Create reference-counted self-cleaning cache object.
154     auto self = weak_from_this();
155     Task cleanup = [memory, memoryCache = std::move(self)] {
156         if (const auto lock = memoryCache.lock()) {
157             lock->freeMemory(memory);
158         }
159     };
160     auto cleaner = std::make_shared<const Cleanup>(std::move(cleanup));
161     mCacheCleaner[slot] = cleaner;
162 
163     return std::make_pair(slot, std::move(cleaner));
164 }
165 
166 nn::GeneralResult<nn::SharedMemory> ExecutionBurstController::MemoryCache::getMemory(int32_t slot) {
167     std::lock_guard guard(mMutex);
168     if (slot < 0 || static_cast<size_t>(slot) >= mMemoryCache.size()) {
169         return NN_ERROR() << "Invalid slot: " << slot << " vs " << mMemoryCache.size();
170     }
171     return mMemoryCache[slot];
172 }
173 
174 void ExecutionBurstController::MemoryCache::freeMemory(const nn::SharedMemory& memory) {
175     {
176         std::lock_guard guard(mMutex);
177         const int32_t slot = mMemoryIdToSlot.at(memory);
178         if (mBurstContext) {
179             mBurstContext->freeMemory(slot);
180         }
181         mMemoryIdToSlot.erase(memory);
182         mMemoryCache[slot] = {};
183         mCacheCleaner[slot].reset();
184         mFreeSlots.push(slot);
185     }
186     mCond.notify_all();
187 }
188 
189 int32_t ExecutionBurstController::MemoryCache::allocateSlotLocked() {
190     constexpr size_t kMaxNumberOfSlots = std::numeric_limits<int32_t>::max();
191 
192     // If there is a free slot, use it.
193     if (!mFreeSlots.empty()) {
194         const int32_t slot = mFreeSlots.top();
195         mFreeSlots.pop();
196         return slot;
197     }
198 
199     // Use a slot for the first time.
200     CHECK_LT(mMemoryCache.size(), kMaxNumberOfSlots) << "Exceeded maximum number of slots!";
201     const int32_t slot = static_cast<int32_t>(mMemoryCache.size());
202     mMemoryCache.emplace_back();
203     mCacheCleaner.emplace_back();
204 
205     return slot;
206 }
207 
208 // ExecutionBurstCallback methods
209 
210 ExecutionBurstController::ExecutionBurstCallback::ExecutionBurstCallback(
211         const std::shared_ptr<MemoryCache>& memoryCache)
212     : kMemoryCache(memoryCache) {
213     CHECK(memoryCache != nullptr);
214 }
215 
216 Return<void> ExecutionBurstController::ExecutionBurstCallback::getMemories(
217         const hidl_vec<int32_t>& slots, getMemories_cb cb) {
218     const auto memoryCache = kMemoryCache.lock();
219     if (memoryCache == nullptr) {
220         LOG(ERROR) << "ExecutionBurstController::ExecutionBurstCallback::getMemories called after "
221                       "the MemoryCache has been freed";
222         cb(V1_0::ErrorStatus::GENERAL_FAILURE, {});
223         return Void();
224     }
225 
226     const auto maybeMemories = getMemoriesHelper(slots, memoryCache);
227     if (!maybeMemories.has_value()) {
228         const auto& [message, code] = maybeMemories.error();
229         LOG(ERROR) << "ExecutionBurstController::ExecutionBurstCallback::getMemories failed with "
230                    << code << ": " << message;
231         cb(V1_0::ErrorStatus::INVALID_ARGUMENT, {});
232         return Void();
233     }
234 
235     cb(V1_0::ErrorStatus::NONE, maybeMemories.value());
236     return Void();
237 }
238 
239 // ExecutionBurstController methods
240 
241 nn::GeneralResult<std::shared_ptr<const ExecutionBurstController>> ExecutionBurstController::create(
242         nn::SharedPreparedModel preparedModel, const sp<V1_2::IPreparedModel>& hidlPreparedModel,
243         std::chrono::microseconds pollingTimeWindow) {
244     // check inputs
245     if (preparedModel == nullptr || hidlPreparedModel == nullptr) {
246         return NN_ERROR() << "ExecutionBurstController::create passed a nullptr";
247     }
248 
249     // create FMQ objects
250     auto [requestChannelSender, requestChannelDescriptor] =
251             NN_TRY(RequestChannelSender::create(kExecutionBurstChannelLength));
252     auto [resultChannelReceiver, resultChannelDescriptor] =
253             NN_TRY(ResultChannelReceiver::create(kExecutionBurstChannelLength, pollingTimeWindow));
254 
255     // check FMQ objects
256     CHECK(requestChannelSender != nullptr);
257     CHECK(requestChannelDescriptor != nullptr);
258     CHECK(resultChannelReceiver != nullptr);
259     CHECK(resultChannelDescriptor != nullptr);
260 
261     // create memory cache
262     auto memoryCache = std::make_shared<MemoryCache>();
263 
264     // create callback object
265     auto burstCallback = sp<ExecutionBurstCallback>::make(memoryCache);
266     auto cb = hal::utils::CallbackValue(executionBurstResultCallback);
267 
268     // configure burst
269     const Return<void> ret = hidlPreparedModel->configureExecutionBurst(
270             burstCallback, *requestChannelDescriptor, *resultChannelDescriptor, cb);
271     HANDLE_TRANSPORT_FAILURE(ret);
272 
273     auto burstContext = NN_TRY(cb.take());
274     memoryCache->setBurstContext(burstContext);
275 
276     // create death handler object
277     auto deathHandler = NN_TRY(neuralnetworks::utils::DeathHandler::create(burstContext));
278     deathHandler.protectCallbackForLifetimeOfDeathHandler(requestChannelSender.get());
279     deathHandler.protectCallbackForLifetimeOfDeathHandler(resultChannelReceiver.get());
280 
281     // make and return controller
282     return std::make_shared<const ExecutionBurstController>(
283             PrivateConstructorTag{}, std::move(preparedModel), std::move(requestChannelSender),
284             std::move(resultChannelReceiver), std::move(burstCallback), std::move(burstContext),
285             std::move(memoryCache), std::move(deathHandler));
286 }
287 
288 ExecutionBurstController::ExecutionBurstController(
289         PrivateConstructorTag /*tag*/, nn::SharedPreparedModel preparedModel,
290         std::unique_ptr<RequestChannelSender> requestChannelSender,
291         std::unique_ptr<ResultChannelReceiver> resultChannelReceiver,
292         sp<ExecutionBurstCallback> callback, sp<IBurstContext> burstContext,
293         std::shared_ptr<MemoryCache> memoryCache, neuralnetworks::utils::DeathHandler deathHandler)
294     : kPreparedModel(std::move(preparedModel)),
295       mRequestChannelSender(std::move(requestChannelSender)),
296       mResultChannelReceiver(std::move(resultChannelReceiver)),
297       mBurstCallback(std::move(callback)),
298       mBurstContext(std::move(burstContext)),
299       mMemoryCache(std::move(memoryCache)),
300       kDeathHandler(std::move(deathHandler)) {}
301 
302 ExecutionBurstController::OptionalCacheHold ExecutionBurstController::cacheMemory(
303         const nn::SharedMemory& memory) const {
304     auto [slot, hold] = mMemoryCache->cacheMemory(memory);
305     return hold;
306 }
307 
308 nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>>
309 ExecutionBurstController::execute(const nn::Request& request, nn::MeasureTiming measure,
310                                   const nn::OptionalTimePoint& deadline,
311                                   const nn::OptionalDuration& loopTimeoutDuration) const {
312     // This is the first point when we know an execution is occurring, so begin to collect
313     // systraces. Note that the first point we can begin collecting systraces in
314     // ExecutionBurstServer is when the RequestChannelReceiver realizes there is data in the FMQ, so
315     // ExecutionBurstServer collects systraces at different points in the code.
316     NNTRACE_RT(NNTRACE_PHASE_EXECUTION, "ExecutionBurstController::execute");
317 
318     // if the request is valid but of a higher version than what's supported in burst execution,
319     // fall back to another execution path
320     if (const auto version = NN_TRY(hal::utils::makeExecutionFailure(nn::validate(request)));
321         version > nn::Version::ANDROID_Q) {
322         // fallback to another execution path if the packet could not be sent
323         return kPreparedModel->execute(request, measure, deadline, loopTimeoutDuration);
324     }
325 
326     // ensure that request is ready for IPC
327     std::optional<nn::Request> maybeRequestInShared;
328     hal::utils::RequestRelocation relocation;
329     const nn::Request& requestInShared =
330             NN_TRY(hal::utils::makeExecutionFailure(hal::utils::convertRequestFromPointerToShared(
331                     &request, nn::kDefaultRequestMemoryAlignment, nn::kMinMemoryPadding,
332                     &maybeRequestInShared, &relocation)));
333 
334     // clear pools field of request, as they will be provided via slots
335     const auto requestWithoutPools = nn::Request{
336             .inputs = requestInShared.inputs, .outputs = requestInShared.outputs, .pools = {}};
337     auto hidlRequest = NN_TRY(
338             hal::utils::makeExecutionFailure(V1_0::utils::unvalidatedConvert(requestWithoutPools)));
339     const auto hidlMeasure = NN_TRY(hal::utils::makeExecutionFailure(convert(measure)));
340 
341     std::vector<int32_t> slots;
342     std::vector<OptionalCacheHold> holds;
343     slots.reserve(requestInShared.pools.size());
344     holds.reserve(requestInShared.pools.size());
345     for (const auto& memoryPool : requestInShared.pools) {
346         auto [slot, hold] = mMemoryCache->cacheMemory(std::get<nn::SharedMemory>(memoryPool));
347         slots.push_back(slot);
348         holds.push_back(std::move(hold));
349     }
350 
351     // send request packet
352     const auto requestPacket = serialize(hidlRequest, hidlMeasure, slots);
353     const auto fallback = [this, &request, measure, &deadline, &loopTimeoutDuration] {
354         return kPreparedModel->execute(request, measure, deadline, loopTimeoutDuration);
355     };
356     return executeInternal(requestPacket, relocation, fallback);
357 }
358 
359 // See IBurst::createReusableExecution for information on this method.
360 nn::GeneralResult<nn::SharedExecution> ExecutionBurstController::createReusableExecution(
361         const nn::Request& request, nn::MeasureTiming measure,
362         const nn::OptionalDuration& loopTimeoutDuration) const {
363     NNTRACE_RT(NNTRACE_PHASE_EXECUTION, "ExecutionBurstController::createReusableExecution");
364 
365     // if the request is valid but of a higher version than what's supported in burst execution,
366     // fall back to another execution path
367     if (const auto version = NN_TRY(hal::utils::makeGeneralFailure(nn::validate(request)));
368         version > nn::Version::ANDROID_Q) {
369         // fallback to another execution path if the packet could not be sent
370         return kPreparedModel->createReusableExecution(request, measure, loopTimeoutDuration);
371     }
372 
373     // ensure that request is ready for IPC
374     std::optional<nn::Request> maybeRequestInShared;
375     hal::utils::RequestRelocation relocation;
376     const nn::Request& requestInShared = NN_TRY(hal::utils::convertRequestFromPointerToShared(
377             &request, nn::kDefaultRequestMemoryAlignment, nn::kMinMemoryPadding,
378             &maybeRequestInShared, &relocation));
379 
380     // clear pools field of request, as they will be provided via slots
381     const auto requestWithoutPools = nn::Request{
382             .inputs = requestInShared.inputs, .outputs = requestInShared.outputs, .pools = {}};
383     auto hidlRequest = NN_TRY(V1_0::utils::unvalidatedConvert(requestWithoutPools));
384     const auto hidlMeasure = NN_TRY(convert(measure));
385 
386     std::vector<int32_t> slots;
387     std::vector<OptionalCacheHold> holds;
388     slots.reserve(requestInShared.pools.size());
389     holds.reserve(requestInShared.pools.size());
390     for (const auto& memoryPool : requestInShared.pools) {
391         auto [slot, hold] = mMemoryCache->cacheMemory(std::get<nn::SharedMemory>(memoryPool));
392         slots.push_back(slot);
393         holds.push_back(std::move(hold));
394     }
395 
396     const auto requestPacket = serialize(hidlRequest, hidlMeasure, slots);
397     return BurstExecution::create(shared_from_this(), std::move(requestPacket),
398                                   std::move(relocation), std::move(holds));
399 }
400 
401 nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>>
402 ExecutionBurstController::executeInternal(const std::vector<FmqRequestDatum>& requestPacket,
403                                           const hal::utils::RequestRelocation& relocation,
404                                           FallbackFunction fallback) const {
405     NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION,
406                  "ExecutionBurstController::executeInternal");
407 
408     // Ensure that at most one execution is in flight at any given time.
409     const bool alreadyInFlight = mExecutionInFlight.test_and_set();
410     if (alreadyInFlight) {
411         return NN_ERROR() << "IBurst already has an execution in flight";
412     }
413     const auto guard = base::make_scope_guard([this] { mExecutionInFlight.clear(); });
414 
415     if (relocation.input) {
416         relocation.input->flush();
417     }
418 
419     // send request packet
420     const auto sendStatus = mRequestChannelSender->sendPacket(requestPacket);
421     if (!sendStatus.ok()) {
422         // fallback to another execution path if the packet could not be sent
423         if (fallback) {
424             return fallback();
425         }
426         return NN_ERROR() << "Error sending FMQ packet: " << sendStatus.error();
427     }
428 
429     // get result packet
430     const auto [status, outputShapes, timing] =
431             NN_TRY(hal::utils::makeExecutionFailure(mResultChannelReceiver->getBlocking()));
432 
433     if (relocation.output) {
434         relocation.output->flush();
435     }
436     return executionCallback(status, outputShapes, timing);
437 }
438 
439 nn::GeneralResult<std::shared_ptr<const BurstExecution>> BurstExecution::create(
440         std::shared_ptr<const ExecutionBurstController> controller,
441         std::vector<FmqRequestDatum> request, hal::utils::RequestRelocation relocation,
442         std::vector<ExecutionBurstController::OptionalCacheHold> cacheHolds) {
443     if (controller == nullptr) {
444         return NN_ERROR() << "V1_2::utils::BurstExecution::create must have non-null controller";
445     }
446 
447     return std::make_shared<const BurstExecution>(PrivateConstructorTag{}, std::move(controller),
448                                                   std::move(request), std::move(relocation),
449                                                   std::move(cacheHolds));
450 }
451 
452 BurstExecution::BurstExecution(PrivateConstructorTag /*tag*/,
453                                std::shared_ptr<const ExecutionBurstController> controller,
454                                std::vector<FmqRequestDatum> request,
455                                hal::utils::RequestRelocation relocation,
456                                std::vector<ExecutionBurstController::OptionalCacheHold> cacheHolds)
457     : kController(std::move(controller)),
458       kRequest(std::move(request)),
459       kRelocation(std::move(relocation)),
460       kCacheHolds(std::move(cacheHolds)) {}
461 
462 nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> BurstExecution::compute(
463         const nn::OptionalTimePoint& /*deadline*/) const {
464     return kController->executeInternal(kRequest, kRelocation, /*fallback=*/nullptr);
465 }
466 
467 nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>>
468 BurstExecution::computeFenced(const std::vector<nn::SyncFence>& /*waitFor*/,
469                               const nn::OptionalTimePoint& /*deadline*/,
470                               const nn::OptionalDuration& /*timeoutDurationAfterFence*/) const {
471     return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
472            << "IExecution::computeFenced is not supported on burst object";
473 }
474 
475 }  // namespace android::hardware::neuralnetworks::V1_2::utils
476