1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_1_2_UTILS_EXECUTION_BURST_UTILS_H
18 #define ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_1_2_UTILS_EXECUTION_BURST_UTILS_H
19 
20 #include <android/hardware/neuralnetworks/1.0/types.h>
21 #include <android/hardware/neuralnetworks/1.2/types.h>
22 #include <fmq/MessageQueue.h>
23 #include <hidl/MQDescriptor.h>
24 #include <nnapi/Result.h>
25 #include <nnapi/Types.h>
26 #include <nnapi/hal/ProtectCallback.h>
27 
28 #include <atomic>
29 #include <chrono>
30 #include <memory>
31 #include <tuple>
32 #include <utility>
33 #include <vector>
34 
35 namespace android::hardware::neuralnetworks::V1_2::utils {
36 
37 /**
38  * Number of elements in the FMQ.
39  */
40 constexpr const size_t kExecutionBurstChannelLength = 1024;
41 
42 /**
43  * Get how long the burst controller should poll while waiting for results to be returned.
44  *
45  * This time can be affected by the property "debug.nn.burst-controller-polling-window".
46  *
47  * @return Polling time in microseconds.
48  */
49 std::chrono::microseconds getBurstControllerPollingTimeWindow();
50 
51 /**
52  * Get how long the burst server should poll while waiting for a request to be received.
53  *
54  * This time can be affected by the property "debug.nn.burst-server-polling-window".
55  *
56  * @return Polling time in microseconds.
57  */
58 std::chrono::microseconds getBurstServerPollingTimeWindow();
59 
60 /**
61  * Function to serialize a request.
62  *
63  * @param request Request object without the pool information.
64  * @param measure Whether to collect timing information for the execution.
65  * @param memoryIds Slot identifiers corresponding to memory resources for the request.
66  * @return Serialized FMQ request data.
67  */
68 std::vector<FmqRequestDatum> serialize(const V1_0::Request& request, MeasureTiming measure,
69                                        const std::vector<int32_t>& slots);
70 
71 /**
72  * Deserialize the FMQ request data.
73  *
74  * The three resulting fields are the Request object (where Request::pools is empty), slot
75  * identifiers (which are stand-ins for Request::pools), and whether timing information must be
76  * collected for the run.
77  *
78  * @param data Serialized FMQ request data.
79  * @return Request object if successfully deserialized, otherwise an error message.
80  */
81 nn::Result<std::tuple<V1_0::Request, std::vector<int32_t>, MeasureTiming>> deserialize(
82         const std::vector<FmqRequestDatum>& data);
83 
84 /**
85  * Function to serialize results.
86  *
87  * @param errorStatus Status of the execution.
88  * @param outputShapes Dynamic shapes of the output tensors.
89  * @param timing Timing information of the execution.
90  * @return Serialized FMQ result data.
91  */
92 std::vector<FmqResultDatum> serialize(V1_0::ErrorStatus errorStatus,
93                                       const std::vector<OutputShape>& outputShapes, Timing timing);
94 
95 /**
96  * Deserialize the FMQ result data.
97  *
98  * The three resulting fields are the status of the execution, the dynamic shapes of the output
99  * tensors, and the timing information of the execution.
100  *
101  * @param data Serialized FMQ result data.
102  * @return Result object if successfully deserialized, otherwise an error message.
103  */
104 nn::Result<std::tuple<V1_0::ErrorStatus, std::vector<OutputShape>, Timing>> deserialize(
105         const std::vector<FmqResultDatum>& data);
106 
107 /**
108  * RequestChannelSender is responsible for serializing the result packet of information, sending it
109  * on the result channel, and signaling that the data is available.
110  */
111 class RequestChannelSender final : public neuralnetworks::utils::IProtectedCallback {
112     struct PrivateConstructorTag {};
113 
114   public:
115     /**
116      * Create the sending end of a request channel.
117      *
118      * @param channelLength Number of elements in the FMQ.
119      * @return A pair of ResultChannelReceiver and the FMQ descriptor on successful creation,
120      *     GeneralError otherwise.
121      */
122     static nn::GeneralResult<std::pair<std::unique_ptr<RequestChannelSender>,
123                                        const MQDescriptorSync<FmqRequestDatum>*>>
124     create(size_t channelLength);
125 
126     /**
127      * Send the request to the channel.
128      *
129      * @param request Request object without the pool information.
130      * @param measure Whether to collect timing information for the execution.
131      * @param slots Slot identifiers corresponding to memory resources for the request.
132      * @return An empty `Result` on successful send, otherwise an error message.
133      */
134     nn::Result<void> send(const V1_0::Request& request, MeasureTiming measure,
135                           const std::vector<int32_t>& slots);
136 
137     /**
138      * Method to mark the channel as invalid, causing all future calls to RequestChannelSender::send
139      * to immediately return false without attempting to send a message across the FMQ.
140      */
141     void notifyAsDeadObject() override;
142 
143     // prefer calling RequestChannelSender::send
144     nn::Result<void> sendPacket(const std::vector<FmqRequestDatum>& packet);
145 
146     RequestChannelSender(PrivateConstructorTag tag, size_t channelLength);
147 
148   private:
149     MessageQueue<FmqRequestDatum, kSynchronizedReadWrite> mFmqRequestChannel;
150     std::atomic<bool> mValid{true};
151 };
152 
153 /**
154  * RequestChannelReceiver is responsible for waiting on the channel until the packet is available,
155  * extracting the packet from the channel, and deserializing the packet.
156  *
157  * Because the receiver can wait on a packet that may never come (e.g., because the sending side of
158  * the packet has been closed), this object can be invalidated, unblocking the receiver.
159  */
160 class RequestChannelReceiver final {
161     struct PrivateConstructorTag {};
162 
163   public:
164     /**
165      * Create the receiving end of a request channel.
166      *
167      * @param requestChannel Descriptor for the request channel.
168      * @param pollingTimeWindow How much time (in microseconds) the RequestChannelReceiver is
169      *     allowed to poll the FMQ before waiting on the blocking futex. Polling may result in lower
170      *     latencies at the potential cost of more power usage.
171      * @return RequestChannelReceiver on successful creation, nullptr otherwise.
172      */
173     static nn::GeneralResult<std::unique_ptr<RequestChannelReceiver>> create(
174             const MQDescriptorSync<FmqRequestDatum>& requestChannel,
175             std::chrono::microseconds pollingTimeWindow);
176 
177     /**
178      * Get the request from the channel.
179      *
180      * This method will block until either:
181      * 1) The packet has been retrieved, or
182      * 2) The receiver has been invalidated
183      *
184      * @return Request object if successfully received, an appropriate message if error or if the
185      *     receiver object was invalidated.
186      */
187     nn::Result<std::tuple<V1_0::Request, std::vector<int32_t>, MeasureTiming>> getBlocking();
188 
189     /**
190      * Method to mark the channel as invalid, unblocking any current or future calls to
191      * RequestChannelReceiver::getBlocking.
192      */
193     void invalidate();
194 
195     RequestChannelReceiver(PrivateConstructorTag tag,
196                            const MQDescriptorSync<FmqRequestDatum>& requestChannel,
197                            std::chrono::microseconds pollingTimeWindow);
198 
199   private:
200     nn::Result<std::vector<FmqRequestDatum>> getPacketBlocking();
201 
202     MessageQueue<FmqRequestDatum, kSynchronizedReadWrite> mFmqRequestChannel;
203     std::atomic<bool> mTeardown{false};
204     const std::chrono::microseconds kPollingTimeWindow;
205 };
206 
207 /**
208  * ResultChannelSender is responsible for serializing the result packet of information, sending it
209  * on the result channel, and signaling that the data is available.
210  */
211 class ResultChannelSender final {
212     struct PrivateConstructorTag {};
213 
214   public:
215     /**
216      * Create the sending end of a result channel.
217      *
218      * @param resultChannel Descriptor for the result channel.
219      * @return ResultChannelSender on successful creation, nullptr otherwise.
220      */
221     static nn::GeneralResult<std::unique_ptr<ResultChannelSender>> create(
222             const MQDescriptorSync<FmqResultDatum>& resultChannel);
223 
224     /**
225      * Send the result to the channel.
226      *
227      * @param errorStatus Status of the execution.
228      * @param outputShapes Dynamic shapes of the output tensors.
229      * @param timing Timing information of the execution.
230      */
231     void send(V1_0::ErrorStatus errorStatus, const std::vector<OutputShape>& outputShapes,
232               Timing timing);
233 
234     // prefer calling ResultChannelSender::send
235     void sendPacket(const std::vector<FmqResultDatum>& packet);
236 
237     ResultChannelSender(PrivateConstructorTag tag,
238                         const MQDescriptorSync<FmqResultDatum>& resultChannel);
239 
240   private:
241     MessageQueue<FmqResultDatum, kSynchronizedReadWrite> mFmqResultChannel;
242 };
243 
244 /**
245  * ResultChannelReceiver is responsible for waiting on the channel until the packet is available,
246  * extracting the packet from the channel, and deserializing the packet.
247  *
248  * Because the receiver can wait on a packet that may never come (e.g., because the sending side of
249  * the packet has been closed), this object can be invalidated, unblocking the receiver.
250  */
251 class ResultChannelReceiver final : public neuralnetworks::utils::IProtectedCallback {
252     struct PrivateConstructorTag {};
253 
254   public:
255     /**
256      * Create the receiving end of a result channel.
257      *
258      * @param channelLength Number of elements in the FMQ.
259      * @param pollingTimeWindow How much time (in microseconds) the ResultChannelReceiver is allowed
260      *     to poll the FMQ before waiting on the blocking futex. Polling may result in lower
261      *     latencies at the potential cost of more power usage.
262      * @return A pair of ResultChannelReceiver and the FMQ descriptor on successful creation, or
263      *     GeneralError otherwise.
264      */
265     static nn::GeneralResult<std::pair<std::unique_ptr<ResultChannelReceiver>,
266                                        const MQDescriptorSync<FmqResultDatum>*>>
267     create(size_t channelLength, std::chrono::microseconds pollingTimeWindow);
268 
269     /**
270      * Get the result from the channel.
271      *
272      * This method will block until either:
273      * 1) The packet has been retrieved, or
274      * 2) The receiver has been invalidated
275      *
276      * @return Result object if successfully received, otherwise an appropriate message if error or
277      *     if the receiver object was invalidated.
278      */
279     nn::Result<std::tuple<V1_0::ErrorStatus, std::vector<OutputShape>, Timing>> getBlocking();
280 
281     /**
282      * Method to mark the channel as invalid, unblocking any current or future calls to
283      * ResultChannelReceiver::getBlocking.
284      */
285     void notifyAsDeadObject() override;
286 
287     // prefer calling ResultChannelReceiver::getBlocking
288     nn::Result<std::vector<FmqResultDatum>> getPacketBlocking();
289 
290     ResultChannelReceiver(PrivateConstructorTag tag, size_t channelLength,
291                           std::chrono::microseconds pollingTimeWindow);
292 
293   private:
294     MessageQueue<FmqResultDatum, kSynchronizedReadWrite> mFmqResultChannel;
295     std::atomic<bool> mValid{true};
296     const std::chrono::microseconds kPollingTimeWindow;
297 };
298 
299 }  // namespace android::hardware::neuralnetworks::V1_2::utils
300 
301 #endif  // ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_1_2_UTILS_EXECUTION_BURST_UTILS_H
302