1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_1_2_UTILS_EXECUTION_BURST_CONTROLLER_H
18 #define ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_1_2_UTILS_EXECUTION_BURST_CONTROLLER_H
19 
20 #include "ExecutionBurstUtils.h"
21 
22 #include <android-base/thread_annotations.h>
23 #include <android/hardware/neuralnetworks/1.0/types.h>
24 #include <android/hardware/neuralnetworks/1.2/IBurstCallback.h>
25 #include <android/hardware/neuralnetworks/1.2/IBurstContext.h>
26 #include <android/hardware/neuralnetworks/1.2/IPreparedModel.h>
27 #include <android/hardware/neuralnetworks/1.2/types.h>
28 #include <fmq/MessageQueue.h>
29 #include <hidl/MQDescriptor.h>
30 #include <nnapi/IBurst.h>
31 #include <nnapi/IExecution.h>
32 #include <nnapi/IPreparedModel.h>
33 #include <nnapi/Result.h>
34 #include <nnapi/Types.h>
35 #include <nnapi/hal/CommonUtils.h>
36 #include <nnapi/hal/ProtectCallback.h>
37 
38 #include <atomic>
39 #include <chrono>
40 #include <functional>
41 #include <map>
42 #include <memory>
43 #include <mutex>
44 #include <stack>
45 #include <tuple>
46 #include <utility>
47 #include <vector>
48 
49 namespace android::hardware::neuralnetworks::V1_2::utils {
50 
51 /**
52  * The ExecutionBurstController class manages both the serialization and deserialization of data
53  * across FMQ, making it appear to the runtime as a regular synchronous inference. Additionally,
54  * this class manages the burst's memory cache.
55  */
56 class ExecutionBurstController final
57     : public nn::IBurst,
58       public std::enable_shared_from_this<ExecutionBurstController> {
59     struct PrivateConstructorTag {};
60 
61   public:
62     using FallbackFunction = std::function<
63             nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>>()>;
64 
65     /**
66      * NN runtime memory cache.
67      *
68      * MemoryCache associates a Memory object with a slot number to be passed across FMQ. The
69      * ExecutionBurstServer can use this callback to retrieve a hidl_memory corresponding to the
70      * slot via HIDL.
71      *
72      * Whenever a hidl_memory object is copied, it will duplicate the underlying file descriptor.
73      * Because the NN runtime currently copies the hidl_memory on each execution, it is difficult to
74      * associate hidl_memory objects with previously cached hidl_memory objects. For this reason,
75      * callers of this class must pair each hidl_memory object with an associated key. For
76      * efficiency, if two hidl_memory objects represent the same underlying buffer, they must use
77      * the same key.
78      *
79      * This class is thread-safe.
80      */
81     class MemoryCache : public std::enable_shared_from_this<MemoryCache> {
82         struct PrivateConstructorTag {};
83 
84       public:
85         using Task = std::function<void()>;
86         using Cleanup = base::ScopeGuard<Task>;
87         using SharedCleanup = std::shared_ptr<const Cleanup>;
88         using WeakCleanup = std::weak_ptr<const Cleanup>;
89 
90         // Custom constructor to pre-allocate cache sizes.
91         MemoryCache();
92 
93         /**
94          * Add a burst context to the MemoryCache object.
95          *
96          * If this method is called, it must be called before the MemoryCache::cacheMemory or
97          * MemoryCache::getMemory is used.
98          *
99          * @param burstContext Burst context to be added to the MemoryCache object.
100          */
101         void setBurstContext(sp<IBurstContext> burstContext);
102 
103         /**
104          * Cache a memory object in the MemoryCache object.
105          *
106          * @param memory Memory object to be cached while the returned `SharedCleanup` is alive.
107          * @return A pair of (1) a unique identifier for the cache entry and (2) a ref-counted
108          *     "hold" object which preserves the cache as long as the hold object is alive.
109          */
110         std::pair<int32_t, SharedCleanup> cacheMemory(const nn::SharedMemory& memory);
111 
112         /**
113          * Get the memory object corresponding to a slot identifier.
114          *
115          * @param slot Slot which identifies the memory object to retrieve.
116          * @return The memory object corresponding to slot, otherwise GeneralError.
117          */
118         nn::GeneralResult<nn::SharedMemory> getMemory(int32_t slot);
119 
120       private:
121         void freeMemory(const nn::SharedMemory& memory);
122         int32_t allocateSlotLocked() REQUIRES(mMutex);
123 
124         std::mutex mMutex;
125         std::condition_variable mCond;
126         sp<IBurstContext> mBurstContext GUARDED_BY(mMutex);
127         std::stack<int32_t, std::vector<int32_t>> mFreeSlots GUARDED_BY(mMutex);
128         std::map<nn::SharedMemory, int32_t> mMemoryIdToSlot GUARDED_BY(mMutex);
129         std::vector<nn::SharedMemory> mMemoryCache GUARDED_BY(mMutex);
130         std::vector<WeakCleanup> mCacheCleaner GUARDED_BY(mMutex);
131     };
132 
133     /**
134      * HIDL Callback class to pass memory objects to the Burst server when given corresponding
135      * slots.
136      */
137     class ExecutionBurstCallback : public IBurstCallback {
138       public:
139         // Precondition: memoryCache must be non-null.
140         explicit ExecutionBurstCallback(const std::shared_ptr<MemoryCache>& memoryCache);
141 
142         // See IBurstCallback::getMemories for information on this method.
143         Return<void> getMemories(const hidl_vec<int32_t>& slots, getMemories_cb cb) override;
144 
145       private:
146         const std::weak_ptr<MemoryCache> kMemoryCache;
147     };
148 
149     /**
150      * Creates a burst controller on a prepared model.
151      *
152      * @param preparedModel Model prepared for execution to execute on.
153      * @param pollingTimeWindow How much time (in microseconds) the ExecutionBurstController is
154      *     allowed to poll the FMQ before waiting on the blocking futex. Polling may result in lower
155      *     latencies at the potential cost of more power usage.
156      * @return ExecutionBurstController Execution burst controller object.
157      */
158     static nn::GeneralResult<std::shared_ptr<const ExecutionBurstController>> create(
159             nn::SharedPreparedModel preparedModel, const sp<IPreparedModel>& hidlPreparedModel,
160             std::chrono::microseconds pollingTimeWindow);
161 
162     ExecutionBurstController(PrivateConstructorTag tag, nn::SharedPreparedModel preparedModel,
163                              std::unique_ptr<RequestChannelSender> requestChannelSender,
164                              std::unique_ptr<ResultChannelReceiver> resultChannelReceiver,
165                              sp<ExecutionBurstCallback> callback, sp<IBurstContext> burstContext,
166                              std::shared_ptr<MemoryCache> memoryCache,
167                              neuralnetworks::utils::DeathHandler deathHandler);
168 
169     // See IBurst::cacheMemory for information on this method.
170     OptionalCacheHold cacheMemory(const nn::SharedMemory& memory) const override;
171 
172     // See IBurst::execute for information on this method.
173     nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> execute(
174             const nn::Request& request, nn::MeasureTiming measure,
175             const nn::OptionalTimePoint& deadline,
176             const nn::OptionalDuration& loopTimeoutDuration) const override;
177 
178     // See IBurst::createReusableExecution for information on this method.
179     nn::GeneralResult<nn::SharedExecution> createReusableExecution(
180             const nn::Request& request, nn::MeasureTiming measure,
181             const nn::OptionalDuration& loopTimeoutDuration) const override;
182 
183     // If fallback is not nullptr, this method will invoke the fallback function to try another
184     // execution path if the packet could not be sent. Otherwise, failing to send the packet will
185     // result in an error.
186     nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> executeInternal(
187             const std::vector<FmqRequestDatum>& requestPacket,
188             const hal::utils::RequestRelocation& relocation, FallbackFunction fallback) const;
189 
190   private:
191     mutable std::atomic_flag mExecutionInFlight = ATOMIC_FLAG_INIT;
192     const nn::SharedPreparedModel kPreparedModel;
193     const std::unique_ptr<RequestChannelSender> mRequestChannelSender;
194     const std::unique_ptr<ResultChannelReceiver> mResultChannelReceiver;
195     const sp<ExecutionBurstCallback> mBurstCallback;
196     const sp<IBurstContext> mBurstContext;
197     const std::shared_ptr<MemoryCache> mMemoryCache;
198     // `kDeathHandler` must come after `mRequestChannelSender` and `mResultChannelReceiver` because
199     // it holds references to both objects.
200     const neuralnetworks::utils::DeathHandler kDeathHandler;
201 };
202 
203 }  // namespace android::hardware::neuralnetworks::V1_2::utils
204 
205 #endif  // ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_1_2_UTILS_EXECUTION_BURST_CONTROLLER_H
206