1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef ANDROID_PACKAGES_MODULES_NEURALNETWORKS_RUNTIME_MANAGER_H 18 #define ANDROID_PACKAGES_MODULES_NEURALNETWORKS_RUNTIME_MANAGER_H 19 20 #include <LegacyUtils.h> 21 #include <android-base/macros.h> 22 #include <nnapi/IBurst.h> 23 #include <nnapi/IDevice.h> 24 #include <nnapi/Types.h> 25 26 #include <functional> 27 #include <map> 28 #include <memory> 29 #include <string> 30 #include <tuple> 31 #include <unordered_set> 32 #include <utility> 33 #include <vector> 34 35 #include "ExecutionCallback.h" 36 #include "Memory.h" 37 38 namespace android { 39 namespace nn { 40 41 // Forward declaration 42 class Device; 43 class MetaModel; 44 class ModelArgumentInfo; 45 46 // A unified interface for a reusable execution with cached resources. 47 // This object provides no thread-safety guarantee. The caller must guarantee there is at most one 48 // call to RuntimeExecution::compute or RuntimeExecution::computeFenced on the same RuntimeExecution 49 // object in flight at a time. 50 class RuntimeExecution { 51 DISALLOW_COPY_AND_ASSIGN(RuntimeExecution); 52 53 public: 54 RuntimeExecution() = default; 55 virtual ~RuntimeExecution() = default; 56 57 virtual std::tuple<int, std::vector<OutputShape>, Timing> compute( 58 const SharedBurst& burstController, const OptionalTimePoint& deadline) const = 0; 59 60 // The returned timing information is only valid if the callback is nullptr. 61 // Returns error_code, sync_fence, callback and timing. 62 virtual std::tuple<int, int, ExecuteFencedInfoCallback, Timing> computeFenced( 63 const std::vector<int>& waitFor, const OptionalTimePoint& deadline, 64 const OptionalDuration& timeoutDurationAfterFence) const = 0; 65 }; 66 67 // A unified interface for actual driver prepared model as well as the CPU. 68 class RuntimePreparedModel { 69 DISALLOW_COPY_AND_ASSIGN(RuntimePreparedModel); 70 71 public: 72 RuntimePreparedModel() = default; 73 virtual ~RuntimePreparedModel() = default; 74 75 virtual const Device* getDevice() const = 0; 76 virtual SharedPreparedModel getInterface() const = 0; 77 78 // Perform computation with given input/output argument info and memory pools. 79 virtual std::tuple<int, std::vector<OutputShape>, Timing> execute( 80 const std::vector<ModelArgumentInfo>& inputs, 81 const std::vector<ModelArgumentInfo>& outputs, 82 const std::vector<const RuntimeMemory*>& memories, const SharedBurst& burstController, 83 MeasureTiming measure, const OptionalTimePoint& deadline, 84 const OptionalDuration& loopTimeoutDuration, 85 const std::vector<TokenValuePair>& metaData) const = 0; 86 87 // Perform fenced computation with given input/output argument info and memory pools. 88 // The returned timing information is only valid if the callback is nullptr. 89 // Returns error_code, sync_fence, callback and timing. 90 virtual std::tuple<int, int, ExecuteFencedInfoCallback, Timing> executeFenced( 91 const std::vector<ModelArgumentInfo>& inputs, 92 const std::vector<ModelArgumentInfo>& outputs, 93 const std::vector<const RuntimeMemory*>& memories, const std::vector<int>& waitFor, 94 MeasureTiming measure, const OptionalTimePoint& deadline, 95 const OptionalDuration& loopTimeoutDuration, 96 const OptionalDuration& timeoutDurationAfterFence, 97 const std::vector<TokenValuePair>& metaData) const = 0; 98 99 // Create a reusable execution with given input/output argument info and memory pools. 100 virtual std::pair<int, std::shared_ptr<RuntimeExecution>> createReusableExecution( 101 const std::vector<ModelArgumentInfo>& inputs, 102 const std::vector<ModelArgumentInfo>& outputs, 103 const std::vector<const RuntimeMemory*>& memories, MeasureTiming measure, 104 const OptionalDuration& loopTimeoutDuration, 105 const std::vector<TokenValuePair>& metaData) const = 0; 106 107 virtual GeneralResult<SharedBurst> configureExecutionBurst() const = 0; 108 109 virtual MemoryPreference getMemoryPreference() const = 0; 110 }; 111 112 using ModelFactory = std::function<Model()>; 113 114 struct CacheHandles { 115 std::vector<SharedHandle> modelCache; 116 std::vector<SharedHandle> dataCache; 117 }; 118 119 using CacheDir = std::string; 120 121 struct CacheInfo { 122 std::variant<CacheDir, CacheHandles> variant; 123 }; 124 125 // A unified interface for actual driver devices as well as the CPU 126 class Device { 127 DISALLOW_COPY_AND_ASSIGN(Device); 128 129 public: 130 Device() = default; 131 virtual ~Device() = default; 132 133 // Introspection methods returning device information 134 virtual const std::string& getName() const = 0; 135 virtual const std::string& getVersionString() const = 0; 136 virtual Version getFeatureLevel() const = 0; 137 virtual int32_t getType() const = 0; 138 virtual const std::vector<Extension>& getSupportedExtensions() const = 0; 139 140 // See the MetaModel class in MetaModel.h for more details. 141 virtual std::vector<bool> getSupportedOperations(const MetaModel& metaModel) const = 0; 142 143 virtual const Capabilities& getCapabilities() const = 0; 144 virtual Capabilities::PerformanceInfo getPerformance(OperandType type) const = 0; 145 virtual Capabilities::PerformanceInfo getRelaxedFloat32toFloat16PerformanceScalar() const = 0; 146 virtual Capabilities::PerformanceInfo getRelaxedFloat32toFloat16PerformanceTensor() const = 0; 147 virtual Capabilities::PerformanceInfo getIfPerformance() const = 0; 148 virtual Capabilities::PerformanceInfo getWhilePerformance() const = 0; 149 virtual std::pair<uint32_t, uint32_t> getNumberOfCacheFilesNeeded() const = 0; 150 virtual bool isCachingSupported() const = 0; 151 virtual int wait() const = 0; 152 153 virtual std::pair<int, std::shared_ptr<RuntimePreparedModel>> prepareModel( 154 const ModelFactory& makeModel, ExecutionPreference preference, Priority priority, 155 const OptionalTimePoint& deadline, const CacheInfo& cacheInfo, 156 const std::optional<CacheToken>& maybeToken, 157 const std::vector<TokenValuePair>& metaData, 158 const std::vector<ExtensionNameAndPrefix>& extensionNameAndPrefix) const = 0; 159 160 // The caller is responsible for making sure the MemoryDescriptor only contains 161 // PreparedModels from the same Device. 162 virtual std::pair<int, std::unique_ptr<RuntimeMemory>> allocate(const MemoryDescriptor& desc, 163 OperandType type) const = 0; 164 }; 165 166 // Manages the NN HAL devices. Only one instance of this class will exist. 167 // Use get() to retrieve it. 168 class DeviceManager { 169 public: getDrivers()170 const std::vector<std::shared_ptr<Device>>& getDrivers() const { 171 if (mSetCpuOnly || mDebugNNCpuOnly) { 172 return mDevicesCpuOnly; 173 } 174 return mDevices; 175 } 176 177 // Gets the runtime version corresponding to getServerFeatureLevelFlag (in ServerFlag.h). getRuntimeVersion()178 Version getRuntimeVersion() const { return mRuntimeVersion; } 179 180 // Gets the runtime feature level corresponding to getServerFeatureLevelFlag (in ServerFlag.h). 181 int64_t getRuntimeFeatureLevel() const; 182 183 // Convert the internal Version level representation to the NDK representation. 184 static int64_t versionToFeatureLevel(Version::Level versionLevel); 185 186 // Returns whether platform telemetry is enabled. isPlatformTelemetryEnabled()187 bool isPlatformTelemetryEnabled() const { return mIsPlatformTelemetryEnabled; } 188 189 // For testing only: setUseCpuOnly(bool useCpuOnly)190 void setUseCpuOnly(bool useCpuOnly) { mSetCpuOnly = useCpuOnly; } getUseCpuOnly()191 bool getUseCpuOnly() const { return mSetCpuOnly; } 192 syncExecCpu()193 bool syncExecCpu() const { return mSyncExecCpu; } syncExecRuntime()194 bool syncExecRuntime() const { return mSyncExecRuntime; } 195 196 // How to handle graph partitioning? 197 // 0 - Don't do graph partitioning. 198 // 1 - Do graph partitioning; but fall back to non-partitioned 199 // execution if there is a partitioning failure. 200 // 2 - Do graph partitioning, and rely on it; there is no fallback. 201 enum { kPartitioningNo = 0, kPartitioningWithFallback = 1, kPartitioningWithoutFallback = 2 }; getPartitioning()202 uint32_t getPartitioning() const { return mPartitioning; } partitioningAllowsFallback(uint32_t partitioning)203 static bool partitioningAllowsFallback(uint32_t partitioning) { 204 return partitioning == kPartitioningWithFallback; 205 } 206 strictSlicing()207 bool strictSlicing() const { return mStrictSlicing; } 208 209 // Returns the singleton manager. 210 static DeviceManager* get(); 211 212 // Returns the singleton Cpu device. 213 static std::shared_ptr<Device> getCpuDevice(); 214 215 // The forTest_* functions below are solely intended for use by unit tests. 216 217 // Returns all devices (ignores the cpu-only flags). forTest_getDevices()218 std::vector<std::shared_ptr<Device>> forTest_getDevices() const { return mDevices; } 219 220 // Sets the device list (does not affect cpu-only queries). forTest_setDevices(std::vector<std::shared_ptr<Device>> devices)221 void forTest_setDevices(std::vector<std::shared_ptr<Device>> devices) { 222 mDevices = std::move(devices); 223 } 224 225 // Register a test device. forTest_registerDevice(const SharedDevice & device)226 void forTest_registerDevice(const SharedDevice& device) { registerDevice(device); } 227 228 // Re-initialize the list of available devices. forTest_reInitializeDeviceList()229 void forTest_reInitializeDeviceList() { 230 mDevices.clear(); 231 mDevicesCpuOnly.clear(); 232 findAvailableDevices(); 233 } 234 235 // Make a test device 236 static std::shared_ptr<Device> forTest_makeDriverDevice(const SharedDevice& device); 237 forTest_isCpuDevice(const ANeuralNetworksDevice * device)238 bool forTest_isCpuDevice(const ANeuralNetworksDevice* device) const { 239 return reinterpret_cast<const Device*>(device) == getCpuDevice().get(); 240 } 241 242 private: 243 // Builds the list of available drivers and queries their capabilities. 244 DeviceManager(); 245 246 // Adds a device for the manager to use. 247 void registerDevice(const SharedDevice& device); 248 249 void findAvailableDevices(); 250 251 // Runtime version corresponding to getServerFeatureLevelFlag (in ServerFlag.h). 252 Version mRuntimeVersion; 253 254 // Holds whether platform telemetry is enabled, as indicated by getServerTelemetryEnableFlag (in 255 // ServerFlag.h). 256 bool mIsPlatformTelemetryEnabled; 257 258 // List of all the devices we discovered (including CpuDevice). 259 std::vector<std::shared_ptr<Device>> mDevices; 260 261 // We set this one to have CpuDevice only. To be used when m*CpuOnly is true. 262 std::vector<std::shared_ptr<Device>> mDevicesCpuOnly; 263 264 // If either of these is true, we'll ignore the drivers that are 265 // on the device and run everything on the CPU. 266 bool mSetCpuOnly = false; // set by setUseCpuOnly() 267 bool mDebugNNCpuOnly = false; // derived from system property debug.nn.cpuonly 268 269 // synchronous execution 270 bool mSyncExecCpu = true; 271 bool mSyncExecRuntime = false; 272 273 static const uint32_t kPartitioningDefault = kPartitioningWithFallback; 274 uint32_t mPartitioning = kPartitioningDefault; 275 276 bool mStrictSlicing = false; 277 }; 278 279 std::vector<SharedDevice> getDevices(); 280 281 } // namespace nn 282 } // namespace android 283 284 #endif // ANDROID_PACKAGES_MODULES_NEURALNETWORKS_RUNTIME_MANAGER_H 285