1 /* 2 * Copyright (C) 2019 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef ANDROID_ML_NN_RUNTIME_EXECUTION_BURST_CONTROLLER_H 18 #define ANDROID_ML_NN_RUNTIME_EXECUTION_BURST_CONTROLLER_H 19 20 #include "HalInterfaces.h" 21 22 #include <android-base/macros.h> 23 #include <fmq/MessageQueue.h> 24 #include <hidl/MQDescriptor.h> 25 26 #include <atomic> 27 #include <map> 28 #include <memory> 29 #include <mutex> 30 #include <stack> 31 #include <tuple> 32 33 namespace android::nn { 34 35 /** 36 * Number of elements in the FMQ. 37 */ 38 constexpr const size_t kExecutionBurstChannelLength = 1024; 39 40 /** 41 * Function to serialize a request. 42 * 43 * Prefer calling RequestChannelSender::send. 44 * 45 * @param request Request object without the pool information. 46 * @param measure Whether to collect timing information for the execution. 47 * @param memoryIds Slot identifiers corresponding to memory resources for the 48 * request. 49 * @return Serialized FMQ request data. 50 */ 51 std::vector<FmqRequestDatum> serialize(const Request& request, MeasureTiming measure, 52 const std::vector<int32_t>& slots); 53 54 /** 55 * Deserialize the FMQ result data. 56 * 57 * The three resulting fields are the status of the execution, the dynamic 58 * shapes of the output tensors, and the timing information of the execution. 59 * 60 * @param data Serialized FMQ result data. 61 * @return Result object if successfully deserialized, std::nullopt otherwise. 62 */ 63 std::optional<std::tuple<ErrorStatus, std::vector<OutputShape>, Timing>> deserialize( 64 const std::vector<FmqResultDatum>& data); 65 66 /** 67 * ResultChannelReceiver is responsible for waiting on the channel until the 68 * packet is available, extracting the packet from the channel, and 69 * deserializing the packet. 70 * 71 * Because the receiver can wait on a packet that may never come (e.g., because 72 * the sending side of the packet has been closed), this object can be 73 * invalidating, unblocking the receiver. 74 */ 75 class ResultChannelReceiver { 76 using FmqResultDescriptor = ::android::hardware::MQDescriptorSync<FmqResultDatum>; 77 using FmqResultChannel = 78 hardware::MessageQueue<FmqResultDatum, hardware::kSynchronizedReadWrite>; 79 80 public: 81 /** 82 * Create the receiving end of a result channel. 83 * 84 * Prefer this call over the constructor. 85 * 86 * @param channelLength Number of elements in the FMQ. 87 * @param blocking 'true' if FMQ should use futex, 'false' if it should 88 * spin-wait. 89 * @return A pair of ResultChannelReceiver and the FMQ descriptor on 90 * successful creation, both nullptr otherwise. 91 */ 92 static std::pair<std::unique_ptr<ResultChannelReceiver>, const FmqResultDescriptor*> create( 93 size_t channelLength, bool blocking); 94 95 /** 96 * Get the result from the channel. 97 * 98 * This method will block until either: 99 * 1) The packet has been retrieved, or 100 * 2) The receiver has been invalidated 101 * 102 * @return Result object if successfully received, std::nullopt if error or 103 * if the receiver object was invalidated. 104 */ 105 std::optional<std::tuple<ErrorStatus, std::vector<OutputShape>, Timing>> getBlocking(); 106 107 /** 108 * Method to mark the channel as invalid, unblocking any current or future 109 * calls to ResultChannelReceiver::getBlocking. 110 */ 111 void invalidate(); 112 113 // prefer calling ResultChannelReceiver::getBlocking 114 std::optional<std::vector<FmqResultDatum>> getPacketBlocking(); 115 116 ResultChannelReceiver(std::unique_ptr<FmqResultChannel> fmqResultChannel, bool blocking); 117 118 private: 119 const std::unique_ptr<FmqResultChannel> mFmqResultChannel; 120 std::atomic<bool> mValid{true}; 121 const bool mBlocking; 122 }; 123 124 /** 125 * RequestChannelSender is responsible for serializing the result packet of 126 * information, sending it on the result channel, and signaling that the data is 127 * available. 128 */ 129 class RequestChannelSender { 130 using FmqRequestDescriptor = ::android::hardware::MQDescriptorSync<FmqRequestDatum>; 131 using FmqRequestChannel = 132 hardware::MessageQueue<FmqRequestDatum, hardware::kSynchronizedReadWrite>; 133 134 public: 135 /** 136 * Create the sending end of a request channel. 137 * 138 * Prefer this call over the constructor. 139 * 140 * @param channelLength Number of elements in the FMQ. 141 * @param blocking 'true' if FMQ should use futex, 'false' if it should 142 * spin-wait. 143 * @return A pair of ResultChannelReceiver and the FMQ descriptor on 144 * successful creation, both nullptr otherwise. 145 */ 146 static std::pair<std::unique_ptr<RequestChannelSender>, const FmqRequestDescriptor*> create( 147 size_t channelLength, bool blocking); 148 149 /** 150 * Send the request to the channel. 151 * 152 * @param request Request object without the pool information. 153 * @param measure Whether to collect timing information for the execution. 154 * @param memoryIds Slot identifiers corresponding to memory resources for 155 * the request. 156 * @return 'true' on successful send, 'false' otherwise. 157 */ 158 bool send(const Request& request, MeasureTiming measure, const std::vector<int32_t>& slots); 159 160 /** 161 * Method to mark the channel as invalid, causing all future calls to 162 * RequestChannelSender::send to immediately return false without attempting 163 * to send a message across the FMQ. 164 */ 165 void invalidate(); 166 167 // prefer calling RequestChannelSender::send 168 bool sendPacket(const std::vector<FmqRequestDatum>& packet); 169 170 RequestChannelSender(std::unique_ptr<FmqRequestChannel> fmqRequestChannel, bool blocking); 171 172 private: 173 const std::unique_ptr<FmqRequestChannel> mFmqRequestChannel; 174 std::atomic<bool> mValid{true}; 175 const bool mBlocking; 176 }; 177 178 /** 179 * The ExecutionBurstController class manages both the serialization and 180 * deserialization of data across FMQ, making it appear to the runtime as a 181 * regular synchronous inference. Additionally, this class manages the burst's 182 * memory cache. 183 */ 184 class ExecutionBurstController { 185 DISALLOW_IMPLICIT_CONSTRUCTORS(ExecutionBurstController); 186 187 public: 188 /** 189 * NN runtime burst callback object and memory cache. 190 * 191 * ExecutionBurstCallback associates a hidl_memory object with a slot number 192 * to be passed across FMQ. The ExecutionBurstServer can use this callback 193 * to retrieve this hidl_memory corresponding to the slot via HIDL. 194 * 195 * Whenever a hidl_memory object is copied, it will duplicate the underlying 196 * file descriptor. Because the NN runtime currently copies the hidl_memory 197 * on each execution, it is difficult to associate hidl_memory objects with 198 * previously cached hidl_memory objects. For this reason, callers of this 199 * class must pair each hidl_memory object with an associated key. For 200 * efficiency, if two hidl_memory objects represent the same underlying 201 * buffer, they must use the same key. 202 */ 203 class ExecutionBurstCallback : public IBurstCallback { 204 DISALLOW_COPY_AND_ASSIGN(ExecutionBurstCallback); 205 206 public: 207 ExecutionBurstCallback() = default; 208 209 Return<void> getMemories(const hidl_vec<int32_t>& slots, getMemories_cb cb) override; 210 211 /** 212 * This function performs one of two different actions: 213 * 1) If a key corresponding to a memory resource is unrecognized by the 214 * ExecutionBurstCallback object, the ExecutionBurstCallback object 215 * will allocate a slot, bind the memory to the slot, and return the 216 * slot identifier. 217 * 2) If a key corresponding to a memory resource is recognized by the 218 * ExecutionBurstCallback object, the ExecutionBurstCallback object 219 * will return the existing slot identifier. 220 * 221 * @param memories Memory resources used in an inference. 222 * @param keys Unique identifiers where each element corresponds to a 223 * memory resource element in "memories". 224 * @return Unique slot identifiers where each returned slot element 225 * corresponds to a memory resource element in "memories". 226 */ 227 std::vector<int32_t> getSlots(const hidl_vec<hidl_memory>& memories, 228 const std::vector<intptr_t>& keys); 229 230 /* 231 * This function performs two different actions: 232 * 1) Removes an entry from the cache (if present), including the local 233 * storage of the hidl_memory object. Note that this call does not 234 * free any corresponding hidl_memory object in ExecutionBurstServer, 235 * which is separately freed via IBurstContext::freeMemory. 236 * 2) Return whether a cache entry was removed and which slot was removed if 237 * found. If the key did not to correspond to any entry in the cache, a 238 * slot number of 0 is returned. The slot number and whether the entry 239 * existed is useful so the same slot can be freed in the 240 * ExecutionBurstServer's cache via IBurstContext::freeMemory. 241 */ 242 std::pair<bool, int32_t> freeMemory(intptr_t key); 243 244 private: 245 int32_t getSlotLocked(const hidl_memory& memory, intptr_t key); 246 int32_t allocateSlotLocked(); 247 248 std::mutex mMutex; 249 std::stack<int32_t, std::vector<int32_t>> mFreeSlots; 250 std::map<intptr_t, int32_t> mMemoryIdToSlot; 251 std::vector<hidl_memory> mMemoryCache; 252 }; 253 254 /** 255 * Creates a burst controller on a prepared model. 256 * 257 * Prefer this over ExecutionBurstController's constructor. 258 * 259 * @param preparedModel Model prepared for execution to execute on. 260 * @param blocking 'true' if the FMQ should use a futex to perform blocking 261 * until data is available in a less responsive, but more energy 262 * efficient manner. 'false' if the FMQ should use spin-looping to 263 * wait until data is available in a more responsive, but less energy 264 * efficient manner. 265 * @return ExecutionBurstController Execution burst controller object. 266 */ 267 static std::unique_ptr<ExecutionBurstController> create(const sp<IPreparedModel>& preparedModel, 268 bool blocking); 269 270 // prefer calling ExecutionBurstController::create 271 ExecutionBurstController(const std::shared_ptr<RequestChannelSender>& requestChannelSender, 272 const std::shared_ptr<ResultChannelReceiver>& resultChannelReceiver, 273 const sp<IBurstContext>& burstContext, 274 const sp<ExecutionBurstCallback>& callback, 275 const sp<hardware::hidl_death_recipient>& deathHandler = nullptr); 276 277 // explicit destructor to unregister the death recipient 278 ~ExecutionBurstController(); 279 280 /** 281 * Execute a request on a model. 282 * 283 * @param request Arguments to be executed on a model. 284 * @param measure Whether to collect timing measurements, either YES or NO 285 * @param memoryIds Identifiers corresponding to each memory object in the 286 * request's pools. 287 * @return A tuple of: 288 * - status of the execution 289 * - dynamic output shapes from the execution 290 * - any execution time measurements of the execution 291 */ 292 std::tuple<ErrorStatus, std::vector<OutputShape>, Timing> compute( 293 const Request& request, MeasureTiming measure, const std::vector<intptr_t>& memoryIds); 294 295 // TODO: combine "compute" and "tryCompute" back into a single function. 296 // "tryCompute" was created later to return the "fallback" boolean. This 297 // could not be done directly in "compute" because the VTS test cases (which 298 // test burst using "compute") had already been locked down and could not be 299 // changed. 300 /** 301 * Execute a request on a model. 302 * 303 * @param request Arguments to be executed on a model. 304 * @param measure Whether to collect timing measurements, either YES or NO 305 * @param memoryIds Identifiers corresponding to each memory object in the 306 * request's pools. 307 * @return A tuple of: 308 * - status of the execution 309 * - dynamic output shapes from the execution 310 * - any execution time measurements of the execution 311 * - whether or not a failed burst execution should be re-run using a 312 * different path (e.g., IPreparedModel::executeSynchronously) 313 */ 314 std::tuple<ErrorStatus, std::vector<OutputShape>, Timing, bool> tryCompute( 315 const Request& request, MeasureTiming measure, const std::vector<intptr_t>& memoryIds); 316 317 /** 318 * Propagate a user's freeing of memory to the service. 319 * 320 * @param key Key corresponding to the memory object. 321 */ 322 void freeMemory(intptr_t key); 323 324 private: 325 std::mutex mMutex; 326 const std::shared_ptr<RequestChannelSender> mRequestChannelSender; 327 const std::shared_ptr<ResultChannelReceiver> mResultChannelReceiver; 328 const sp<IBurstContext> mBurstContext; 329 const sp<ExecutionBurstCallback> mMemoryCache; 330 const sp<hardware::hidl_death_recipient> mDeathHandler; 331 }; 332 333 } // namespace android::nn 334 335 #endif // ANDROID_ML_NN_RUNTIME_EXECUTION_BURST_CONTROLLER_H 336