1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef ANDROID_PACKAGES_MODULES_NEURALNETWORKS_DRIVER_SAMPLE_SAMPLE_DRIVER_H
18 #define ANDROID_PACKAGES_MODULES_NEURALNETWORKS_DRIVER_SAMPLE_SAMPLE_DRIVER_H
19 
20 #include <CpuExecutor.h>
21 #include <HalBufferTracker.h>
22 #include <HalInterfaces.h>
23 #include <hwbinder/IPCThreadState.h>
24 
25 #include <memory>
26 #include <string>
27 #include <utility>
28 #include <vector>
29 
30 #include "NeuralNetworks.h"
31 
32 namespace android {
33 namespace nn {
34 namespace sample_driver {
35 
36 using hardware::MQDescriptorSync;
37 
38 // Manages the data buffer for an operand.
39 class SampleBuffer : public V1_3::IBuffer {
40    public:
SampleBuffer(std::shared_ptr<HalManagedBuffer> buffer,std::unique_ptr<HalBufferTracker::Token> token)41     SampleBuffer(std::shared_ptr<HalManagedBuffer> buffer,
42                  std::unique_ptr<HalBufferTracker::Token> token)
43         : kBuffer(std::move(buffer)), kToken(std::move(token)) {
44         CHECK(kBuffer != nullptr);
45         CHECK(kToken != nullptr);
46     }
47     hardware::Return<V1_3::ErrorStatus> copyTo(const hardware::hidl_memory& dst) override;
48     hardware::Return<V1_3::ErrorStatus> copyFrom(
49             const hardware::hidl_memory& src,
50             const hardware::hidl_vec<uint32_t>& dimensions) override;
51 
52    private:
53     const std::shared_ptr<HalManagedBuffer> kBuffer;
54     const std::unique_ptr<HalBufferTracker::Token> kToken;
55 };
56 
57 // Base class used to create sample drivers for the NN HAL.  This class
58 // provides some implementation of the more common functions.
59 //
60 // Since these drivers simulate hardware, they must run the computations
61 // on the CPU.  An actual driver would not do that.
62 class SampleDriver : public V1_3::IDevice {
63    public:
64     SampleDriver(const char* name,
65                  const IOperationResolver* operationResolver = BuiltinOperationResolver::get())
mName(name)66         : mName(name),
67           mOperationResolver(operationResolver),
68           mHalBufferTracker(HalBufferTracker::create()) {
69         android::nn::initVLogMask();
70     }
71     hardware::Return<void> getCapabilities(getCapabilities_cb cb) override;
72     hardware::Return<void> getCapabilities_1_1(getCapabilities_1_1_cb cb) override;
73     hardware::Return<void> getCapabilities_1_2(getCapabilities_1_2_cb cb) override;
74     hardware::Return<void> getVersionString(getVersionString_cb cb) override;
75     hardware::Return<void> getType(getType_cb cb) override;
76     hardware::Return<void> getSupportedExtensions(getSupportedExtensions_cb) override;
77     hardware::Return<void> getSupportedOperations(const V1_0::Model& model,
78                                                   getSupportedOperations_cb cb) override;
79     hardware::Return<void> getSupportedOperations_1_1(const V1_1::Model& model,
80                                                       getSupportedOperations_1_1_cb cb) override;
81     hardware::Return<void> getSupportedOperations_1_2(const V1_2::Model& model,
82                                                       getSupportedOperations_1_2_cb cb) override;
83     hardware::Return<void> getNumberOfCacheFilesNeeded(getNumberOfCacheFilesNeeded_cb cb) override;
84     hardware::Return<V1_0::ErrorStatus> prepareModel(
85             const V1_0::Model& model, const sp<V1_0::IPreparedModelCallback>& callback) override;
86     hardware::Return<V1_0::ErrorStatus> prepareModel_1_1(
87             const V1_1::Model& model, V1_1::ExecutionPreference preference,
88             const sp<V1_0::IPreparedModelCallback>& callback) override;
89     hardware::Return<V1_0::ErrorStatus> prepareModel_1_2(
90             const V1_2::Model& model, V1_1::ExecutionPreference preference,
91             const hardware::hidl_vec<hardware::hidl_handle>& modelCache,
92             const hardware::hidl_vec<hardware::hidl_handle>& dataCache, const HalCacheToken& token,
93             const sp<V1_2::IPreparedModelCallback>& callback) override;
94     hardware::Return<V1_3::ErrorStatus> prepareModel_1_3(
95             const V1_3::Model& model, V1_1::ExecutionPreference preference, V1_3::Priority priority,
96             const V1_3::OptionalTimePoint& deadline,
97             const hardware::hidl_vec<hardware::hidl_handle>& modelCache,
98             const hardware::hidl_vec<hardware::hidl_handle>& dataCache, const HalCacheToken& token,
99             const sp<V1_3::IPreparedModelCallback>& callback) override;
100     hardware::Return<V1_0::ErrorStatus> prepareModelFromCache(
101             const hardware::hidl_vec<hardware::hidl_handle>& modelCache,
102             const hardware::hidl_vec<hardware::hidl_handle>& dataCache, const HalCacheToken& token,
103             const sp<V1_2::IPreparedModelCallback>& callback) override;
104     hardware::Return<V1_3::ErrorStatus> prepareModelFromCache_1_3(
105             const V1_3::OptionalTimePoint& deadline,
106             const hardware::hidl_vec<hardware::hidl_handle>& modelCache,
107             const hardware::hidl_vec<hardware::hidl_handle>& dataCache, const HalCacheToken& token,
108             const sp<V1_3::IPreparedModelCallback>& callback) override;
109     hardware::Return<V1_0::DeviceStatus> getStatus() override;
110     hardware::Return<void> allocate(
111             const V1_3::BufferDesc& desc,
112             const hardware::hidl_vec<sp<V1_3::IPreparedModel>>& preparedModels,
113             const hardware::hidl_vec<V1_3::BufferRole>& inputRoles,
114             const hardware::hidl_vec<V1_3::BufferRole>& outputRoles, allocate_cb cb) override;
115 
getExecutor()116     CpuExecutor getExecutor() const { return CpuExecutor(mOperationResolver); }
getHalBufferTracker()117     const std::shared_ptr<HalBufferTracker>& getHalBufferTracker() const {
118         return mHalBufferTracker;
119     }
120 
121    protected:
122     std::string mName;
123     const IOperationResolver* mOperationResolver;
124     const std::shared_ptr<HalBufferTracker> mHalBufferTracker;
125 };
126 
127 class SamplePreparedModel : public V1_3::IPreparedModel {
128    public:
SamplePreparedModel(const V1_3::Model & model,const SampleDriver * driver,V1_1::ExecutionPreference preference,uid_t userId,V1_3::Priority priority)129     SamplePreparedModel(const V1_3::Model& model, const SampleDriver* driver,
130                         V1_1::ExecutionPreference preference, uid_t userId, V1_3::Priority priority)
131         : mModel(model),
132           mDriver(driver),
133           kPreference(preference),
134           kUserId(userId),
135           kPriority(priority) {
136         (void)kUserId;
137         (void)kPriority;
138     }
139     bool initialize();
140     hardware::Return<V1_0::ErrorStatus> execute(
141             const V1_0::Request& request, const sp<V1_0::IExecutionCallback>& callback) override;
142     hardware::Return<V1_0::ErrorStatus> execute_1_2(
143             const V1_0::Request& request, V1_2::MeasureTiming measure,
144             const sp<V1_2::IExecutionCallback>& callback) override;
145     hardware::Return<V1_3::ErrorStatus> execute_1_3(
146             const V1_3::Request& request, V1_2::MeasureTiming measure,
147             const V1_3::OptionalTimePoint& deadline,
148             const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
149             const sp<V1_3::IExecutionCallback>& callback) override;
150     hardware::Return<void> executeSynchronously(const V1_0::Request& request,
151                                                 V1_2::MeasureTiming measure,
152                                                 executeSynchronously_cb cb) override;
153     hardware::Return<void> executeSynchronously_1_3(
154             const V1_3::Request& request, V1_2::MeasureTiming measure,
155             const V1_3::OptionalTimePoint& deadline,
156             const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
157             executeSynchronously_1_3_cb cb) override;
158     hardware::Return<void> configureExecutionBurst(
159             const sp<V1_2::IBurstCallback>& callback,
160             const MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel,
161             const MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel,
162             configureExecutionBurst_cb cb) override;
163     hardware::Return<void> executeFenced(const V1_3::Request& request,
164                                          const hardware::hidl_vec<hardware::hidl_handle>& wait_for,
165                                          V1_2::MeasureTiming measure,
166                                          const V1_3::OptionalTimePoint& deadline,
167                                          const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
168                                          const V1_3::OptionalTimeoutDuration& duration,
169                                          executeFenced_cb callback) override;
getModel()170     const V1_3::Model* getModel() const { return &mModel; }
171 
172    protected:
173     V1_3::Model mModel;
174     const SampleDriver* mDriver;
175     std::vector<RunTimePoolInfo> mPoolInfos;
176     const V1_1::ExecutionPreference kPreference;
177     const uid_t kUserId;
178     const V1_3::Priority kPriority;
179 };
180 
181 class SampleFencedExecutionCallback : public V1_3::IFencedExecutionCallback {
182    public:
SampleFencedExecutionCallback(V1_2::Timing timingSinceLaunch,V1_2::Timing timingAfterFence,V1_3::ErrorStatus error)183     SampleFencedExecutionCallback(V1_2::Timing timingSinceLaunch, V1_2::Timing timingAfterFence,
184                                   V1_3::ErrorStatus error)
185         : kTimingSinceLaunch(timingSinceLaunch),
186           kTimingAfterFence(timingAfterFence),
187           kErrorStatus(error) {}
getExecutionInfo(getExecutionInfo_cb callback)188     hardware::Return<void> getExecutionInfo(getExecutionInfo_cb callback) override {
189         callback(kErrorStatus, kTimingSinceLaunch, kTimingAfterFence);
190         return hardware::Void();
191     }
192 
193    private:
194     const V1_2::Timing kTimingSinceLaunch;
195     const V1_2::Timing kTimingAfterFence;
196     const V1_3::ErrorStatus kErrorStatus;
197 };
198 
199 }  // namespace sample_driver
200 }  // namespace nn
201 }  // namespace android
202 
203 #endif  // ANDROID_PACKAGES_MODULES_NEURALNETWORKS_DRIVER_SAMPLE_SAMPLE_DRIVER_H
204