1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "SampleDriver"
18 
19 #include "SampleDriver.h"
20 
21 #include "CpuExecutor.h"
22 #include "HalInterfaces.h"
23 #include "ValidateHal.h"
24 
25 #include <android-base/logging.h>
26 #include <hidl/LegacySupport.h>
27 #include <thread>
28 
29 namespace android {
30 namespace nn {
31 namespace sample_driver {
32 
getCapabilities(getCapabilities_cb cb)33 Return<void> SampleDriver::getCapabilities(getCapabilities_cb cb) {
34     return getCapabilities_1_1(
35         [&](ErrorStatus error, const V1_1::Capabilities& capabilities) {
36             // TODO(dgross): Do we need to check compliantWithV1_0(capabilities)?
37             cb(error, convertToV1_0(capabilities));
38         });
39 }
40 
getSupportedOperations(const V1_0::Model & model,getSupportedOperations_cb cb)41 Return<void> SampleDriver::getSupportedOperations(const V1_0::Model& model,
42                                                   getSupportedOperations_cb cb) {
43     if (!validateModel(model)) {
44         VLOG(DRIVER) << "getSupportedOperations";
45         std::vector<bool> supported;
46         cb(ErrorStatus::INVALID_ARGUMENT, supported);
47         return Void();
48     }
49     return getSupportedOperations_1_1(convertToV1_1(model), cb);
50 }
51 
prepareModel(const V1_0::Model & model,const sp<IPreparedModelCallback> & callback)52 Return<ErrorStatus> SampleDriver::prepareModel(const V1_0::Model& model,
53                                                const sp<IPreparedModelCallback>& callback) {
54     if (callback.get() == nullptr) {
55         VLOG(DRIVER) << "prepareModel";
56         LOG(ERROR) << "invalid callback passed to prepareModel";
57         return ErrorStatus::INVALID_ARGUMENT;
58     }
59     if (!validateModel(model)) {
60         VLOG(DRIVER) << "prepareModel";
61         callback->notify(ErrorStatus::INVALID_ARGUMENT, nullptr);
62         return ErrorStatus::INVALID_ARGUMENT;
63     }
64     return prepareModel_1_1(convertToV1_1(model), ExecutionPreference::FAST_SINGLE_ANSWER,
65                             callback);
66 }
67 
prepareModel_1_1(const V1_1::Model & model,ExecutionPreference preference,const sp<IPreparedModelCallback> & callback)68 Return<ErrorStatus> SampleDriver::prepareModel_1_1(const V1_1::Model& model,
69                                                    ExecutionPreference preference,
70                                                    const sp<IPreparedModelCallback>& callback) {
71     if (VLOG_IS_ON(DRIVER)) {
72         VLOG(DRIVER) << "prepareModel_1_1";
73         logModelToInfo(model);
74     }
75     if (callback.get() == nullptr) {
76         LOG(ERROR) << "invalid callback passed to prepareModel";
77         return ErrorStatus::INVALID_ARGUMENT;
78     }
79     if (!validateModel(model) || !validateExecutionPreference(preference)) {
80         callback->notify(ErrorStatus::INVALID_ARGUMENT, nullptr);
81         return ErrorStatus::INVALID_ARGUMENT;
82     }
83 
84     // TODO: make asynchronous later
85     sp<SamplePreparedModel> preparedModel = new SamplePreparedModel(model);
86     if (!preparedModel->initialize()) {
87        callback->notify(ErrorStatus::INVALID_ARGUMENT, nullptr);
88        return ErrorStatus::INVALID_ARGUMENT;
89     }
90     callback->notify(ErrorStatus::NONE, preparedModel);
91     return ErrorStatus::NONE;
92 }
93 
getStatus()94 Return<DeviceStatus> SampleDriver::getStatus() {
95     VLOG(DRIVER) << "getStatus()";
96     return DeviceStatus::AVAILABLE;
97 }
98 
run()99 int SampleDriver::run() {
100     android::hardware::configureRpcThreadpool(4, true);
101     if (registerAsService(mName) != android::OK) {
102         LOG(ERROR) << "Could not register service";
103         return 1;
104     }
105     android::hardware::joinRpcThreadpool();
106     LOG(ERROR) << "Service exited!";
107     return 1;
108 }
109 
initialize()110 bool SamplePreparedModel::initialize() {
111     return setRunTimePoolInfosFromHidlMemories(&mPoolInfos, mModel.pools);
112 }
113 
asyncExecute(const Request & request,const sp<IExecutionCallback> & callback)114 void SamplePreparedModel::asyncExecute(const Request& request,
115                                        const sp<IExecutionCallback>& callback) {
116     std::vector<RunTimePoolInfo> requestPoolInfos;
117     if (!setRunTimePoolInfosFromHidlMemories(&requestPoolInfos, request.pools)) {
118         callback->notify(ErrorStatus::GENERAL_FAILURE);
119         return;
120     }
121 
122     CpuExecutor executor;
123     int n = executor.run(mModel, request, mPoolInfos, requestPoolInfos);
124     VLOG(DRIVER) << "executor.run returned " << n;
125     ErrorStatus executionStatus =
126             n == ANEURALNETWORKS_NO_ERROR ? ErrorStatus::NONE : ErrorStatus::GENERAL_FAILURE;
127     Return<void> returned = callback->notify(executionStatus);
128     if (!returned.isOk()) {
129         LOG(ERROR) << " hidl callback failed to return properly: " << returned.description();
130     }
131 }
132 
execute(const Request & request,const sp<IExecutionCallback> & callback)133 Return<ErrorStatus> SamplePreparedModel::execute(const Request& request,
134                                                  const sp<IExecutionCallback>& callback) {
135     VLOG(DRIVER) << "execute(" << SHOW_IF_DEBUG(toString(request)) << ")";
136     if (callback.get() == nullptr) {
137         LOG(ERROR) << "invalid callback passed to execute";
138         return ErrorStatus::INVALID_ARGUMENT;
139     }
140     if (!validateRequest(request, mModel)) {
141         callback->notify(ErrorStatus::INVALID_ARGUMENT);
142         return ErrorStatus::INVALID_ARGUMENT;
143     }
144 
145     // This thread is intentionally detached because the sample driver service
146     // is expected to live forever.
147     std::thread([this, request, callback]{ asyncExecute(request, callback); }).detach();
148 
149     return ErrorStatus::NONE;
150 }
151 
152 } // namespace sample_driver
153 } // namespace nn
154 } // namespace android
155