1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef ANDROID_ML_NN_RUNTIME_EXECUTION_BURST_CONTROLLER_H
18 #define ANDROID_ML_NN_RUNTIME_EXECUTION_BURST_CONTROLLER_H
19 
20 #include "HalInterfaces.h"
21 
22 #include <android-base/macros.h>
23 #include <fmq/MessageQueue.h>
24 #include <hidl/MQDescriptor.h>
25 
26 #include <atomic>
27 #include <map>
28 #include <memory>
29 #include <mutex>
30 #include <stack>
31 #include <tuple>
32 
33 namespace android::nn {
34 
35 /**
36  * Number of elements in the FMQ.
37  */
38 constexpr const size_t kExecutionBurstChannelLength = 1024;
39 
40 /**
41  * Function to serialize a request.
42  *
43  * Prefer calling RequestChannelSender::send.
44  *
45  * @param request Request object without the pool information.
46  * @param measure Whether to collect timing information for the execution.
47  * @param memoryIds Slot identifiers corresponding to memory resources for the
48  *     request.
49  * @return Serialized FMQ request data.
50  */
51 std::vector<FmqRequestDatum> serialize(const Request& request, MeasureTiming measure,
52                                        const std::vector<int32_t>& slots);
53 
54 /**
55  * Deserialize the FMQ result data.
56  *
57  * The three resulting fields are the status of the execution, the dynamic
58  * shapes of the output tensors, and the timing information of the execution.
59  *
60  * @param data Serialized FMQ result data.
61  * @return Result object if successfully deserialized, std::nullopt otherwise.
62  */
63 std::optional<std::tuple<ErrorStatus, std::vector<OutputShape>, Timing>> deserialize(
64         const std::vector<FmqResultDatum>& data);
65 
66 /**
67  * ResultChannelReceiver is responsible for waiting on the channel until the
68  * packet is available, extracting the packet from the channel, and
69  * deserializing the packet.
70  *
71  * Because the receiver can wait on a packet that may never come (e.g., because
72  * the sending side of the packet has been closed), this object can be
73  * invalidating, unblocking the receiver.
74  */
75 class ResultChannelReceiver {
76     using FmqResultDescriptor = ::android::hardware::MQDescriptorSync<FmqResultDatum>;
77     using FmqResultChannel =
78             hardware::MessageQueue<FmqResultDatum, hardware::kSynchronizedReadWrite>;
79 
80    public:
81     /**
82      * Create the receiving end of a result channel.
83      *
84      * Prefer this call over the constructor.
85      *
86      * @param channelLength Number of elements in the FMQ.
87      * @param blocking 'true' if FMQ should use futex, 'false' if it should
88      *     spin-wait.
89      * @return A pair of ResultChannelReceiver and the FMQ descriptor on
90      *     successful creation, both nullptr otherwise.
91      */
92     static std::pair<std::unique_ptr<ResultChannelReceiver>, const FmqResultDescriptor*> create(
93             size_t channelLength, bool blocking);
94 
95     /**
96      * Get the result from the channel.
97      *
98      * This method will block until either:
99      * 1) The packet has been retrieved, or
100      * 2) The receiver has been invalidated
101      *
102      * @return Result object if successfully received, std::nullopt if error or
103      *     if the receiver object was invalidated.
104      */
105     std::optional<std::tuple<ErrorStatus, std::vector<OutputShape>, Timing>> getBlocking();
106 
107     /**
108      * Method to mark the channel as invalid, unblocking any current or future
109      * calls to ResultChannelReceiver::getBlocking.
110      */
111     void invalidate();
112 
113     // prefer calling ResultChannelReceiver::getBlocking
114     std::optional<std::vector<FmqResultDatum>> getPacketBlocking();
115 
116     ResultChannelReceiver(std::unique_ptr<FmqResultChannel> fmqResultChannel, bool blocking);
117 
118    private:
119     const std::unique_ptr<FmqResultChannel> mFmqResultChannel;
120     std::atomic<bool> mValid{true};
121     const bool mBlocking;
122 };
123 
124 /**
125  * RequestChannelSender is responsible for serializing the result packet of
126  * information, sending it on the result channel, and signaling that the data is
127  * available.
128  */
129 class RequestChannelSender {
130     using FmqRequestDescriptor = ::android::hardware::MQDescriptorSync<FmqRequestDatum>;
131     using FmqRequestChannel =
132             hardware::MessageQueue<FmqRequestDatum, hardware::kSynchronizedReadWrite>;
133 
134    public:
135     /**
136      * Create the sending end of a request channel.
137      *
138      * Prefer this call over the constructor.
139      *
140      * @param channelLength Number of elements in the FMQ.
141      * @param blocking 'true' if FMQ should use futex, 'false' if it should
142      *     spin-wait.
143      * @return A pair of ResultChannelReceiver and the FMQ descriptor on
144      *     successful creation, both nullptr otherwise.
145      */
146     static std::pair<std::unique_ptr<RequestChannelSender>, const FmqRequestDescriptor*> create(
147             size_t channelLength, bool blocking);
148 
149     /**
150      * Send the request to the channel.
151      *
152      * @param request Request object without the pool information.
153      * @param measure Whether to collect timing information for the execution.
154      * @param memoryIds Slot identifiers corresponding to memory resources for
155      *     the request.
156      * @return 'true' on successful send, 'false' otherwise.
157      */
158     bool send(const Request& request, MeasureTiming measure, const std::vector<int32_t>& slots);
159 
160     /**
161      * Method to mark the channel as invalid, causing all future calls to
162      * RequestChannelSender::send to immediately return false without attempting
163      * to send a message across the FMQ.
164      */
165     void invalidate();
166 
167     // prefer calling RequestChannelSender::send
168     bool sendPacket(const std::vector<FmqRequestDatum>& packet);
169 
170     RequestChannelSender(std::unique_ptr<FmqRequestChannel> fmqRequestChannel, bool blocking);
171 
172    private:
173     const std::unique_ptr<FmqRequestChannel> mFmqRequestChannel;
174     std::atomic<bool> mValid{true};
175     const bool mBlocking;
176 };
177 
178 /**
179  * The ExecutionBurstController class manages both the serialization and
180  * deserialization of data across FMQ, making it appear to the runtime as a
181  * regular synchronous inference. Additionally, this class manages the burst's
182  * memory cache.
183  */
184 class ExecutionBurstController {
185     DISALLOW_IMPLICIT_CONSTRUCTORS(ExecutionBurstController);
186 
187    public:
188     /**
189      * NN runtime burst callback object and memory cache.
190      *
191      * ExecutionBurstCallback associates a hidl_memory object with a slot number
192      * to be passed across FMQ. The ExecutionBurstServer can use this callback
193      * to retrieve this hidl_memory corresponding to the slot via HIDL.
194      *
195      * Whenever a hidl_memory object is copied, it will duplicate the underlying
196      * file descriptor. Because the NN runtime currently copies the hidl_memory
197      * on each execution, it is difficult to associate hidl_memory objects with
198      * previously cached hidl_memory objects. For this reason, callers of this
199      * class must pair each hidl_memory object with an associated key. For
200      * efficiency, if two hidl_memory objects represent the same underlying
201      * buffer, they must use the same key.
202      */
203     class ExecutionBurstCallback : public IBurstCallback {
204         DISALLOW_COPY_AND_ASSIGN(ExecutionBurstCallback);
205 
206        public:
207         ExecutionBurstCallback() = default;
208 
209         Return<void> getMemories(const hidl_vec<int32_t>& slots, getMemories_cb cb) override;
210 
211         /**
212          * This function performs one of two different actions:
213          * 1) If a key corresponding to a memory resource is unrecognized by the
214          *    ExecutionBurstCallback object, the ExecutionBurstCallback object
215          *    will allocate a slot, bind the memory to the slot, and return the
216          *    slot identifier.
217          * 2) If a key corresponding to a memory resource is recognized by the
218          *    ExecutionBurstCallback object, the ExecutionBurstCallback object
219          *    will return the existing slot identifier.
220          *
221          * @param memories Memory resources used in an inference.
222          * @param keys Unique identifiers where each element corresponds to a
223          *     memory resource element in "memories".
224          * @return Unique slot identifiers where each returned slot element
225          *     corresponds to a memory resource element in "memories".
226          */
227         std::vector<int32_t> getSlots(const hidl_vec<hidl_memory>& memories,
228                                       const std::vector<intptr_t>& keys);
229 
230         /*
231          * This function performs two different actions:
232          * 1) Removes an entry from the cache (if present), including the local
233          *    storage of the hidl_memory object. Note that this call does not
234          *    free any corresponding hidl_memory object in ExecutionBurstServer,
235          *    which is separately freed via IBurstContext::freeMemory.
236          * 2) Return whether a cache entry was removed and which slot was removed if
237          *    found. If the key did not to correspond to any entry in the cache, a
238          *    slot number of 0 is returned. The slot number and whether the entry
239          *    existed is useful so the same slot can be freed in the
240          *    ExecutionBurstServer's cache via IBurstContext::freeMemory.
241          */
242         std::pair<bool, int32_t> freeMemory(intptr_t key);
243 
244        private:
245         int32_t getSlotLocked(const hidl_memory& memory, intptr_t key);
246         int32_t allocateSlotLocked();
247 
248         std::mutex mMutex;
249         std::stack<int32_t, std::vector<int32_t>> mFreeSlots;
250         std::map<intptr_t, int32_t> mMemoryIdToSlot;
251         std::vector<hidl_memory> mMemoryCache;
252     };
253 
254     /**
255      * Creates a burst controller on a prepared model.
256      *
257      * Prefer this over ExecutionBurstController's constructor.
258      *
259      * @param preparedModel Model prepared for execution to execute on.
260      * @param blocking 'true' if the FMQ should use a futex to perform blocking
261      *     until data is available in a less responsive, but more energy
262      *     efficient manner. 'false' if the FMQ should use spin-looping to
263      *     wait until data is available in a more responsive, but less energy
264      *     efficient manner.
265      * @return ExecutionBurstController Execution burst controller object.
266      */
267     static std::unique_ptr<ExecutionBurstController> create(const sp<IPreparedModel>& preparedModel,
268                                                             bool blocking);
269 
270     // prefer calling ExecutionBurstController::create
271     ExecutionBurstController(const std::shared_ptr<RequestChannelSender>& requestChannelSender,
272                              const std::shared_ptr<ResultChannelReceiver>& resultChannelReceiver,
273                              const sp<IBurstContext>& burstContext,
274                              const sp<ExecutionBurstCallback>& callback,
275                              const sp<hardware::hidl_death_recipient>& deathHandler = nullptr);
276 
277     // explicit destructor to unregister the death recipient
278     ~ExecutionBurstController();
279 
280     /**
281      * Execute a request on a model.
282      *
283      * @param request Arguments to be executed on a model.
284      * @param measure Whether to collect timing measurements, either YES or NO
285      * @param memoryIds Identifiers corresponding to each memory object in the
286      *     request's pools.
287      * @return A tuple of:
288      *     - status of the execution
289      *     - dynamic output shapes from the execution
290      *     - any execution time measurements of the execution
291      */
292     std::tuple<ErrorStatus, std::vector<OutputShape>, Timing> compute(
293             const Request& request, MeasureTiming measure, const std::vector<intptr_t>& memoryIds);
294 
295     // TODO: combine "compute" and "tryCompute" back into a single function.
296     // "tryCompute" was created later to return the "fallback" boolean. This
297     // could not be done directly in "compute" because the VTS test cases (which
298     // test burst using "compute") had already been locked down and could not be
299     // changed.
300     /**
301      * Execute a request on a model.
302      *
303      * @param request Arguments to be executed on a model.
304      * @param measure Whether to collect timing measurements, either YES or NO
305      * @param memoryIds Identifiers corresponding to each memory object in the
306      *     request's pools.
307      * @return A tuple of:
308      *     - status of the execution
309      *     - dynamic output shapes from the execution
310      *     - any execution time measurements of the execution
311      *     - whether or not a failed burst execution should be re-run using a
312      *       different path (e.g., IPreparedModel::executeSynchronously)
313      */
314     std::tuple<ErrorStatus, std::vector<OutputShape>, Timing, bool> tryCompute(
315             const Request& request, MeasureTiming measure, const std::vector<intptr_t>& memoryIds);
316 
317     /**
318      * Propagate a user's freeing of memory to the service.
319      *
320      * @param key Key corresponding to the memory object.
321      */
322     void freeMemory(intptr_t key);
323 
324    private:
325     std::mutex mMutex;
326     const std::shared_ptr<RequestChannelSender> mRequestChannelSender;
327     const std::shared_ptr<ResultChannelReceiver> mResultChannelReceiver;
328     const sp<IBurstContext> mBurstContext;
329     const sp<ExecutionBurstCallback> mMemoryCache;
330     const sp<hardware::hidl_death_recipient> mDeathHandler;
331 };
332 
333 }  // namespace android::nn
334 
335 #endif  // ANDROID_ML_NN_RUNTIME_EXECUTION_BURST_CONTROLLER_H
336