1 /* 2 * Copyright (C) 2019 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_1_2_UTILS_EXECUTION_BURST_CONTROLLER_H 18 #define ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_1_2_UTILS_EXECUTION_BURST_CONTROLLER_H 19 20 #include "ExecutionBurstUtils.h" 21 22 #include <android-base/thread_annotations.h> 23 #include <android/hardware/neuralnetworks/1.0/types.h> 24 #include <android/hardware/neuralnetworks/1.2/IBurstCallback.h> 25 #include <android/hardware/neuralnetworks/1.2/IBurstContext.h> 26 #include <android/hardware/neuralnetworks/1.2/IPreparedModel.h> 27 #include <android/hardware/neuralnetworks/1.2/types.h> 28 #include <fmq/MessageQueue.h> 29 #include <hidl/MQDescriptor.h> 30 #include <nnapi/IBurst.h> 31 #include <nnapi/IExecution.h> 32 #include <nnapi/IPreparedModel.h> 33 #include <nnapi/Result.h> 34 #include <nnapi/Types.h> 35 #include <nnapi/hal/CommonUtils.h> 36 #include <nnapi/hal/ProtectCallback.h> 37 38 #include <atomic> 39 #include <chrono> 40 #include <functional> 41 #include <map> 42 #include <memory> 43 #include <mutex> 44 #include <stack> 45 #include <tuple> 46 #include <utility> 47 #include <vector> 48 49 namespace android::hardware::neuralnetworks::V1_2::utils { 50 51 /** 52 * The ExecutionBurstController class manages both the serialization and deserialization of data 53 * across FMQ, making it appear to the runtime as a regular synchronous inference. Additionally, 54 * this class manages the burst's memory cache. 55 */ 56 class ExecutionBurstController final 57 : public nn::IBurst, 58 public std::enable_shared_from_this<ExecutionBurstController> { 59 struct PrivateConstructorTag {}; 60 61 public: 62 using FallbackFunction = std::function< 63 nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>>()>; 64 65 /** 66 * NN runtime memory cache. 67 * 68 * MemoryCache associates a Memory object with a slot number to be passed across FMQ. The 69 * ExecutionBurstServer can use this callback to retrieve a hidl_memory corresponding to the 70 * slot via HIDL. 71 * 72 * Whenever a hidl_memory object is copied, it will duplicate the underlying file descriptor. 73 * Because the NN runtime currently copies the hidl_memory on each execution, it is difficult to 74 * associate hidl_memory objects with previously cached hidl_memory objects. For this reason, 75 * callers of this class must pair each hidl_memory object with an associated key. For 76 * efficiency, if two hidl_memory objects represent the same underlying buffer, they must use 77 * the same key. 78 * 79 * This class is thread-safe. 80 */ 81 class MemoryCache : public std::enable_shared_from_this<MemoryCache> { 82 struct PrivateConstructorTag {}; 83 84 public: 85 using Task = std::function<void()>; 86 using Cleanup = base::ScopeGuard<Task>; 87 using SharedCleanup = std::shared_ptr<const Cleanup>; 88 using WeakCleanup = std::weak_ptr<const Cleanup>; 89 90 // Custom constructor to pre-allocate cache sizes. 91 MemoryCache(); 92 93 /** 94 * Add a burst context to the MemoryCache object. 95 * 96 * If this method is called, it must be called before the MemoryCache::cacheMemory or 97 * MemoryCache::getMemory is used. 98 * 99 * @param burstContext Burst context to be added to the MemoryCache object. 100 */ 101 void setBurstContext(sp<IBurstContext> burstContext); 102 103 /** 104 * Cache a memory object in the MemoryCache object. 105 * 106 * @param memory Memory object to be cached while the returned `SharedCleanup` is alive. 107 * @return A pair of (1) a unique identifier for the cache entry and (2) a ref-counted 108 * "hold" object which preserves the cache as long as the hold object is alive. 109 */ 110 std::pair<int32_t, SharedCleanup> cacheMemory(const nn::SharedMemory& memory); 111 112 /** 113 * Get the memory object corresponding to a slot identifier. 114 * 115 * @param slot Slot which identifies the memory object to retrieve. 116 * @return The memory object corresponding to slot, otherwise GeneralError. 117 */ 118 nn::GeneralResult<nn::SharedMemory> getMemory(int32_t slot); 119 120 private: 121 void freeMemory(const nn::SharedMemory& memory); 122 int32_t allocateSlotLocked() REQUIRES(mMutex); 123 124 std::mutex mMutex; 125 std::condition_variable mCond; 126 sp<IBurstContext> mBurstContext GUARDED_BY(mMutex); 127 std::stack<int32_t, std::vector<int32_t>> mFreeSlots GUARDED_BY(mMutex); 128 std::map<nn::SharedMemory, int32_t> mMemoryIdToSlot GUARDED_BY(mMutex); 129 std::vector<nn::SharedMemory> mMemoryCache GUARDED_BY(mMutex); 130 std::vector<WeakCleanup> mCacheCleaner GUARDED_BY(mMutex); 131 }; 132 133 /** 134 * HIDL Callback class to pass memory objects to the Burst server when given corresponding 135 * slots. 136 */ 137 class ExecutionBurstCallback : public IBurstCallback { 138 public: 139 // Precondition: memoryCache must be non-null. 140 explicit ExecutionBurstCallback(const std::shared_ptr<MemoryCache>& memoryCache); 141 142 // See IBurstCallback::getMemories for information on this method. 143 Return<void> getMemories(const hidl_vec<int32_t>& slots, getMemories_cb cb) override; 144 145 private: 146 const std::weak_ptr<MemoryCache> kMemoryCache; 147 }; 148 149 /** 150 * Creates a burst controller on a prepared model. 151 * 152 * @param preparedModel Model prepared for execution to execute on. 153 * @param pollingTimeWindow How much time (in microseconds) the ExecutionBurstController is 154 * allowed to poll the FMQ before waiting on the blocking futex. Polling may result in lower 155 * latencies at the potential cost of more power usage. 156 * @return ExecutionBurstController Execution burst controller object. 157 */ 158 static nn::GeneralResult<std::shared_ptr<const ExecutionBurstController>> create( 159 nn::SharedPreparedModel preparedModel, const sp<IPreparedModel>& hidlPreparedModel, 160 std::chrono::microseconds pollingTimeWindow); 161 162 ExecutionBurstController(PrivateConstructorTag tag, nn::SharedPreparedModel preparedModel, 163 std::unique_ptr<RequestChannelSender> requestChannelSender, 164 std::unique_ptr<ResultChannelReceiver> resultChannelReceiver, 165 sp<ExecutionBurstCallback> callback, sp<IBurstContext> burstContext, 166 std::shared_ptr<MemoryCache> memoryCache, 167 neuralnetworks::utils::DeathHandler deathHandler); 168 169 // See IBurst::cacheMemory for information on this method. 170 OptionalCacheHold cacheMemory(const nn::SharedMemory& memory) const override; 171 172 // See IBurst::execute for information on this method. 173 nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> execute( 174 const nn::Request& request, nn::MeasureTiming measure, 175 const nn::OptionalTimePoint& deadline, 176 const nn::OptionalDuration& loopTimeoutDuration) const override; 177 178 // See IBurst::createReusableExecution for information on this method. 179 nn::GeneralResult<nn::SharedExecution> createReusableExecution( 180 const nn::Request& request, nn::MeasureTiming measure, 181 const nn::OptionalDuration& loopTimeoutDuration) const override; 182 183 // If fallback is not nullptr, this method will invoke the fallback function to try another 184 // execution path if the packet could not be sent. Otherwise, failing to send the packet will 185 // result in an error. 186 nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> executeInternal( 187 const std::vector<FmqRequestDatum>& requestPacket, 188 const hal::utils::RequestRelocation& relocation, FallbackFunction fallback) const; 189 190 private: 191 mutable std::atomic_flag mExecutionInFlight = ATOMIC_FLAG_INIT; 192 const nn::SharedPreparedModel kPreparedModel; 193 const std::unique_ptr<RequestChannelSender> mRequestChannelSender; 194 const std::unique_ptr<ResultChannelReceiver> mResultChannelReceiver; 195 const sp<ExecutionBurstCallback> mBurstCallback; 196 const sp<IBurstContext> mBurstContext; 197 const std::shared_ptr<MemoryCache> mMemoryCache; 198 // `kDeathHandler` must come after `mRequestChannelSender` and `mResultChannelReceiver` because 199 // it holds references to both objects. 200 const neuralnetworks::utils::DeathHandler kDeathHandler; 201 }; 202 203 } // namespace android::hardware::neuralnetworks::V1_2::utils 204 205 #endif // ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_1_2_UTILS_EXECUTION_BURST_CONTROLLER_H 206