1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef ANDROID_PACKAGES_MODULES_NEURALNETWORKS_RUNTIME_MANAGER_H
18 #define ANDROID_PACKAGES_MODULES_NEURALNETWORKS_RUNTIME_MANAGER_H
19 
20 #include <LegacyUtils.h>
21 #include <android-base/macros.h>
22 #include <nnapi/IBurst.h>
23 #include <nnapi/IDevice.h>
24 #include <nnapi/Types.h>
25 
26 #include <functional>
27 #include <map>
28 #include <memory>
29 #include <string>
30 #include <tuple>
31 #include <unordered_set>
32 #include <utility>
33 #include <vector>
34 
35 #include "ExecutionCallback.h"
36 #include "Memory.h"
37 
38 namespace android {
39 namespace nn {
40 
41 // Forward declaration
42 class Device;
43 class MetaModel;
44 class ModelArgumentInfo;
45 
46 // A unified interface for a reusable execution with cached resources.
47 // This object provides no thread-safety guarantee. The caller must guarantee there is at most one
48 // call to RuntimeExecution::compute or RuntimeExecution::computeFenced on the same RuntimeExecution
49 // object in flight at a time.
50 class RuntimeExecution {
51     DISALLOW_COPY_AND_ASSIGN(RuntimeExecution);
52 
53    public:
54     RuntimeExecution() = default;
55     virtual ~RuntimeExecution() = default;
56 
57     virtual std::tuple<int, std::vector<OutputShape>, Timing> compute(
58             const SharedBurst& burstController, const OptionalTimePoint& deadline) const = 0;
59 
60     // The returned timing information is only valid if the callback is nullptr.
61     // Returns error_code, sync_fence, callback and timing.
62     virtual std::tuple<int, int, ExecuteFencedInfoCallback, Timing> computeFenced(
63             const std::vector<int>& waitFor, const OptionalTimePoint& deadline,
64             const OptionalDuration& timeoutDurationAfterFence) const = 0;
65 };
66 
67 // A unified interface for actual driver prepared model as well as the CPU.
68 class RuntimePreparedModel {
69     DISALLOW_COPY_AND_ASSIGN(RuntimePreparedModel);
70 
71    public:
72     RuntimePreparedModel() = default;
73     virtual ~RuntimePreparedModel() = default;
74 
75     virtual const Device* getDevice() const = 0;
76     virtual SharedPreparedModel getInterface() const = 0;
77 
78     // Perform computation with given input/output argument info and memory pools.
79     virtual std::tuple<int, std::vector<OutputShape>, Timing> execute(
80             const std::vector<ModelArgumentInfo>& inputs,
81             const std::vector<ModelArgumentInfo>& outputs,
82             const std::vector<const RuntimeMemory*>& memories, const SharedBurst& burstController,
83             MeasureTiming measure, const OptionalTimePoint& deadline,
84             const OptionalDuration& loopTimeoutDuration,
85             const std::vector<TokenValuePair>& metaData) const = 0;
86 
87     // Perform fenced computation with given input/output argument info and memory pools.
88     // The returned timing information is only valid if the callback is nullptr.
89     // Returns error_code, sync_fence, callback and timing.
90     virtual std::tuple<int, int, ExecuteFencedInfoCallback, Timing> executeFenced(
91             const std::vector<ModelArgumentInfo>& inputs,
92             const std::vector<ModelArgumentInfo>& outputs,
93             const std::vector<const RuntimeMemory*>& memories, const std::vector<int>& waitFor,
94             MeasureTiming measure, const OptionalTimePoint& deadline,
95             const OptionalDuration& loopTimeoutDuration,
96             const OptionalDuration& timeoutDurationAfterFence,
97             const std::vector<TokenValuePair>& metaData) const = 0;
98 
99     // Create a reusable execution with given input/output argument info and memory pools.
100     virtual std::pair<int, std::shared_ptr<RuntimeExecution>> createReusableExecution(
101             const std::vector<ModelArgumentInfo>& inputs,
102             const std::vector<ModelArgumentInfo>& outputs,
103             const std::vector<const RuntimeMemory*>& memories, MeasureTiming measure,
104             const OptionalDuration& loopTimeoutDuration,
105             const std::vector<TokenValuePair>& metaData) const = 0;
106 
107     virtual GeneralResult<SharedBurst> configureExecutionBurst() const = 0;
108 
109     virtual MemoryPreference getMemoryPreference() const = 0;
110 };
111 
112 using ModelFactory = std::function<Model()>;
113 
114 struct CacheHandles {
115     std::vector<SharedHandle> modelCache;
116     std::vector<SharedHandle> dataCache;
117 };
118 
119 using CacheDir = std::string;
120 
121 struct CacheInfo {
122     std::variant<CacheDir, CacheHandles> variant;
123 };
124 
125 // A unified interface for actual driver devices as well as the CPU
126 class Device {
127     DISALLOW_COPY_AND_ASSIGN(Device);
128 
129    public:
130     Device() = default;
131     virtual ~Device() = default;
132 
133     // Introspection methods returning device information
134     virtual const std::string& getName() const = 0;
135     virtual const std::string& getVersionString() const = 0;
136     virtual Version getFeatureLevel() const = 0;
137     virtual int32_t getType() const = 0;
138     virtual const std::vector<Extension>& getSupportedExtensions() const = 0;
139 
140     // See the MetaModel class in MetaModel.h for more details.
141     virtual std::vector<bool> getSupportedOperations(const MetaModel& metaModel) const = 0;
142 
143     virtual const Capabilities& getCapabilities() const = 0;
144     virtual Capabilities::PerformanceInfo getPerformance(OperandType type) const = 0;
145     virtual Capabilities::PerformanceInfo getRelaxedFloat32toFloat16PerformanceScalar() const = 0;
146     virtual Capabilities::PerformanceInfo getRelaxedFloat32toFloat16PerformanceTensor() const = 0;
147     virtual Capabilities::PerformanceInfo getIfPerformance() const = 0;
148     virtual Capabilities::PerformanceInfo getWhilePerformance() const = 0;
149     virtual std::pair<uint32_t, uint32_t> getNumberOfCacheFilesNeeded() const = 0;
150     virtual bool isCachingSupported() const = 0;
151     virtual int wait() const = 0;
152 
153     virtual std::pair<int, std::shared_ptr<RuntimePreparedModel>> prepareModel(
154             const ModelFactory& makeModel, ExecutionPreference preference, Priority priority,
155             const OptionalTimePoint& deadline, const CacheInfo& cacheInfo,
156             const std::optional<CacheToken>& maybeToken,
157             const std::vector<TokenValuePair>& metaData,
158             const std::vector<ExtensionNameAndPrefix>& extensionNameAndPrefix) const = 0;
159 
160     // The caller is responsible for making sure the MemoryDescriptor only contains
161     // PreparedModels from the same Device.
162     virtual std::pair<int, std::unique_ptr<RuntimeMemory>> allocate(const MemoryDescriptor& desc,
163                                                                     OperandType type) const = 0;
164 };
165 
166 // Manages the NN HAL devices.  Only one instance of this class will exist.
167 // Use get() to retrieve it.
168 class DeviceManager {
169    public:
getDrivers()170     const std::vector<std::shared_ptr<Device>>& getDrivers() const {
171         if (mSetCpuOnly || mDebugNNCpuOnly) {
172             return mDevicesCpuOnly;
173         }
174         return mDevices;
175     }
176 
177     // Gets the runtime version corresponding to getServerFeatureLevelFlag (in ServerFlag.h).
getRuntimeVersion()178     Version getRuntimeVersion() const { return mRuntimeVersion; }
179 
180     // Gets the runtime feature level corresponding to getServerFeatureLevelFlag (in ServerFlag.h).
181     int64_t getRuntimeFeatureLevel() const;
182 
183     // Convert the internal Version level representation to the NDK representation.
184     static int64_t versionToFeatureLevel(Version::Level versionLevel);
185 
186     // Returns whether platform telemetry is enabled.
isPlatformTelemetryEnabled()187     bool isPlatformTelemetryEnabled() const { return mIsPlatformTelemetryEnabled; }
188 
189     // For testing only:
setUseCpuOnly(bool useCpuOnly)190     void setUseCpuOnly(bool useCpuOnly) { mSetCpuOnly = useCpuOnly; }
getUseCpuOnly()191     bool getUseCpuOnly() const { return mSetCpuOnly; }
192 
syncExecCpu()193     bool syncExecCpu() const { return mSyncExecCpu; }
syncExecRuntime()194     bool syncExecRuntime() const { return mSyncExecRuntime; }
195 
196     // How to handle graph partitioning?
197     // 0 - Don't do graph partitioning.
198     // 1 - Do graph partitioning; but fall back to non-partitioned
199     //     execution if there is a partitioning failure.
200     // 2 - Do graph partitioning, and rely on it; there is no fallback.
201     enum { kPartitioningNo = 0, kPartitioningWithFallback = 1, kPartitioningWithoutFallback = 2 };
getPartitioning()202     uint32_t getPartitioning() const { return mPartitioning; }
partitioningAllowsFallback(uint32_t partitioning)203     static bool partitioningAllowsFallback(uint32_t partitioning) {
204         return partitioning == kPartitioningWithFallback;
205     }
206 
strictSlicing()207     bool strictSlicing() const { return mStrictSlicing; }
208 
209     // Returns the singleton manager.
210     static DeviceManager* get();
211 
212     // Returns the singleton Cpu device.
213     static std::shared_ptr<Device> getCpuDevice();
214 
215     // The forTest_* functions below are solely intended for use by unit tests.
216 
217     // Returns all devices (ignores the cpu-only flags).
forTest_getDevices()218     std::vector<std::shared_ptr<Device>> forTest_getDevices() const { return mDevices; }
219 
220     // Sets the device list (does not affect cpu-only queries).
forTest_setDevices(std::vector<std::shared_ptr<Device>> devices)221     void forTest_setDevices(std::vector<std::shared_ptr<Device>> devices) {
222         mDevices = std::move(devices);
223     }
224 
225     // Register a test device.
forTest_registerDevice(const SharedDevice & device)226     void forTest_registerDevice(const SharedDevice& device) { registerDevice(device); }
227 
228     // Re-initialize the list of available devices.
forTest_reInitializeDeviceList()229     void forTest_reInitializeDeviceList() {
230         mDevices.clear();
231         mDevicesCpuOnly.clear();
232         findAvailableDevices();
233     }
234 
235     // Make a test device
236     static std::shared_ptr<Device> forTest_makeDriverDevice(const SharedDevice& device);
237 
forTest_isCpuDevice(const ANeuralNetworksDevice * device)238     bool forTest_isCpuDevice(const ANeuralNetworksDevice* device) const {
239         return reinterpret_cast<const Device*>(device) == getCpuDevice().get();
240     }
241 
242    private:
243     // Builds the list of available drivers and queries their capabilities.
244     DeviceManager();
245 
246     // Adds a device for the manager to use.
247     void registerDevice(const SharedDevice& device);
248 
249     void findAvailableDevices();
250 
251     // Runtime version corresponding to getServerFeatureLevelFlag (in ServerFlag.h).
252     Version mRuntimeVersion;
253 
254     // Holds whether platform telemetry is enabled, as indicated by getServerTelemetryEnableFlag (in
255     // ServerFlag.h).
256     bool mIsPlatformTelemetryEnabled;
257 
258     // List of all the devices we discovered (including CpuDevice).
259     std::vector<std::shared_ptr<Device>> mDevices;
260 
261     // We set this one to have CpuDevice only. To be used when m*CpuOnly is true.
262     std::vector<std::shared_ptr<Device>> mDevicesCpuOnly;
263 
264     // If either of these is true, we'll ignore the drivers that are
265     // on the device and run everything on the CPU.
266     bool mSetCpuOnly = false;      // set by setUseCpuOnly()
267     bool mDebugNNCpuOnly = false;  // derived from system property debug.nn.cpuonly
268 
269     // synchronous execution
270     bool mSyncExecCpu = true;
271     bool mSyncExecRuntime = false;
272 
273     static const uint32_t kPartitioningDefault = kPartitioningWithFallback;
274     uint32_t mPartitioning = kPartitioningDefault;
275 
276     bool mStrictSlicing = false;
277 };
278 
279 std::vector<SharedDevice> getDevices();
280 
281 }  // namespace nn
282 }  // namespace android
283 
284 #endif  // ANDROID_PACKAGES_MODULES_NEURALNETWORKS_RUNTIME_MANAGER_H
285