1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "SampleDriver"
18
19 #include "SampleDriver.h"
20
21 #include <CpuExecutor.h>
22 #include <ExecutionBurstServer.h>
23 #include <HalBufferTracker.h>
24 #include <HalInterfaces.h>
25 #include <Tracing.h>
26 #include <ValidateHal.h>
27 #include <android-base/logging.h>
28 #include <android-base/properties.h>
29 #include <hidl/LegacySupport.h>
30 #include <nnapi/Types.h>
31 #include <nnapi/hal/1.3/Conversions.h>
32
33 #include <algorithm>
34 #include <chrono>
35 #include <map>
36 #include <memory>
37 #include <optional>
38 #include <set>
39 #include <thread>
40 #include <tuple>
41 #include <utility>
42 #include <vector>
43
44 #include "SampleDriverUtils.h"
45
46 namespace android {
47 namespace nn {
48 namespace sample_driver {
49
50 namespace {
51
microsecondsDuration(TimePoint end,TimePoint start)52 uint64_t microsecondsDuration(TimePoint end, TimePoint start) {
53 using Microseconds = std::chrono::duration<uint64_t, std::micro>;
54 return std::chrono::duration_cast<Microseconds>(end - start).count();
55 };
56
57 } // namespace
58
59 static const V1_2::Timing kNoTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
60
getCapabilities(getCapabilities_cb cb)61 hardware::Return<void> SampleDriver::getCapabilities(getCapabilities_cb cb) {
62 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
63 "SampleDriver::getCapabilities");
64 return getCapabilities_1_3(
65 [&](V1_3::ErrorStatus error, const V1_3::Capabilities& capabilities) {
66 // TODO(dgross): Do we need to check compliantWithV1_0(capabilities)?
67 cb(convertToV1_0(error), convertToV1_0(capabilities));
68 });
69 }
70
getCapabilities_1_1(getCapabilities_1_1_cb cb)71 hardware::Return<void> SampleDriver::getCapabilities_1_1(getCapabilities_1_1_cb cb) {
72 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
73 "SampleDriver::getCapabilities_1_1");
74 return getCapabilities_1_3(
75 [&](V1_3::ErrorStatus error, const V1_3::Capabilities& capabilities) {
76 // TODO(dgross): Do we need to check compliantWithV1_1(capabilities)?
77 cb(convertToV1_0(error), convertToV1_1(capabilities));
78 });
79 }
80
getCapabilities_1_2(getCapabilities_1_2_cb cb)81 hardware::Return<void> SampleDriver::getCapabilities_1_2(getCapabilities_1_2_cb cb) {
82 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
83 "SampleDriver::getCapabilities_1_2");
84 return getCapabilities_1_3(
85 [&](V1_3::ErrorStatus error, const V1_3::Capabilities& capabilities) {
86 // TODO(dgross): Do we need to check compliantWithV1_2(capabilities)?
87 cb(convertToV1_0(error), convertToV1_2(capabilities));
88 });
89 }
90
getVersionString(getVersionString_cb cb)91 hardware::Return<void> SampleDriver::getVersionString(getVersionString_cb cb) {
92 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
93 "SampleDriver::getVersionString");
94 cb(V1_0::ErrorStatus::NONE, "JUST_AN_EXAMPLE");
95 return hardware::Void();
96 }
97
getType(getType_cb cb)98 hardware::Return<void> SampleDriver::getType(getType_cb cb) {
99 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION, "SampleDriver::getType");
100 cb(V1_0::ErrorStatus::NONE, V1_2::DeviceType::CPU);
101 return hardware::Void();
102 }
103
getSupportedExtensions(getSupportedExtensions_cb cb)104 hardware::Return<void> SampleDriver::getSupportedExtensions(getSupportedExtensions_cb cb) {
105 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
106 "SampleDriver::getSupportedExtensions");
107 cb(V1_0::ErrorStatus::NONE, {/* No extensions. */});
108 return hardware::Void();
109 }
110
getSupportedOperations(const V1_0::Model & model,getSupportedOperations_cb cb)111 hardware::Return<void> SampleDriver::getSupportedOperations(const V1_0::Model& model,
112 getSupportedOperations_cb cb) {
113 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION,
114 "SampleDriver::getSupportedOperations");
115 if (!validateModel(model)) {
116 VLOG(DRIVER) << "getSupportedOperations";
117 cb(V1_0::ErrorStatus::INVALID_ARGUMENT, {});
118 return hardware::Void();
119 }
120 return getSupportedOperations_1_3(
121 convertToV1_3(model),
122 [&](V1_3::ErrorStatus status, const hardware::hidl_vec<bool>& supported) {
123 cb(convertToV1_0(status), supported);
124 });
125 }
126
getSupportedOperations_1_1(const V1_1::Model & model,getSupportedOperations_1_1_cb cb)127 hardware::Return<void> SampleDriver::getSupportedOperations_1_1(const V1_1::Model& model,
128 getSupportedOperations_1_1_cb cb) {
129 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION,
130 "SampleDriver::getSupportedOperations_1_1");
131 if (!validateModel(model)) {
132 VLOG(DRIVER) << "getSupportedOperations_1_1";
133 cb(V1_0::ErrorStatus::INVALID_ARGUMENT, {});
134 return hardware::Void();
135 }
136 return getSupportedOperations_1_3(
137 convertToV1_3(model),
138 [&](V1_3::ErrorStatus status, const hardware::hidl_vec<bool>& supported) {
139 cb(convertToV1_0(status), supported);
140 });
141 }
142
getSupportedOperations_1_2(const V1_2::Model & model,getSupportedOperations_1_2_cb cb)143 hardware::Return<void> SampleDriver::getSupportedOperations_1_2(const V1_2::Model& model,
144 getSupportedOperations_1_2_cb cb) {
145 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION,
146 "SampleDriver::getSupportedOperations_1_2");
147 if (!validateModel(model)) {
148 VLOG(DRIVER) << "getSupportedOperations_1_2";
149 cb(V1_0::ErrorStatus::INVALID_ARGUMENT, {});
150 return hardware::Void();
151 }
152 return getSupportedOperations_1_3(
153 convertToV1_3(model),
154 [&](V1_3::ErrorStatus status, const hardware::hidl_vec<bool>& supported) {
155 cb(convertToV1_0(status), supported);
156 });
157 }
158
getNumberOfCacheFilesNeeded(getNumberOfCacheFilesNeeded_cb cb)159 hardware::Return<void> SampleDriver::getNumberOfCacheFilesNeeded(
160 getNumberOfCacheFilesNeeded_cb cb) {
161 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
162 "SampleDriver::getNumberOfCacheFilesNeeded");
163 // Set both numbers to be 0 for cache not supported.
164 cb(V1_0::ErrorStatus::NONE, /*numModelCache=*/0, /*numDataCache=*/0);
165 return hardware::Void();
166 }
167
prepareModel(const V1_0::Model & model,const sp<V1_0::IPreparedModelCallback> & callback)168 hardware::Return<V1_0::ErrorStatus> SampleDriver::prepareModel(
169 const V1_0::Model& model, const sp<V1_0::IPreparedModelCallback>& callback) {
170 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION, "SampleDriver::prepareModel");
171 const V1_3::ErrorStatus status =
172 prepareModelBase(model, this, V1_1::ExecutionPreference::FAST_SINGLE_ANSWER,
173 kDefaultPriority13, {}, callback);
174 return convertToV1_0(status);
175 }
176
prepareModel_1_1(const V1_1::Model & model,V1_1::ExecutionPreference preference,const sp<V1_0::IPreparedModelCallback> & callback)177 hardware::Return<V1_0::ErrorStatus> SampleDriver::prepareModel_1_1(
178 const V1_1::Model& model, V1_1::ExecutionPreference preference,
179 const sp<V1_0::IPreparedModelCallback>& callback) {
180 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION, "SampleDriver::prepareModel_1_1");
181 const V1_3::ErrorStatus status =
182 prepareModelBase(model, this, preference, kDefaultPriority13, {}, callback);
183 return convertToV1_0(status);
184 }
185
prepareModel_1_2(const V1_2::Model & model,V1_1::ExecutionPreference preference,const hardware::hidl_vec<hardware::hidl_handle> &,const hardware::hidl_vec<hardware::hidl_handle> &,const HalCacheToken &,const sp<V1_2::IPreparedModelCallback> & callback)186 hardware::Return<V1_0::ErrorStatus> SampleDriver::prepareModel_1_2(
187 const V1_2::Model& model, V1_1::ExecutionPreference preference,
188 const hardware::hidl_vec<hardware::hidl_handle>&,
189 const hardware::hidl_vec<hardware::hidl_handle>&, const HalCacheToken&,
190 const sp<V1_2::IPreparedModelCallback>& callback) {
191 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION, "SampleDriver::prepareModel_1_2");
192 const V1_3::ErrorStatus status =
193 prepareModelBase(model, this, preference, kDefaultPriority13, {}, callback);
194 return convertToV1_0(status);
195 }
196
prepareModel_1_3(const V1_3::Model & model,V1_1::ExecutionPreference preference,V1_3::Priority priority,const V1_3::OptionalTimePoint & deadline,const hardware::hidl_vec<hardware::hidl_handle> &,const hardware::hidl_vec<hardware::hidl_handle> &,const HalCacheToken &,const sp<V1_3::IPreparedModelCallback> & callback)197 hardware::Return<V1_3::ErrorStatus> SampleDriver::prepareModel_1_3(
198 const V1_3::Model& model, V1_1::ExecutionPreference preference, V1_3::Priority priority,
199 const V1_3::OptionalTimePoint& deadline, const hardware::hidl_vec<hardware::hidl_handle>&,
200 const hardware::hidl_vec<hardware::hidl_handle>&, const HalCacheToken&,
201 const sp<V1_3::IPreparedModelCallback>& callback) {
202 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION, "SampleDriver::prepareModel_1_3");
203 return prepareModelBase(model, this, preference, priority, deadline, callback);
204 }
205
prepareModelFromCache(const hardware::hidl_vec<hardware::hidl_handle> &,const hardware::hidl_vec<hardware::hidl_handle> &,const HalCacheToken &,const sp<V1_2::IPreparedModelCallback> & callback)206 hardware::Return<V1_0::ErrorStatus> SampleDriver::prepareModelFromCache(
207 const hardware::hidl_vec<hardware::hidl_handle>&,
208 const hardware::hidl_vec<hardware::hidl_handle>&, const HalCacheToken&,
209 const sp<V1_2::IPreparedModelCallback>& callback) {
210 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION,
211 "SampleDriver::prepareModelFromCache");
212 notify(callback, V1_3::ErrorStatus::GENERAL_FAILURE, nullptr);
213 return V1_0::ErrorStatus::GENERAL_FAILURE;
214 }
215
prepareModelFromCache_1_3(const V1_3::OptionalTimePoint &,const hardware::hidl_vec<hardware::hidl_handle> &,const hardware::hidl_vec<hardware::hidl_handle> &,const HalCacheToken &,const sp<V1_3::IPreparedModelCallback> & callback)216 hardware::Return<V1_3::ErrorStatus> SampleDriver::prepareModelFromCache_1_3(
217 const V1_3::OptionalTimePoint& /*deadline*/,
218 const hardware::hidl_vec<hardware::hidl_handle>&,
219 const hardware::hidl_vec<hardware::hidl_handle>&, const HalCacheToken&,
220 const sp<V1_3::IPreparedModelCallback>& callback) {
221 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION,
222 "SampleDriver::prepareModelFromCache_1_3");
223 notify(callback, V1_3::ErrorStatus::GENERAL_FAILURE, nullptr);
224 return V1_3::ErrorStatus::GENERAL_FAILURE;
225 }
226
getStatus()227 hardware::Return<V1_0::DeviceStatus> SampleDriver::getStatus() {
228 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_UNSPECIFIED, "SampleDriver::getStatus");
229 VLOG(DRIVER) << "getStatus()";
230 return V1_0::DeviceStatus::AVAILABLE;
231 }
232
233 // Safely downcast an IPreparedModel object to SamplePreparedModel.
234 // This function will return nullptr if the IPreparedModel object is not originated from the sample
235 // driver process.
castToSamplePreparedModel(const sp<V1_3::IPreparedModel> & preparedModel)236 static const SamplePreparedModel* castToSamplePreparedModel(
237 const sp<V1_3::IPreparedModel>& preparedModel) {
238 if (preparedModel->isRemote()) {
239 return nullptr;
240 } else {
241 // This static_cast is safe because SamplePreparedModel is the only class that implements
242 // the IPreparedModel interface in the sample driver process.
243 return static_cast<const SamplePreparedModel*>(preparedModel.get());
244 }
245 }
246
allocate(const V1_3::BufferDesc & desc,const hardware::hidl_vec<sp<V1_3::IPreparedModel>> & preparedModels,const hardware::hidl_vec<V1_3::BufferRole> & inputRoles,const hardware::hidl_vec<V1_3::BufferRole> & outputRoles,allocate_cb cb)247 hardware::Return<void> SampleDriver::allocate(
248 const V1_3::BufferDesc& desc,
249 const hardware::hidl_vec<sp<V1_3::IPreparedModel>>& preparedModels,
250 const hardware::hidl_vec<V1_3::BufferRole>& inputRoles,
251 const hardware::hidl_vec<V1_3::BufferRole>& outputRoles, allocate_cb cb) {
252 constexpr uint32_t kInvalidBufferToken = 0;
253
254 VLOG(DRIVER) << "SampleDriver::allocate";
255 std::set<HalPreparedModelRole> roles;
256 V1_3::Operand operand;
257 auto getModel = [](const sp<V1_3::IPreparedModel>& preparedModel) -> const V1_3::Model* {
258 const auto* samplePreparedModel = castToSamplePreparedModel(preparedModel);
259 if (samplePreparedModel == nullptr) {
260 LOG(ERROR) << "SampleDriver::allocate -- unknown remote IPreparedModel.";
261 return nullptr;
262 }
263 return samplePreparedModel->getModel();
264 };
265 if (!validateMemoryDesc(desc, preparedModels, inputRoles, outputRoles, getModel, &roles,
266 &operand)) {
267 LOG(ERROR) << "SampleDriver::allocate -- validation failed.";
268 cb(V1_3::ErrorStatus::INVALID_ARGUMENT, nullptr, kInvalidBufferToken);
269 return hardware::Void();
270 }
271
272 if (isExtensionOperandType(operand.type)) {
273 LOG(ERROR) << "SampleDriver::allocate -- does not support extension type.";
274 cb(V1_3::ErrorStatus::GENERAL_FAILURE, nullptr, kInvalidBufferToken);
275 return hardware::Void();
276 }
277
278 // TODO(xusongw): Support allocating buffers with unknown dimensions or rank.
279 uint32_t size = nonExtensionOperandSizeOfData(operand.type, operand.dimensions);
280 VLOG(DRIVER) << "SampleDriver::allocate -- type = " << toString(operand.type)
281 << ", dimensions = " << toString(operand.dimensions) << ", size = " << size;
282 if (size == 0) {
283 LOG(ERROR) << "SampleDriver::allocate -- does not support dynamic output shape.";
284 cb(V1_3::ErrorStatus::GENERAL_FAILURE, nullptr, kInvalidBufferToken);
285 return hardware::Void();
286 }
287
288 auto bufferWrapper =
289 HalManagedBuffer::create(size, std::move(roles), uncheckedConvert(operand));
290 if (bufferWrapper == nullptr) {
291 LOG(ERROR) << "SampleDriver::allocate -- not enough memory.";
292 cb(V1_3::ErrorStatus::GENERAL_FAILURE, nullptr, kInvalidBufferToken);
293 return hardware::Void();
294 }
295
296 auto token = mHalBufferTracker->add(bufferWrapper);
297 if (token == nullptr) {
298 LOG(ERROR) << "SampleDriver::allocate -- HalBufferTracker returned invalid token.";
299 cb(V1_3::ErrorStatus::GENERAL_FAILURE, nullptr, kInvalidBufferToken);
300 return hardware::Void();
301 }
302
303 const uint32_t tokenValue = token->get();
304 sp<SampleBuffer> sampleBuffer = new SampleBuffer(std::move(bufferWrapper), std::move(token));
305 VLOG(DRIVER) << "SampleDriver::allocate -- successfully allocates the requested memory";
306 cb(V1_3::ErrorStatus::NONE, std::move(sampleBuffer), tokenValue);
307 return hardware::Void();
308 }
309
copyRunTimePoolInfos(const RunTimePoolInfo & srcPool,const RunTimePoolInfo & dstPool)310 static void copyRunTimePoolInfos(const RunTimePoolInfo& srcPool, const RunTimePoolInfo& dstPool) {
311 CHECK(srcPool.getBuffer() != nullptr);
312 CHECK(dstPool.getBuffer() != nullptr);
313 CHECK(srcPool.getSize() == dstPool.getSize());
314 std::copy(srcPool.getBuffer(), srcPool.getBuffer() + srcPool.getSize(), dstPool.getBuffer());
315 dstPool.flush();
316 }
317
copyTo(const hardware::hidl_memory & dst)318 hardware::Return<V1_3::ErrorStatus> SampleBuffer::copyTo(const hardware::hidl_memory& dst) {
319 const auto dstPool = RunTimePoolInfo::createFromMemory(uncheckedConvert(dst));
320 if (!dstPool.has_value()) {
321 LOG(ERROR) << "SampleBuffer::copyTo -- unable to map dst memory.";
322 return V1_3::ErrorStatus::GENERAL_FAILURE;
323 }
324 const V1_3::ErrorStatus validationStatus =
325 convertToV1_3(kBuffer->validateCopyTo(dstPool->getSize()));
326 if (validationStatus != V1_3::ErrorStatus::NONE) {
327 return validationStatus;
328 }
329 const auto srcPool = kBuffer->createRunTimePoolInfo();
330 copyRunTimePoolInfos(srcPool, dstPool.value());
331 return V1_3::ErrorStatus::NONE;
332 }
333
copyFromInternal(const hardware::hidl_memory & src,const hardware::hidl_vec<uint32_t> & dimensions,const std::shared_ptr<HalManagedBuffer> & bufferWrapper)334 static V1_3::ErrorStatus copyFromInternal(const hardware::hidl_memory& src,
335 const hardware::hidl_vec<uint32_t>& dimensions,
336 const std::shared_ptr<HalManagedBuffer>& bufferWrapper) {
337 CHECK(bufferWrapper != nullptr);
338 const auto srcPool = RunTimePoolInfo::createFromMemory(uncheckedConvert(src));
339 if (!srcPool.has_value()) {
340 LOG(ERROR) << "SampleBuffer::copyFrom -- unable to map src memory.";
341 return V1_3::ErrorStatus::GENERAL_FAILURE;
342 }
343 const V1_3::ErrorStatus validationStatus =
344 convertToV1_3(bufferWrapper->validateCopyFrom(dimensions, srcPool->getSize()));
345 if (validationStatus != V1_3::ErrorStatus::NONE) {
346 return validationStatus;
347 }
348 const auto dstPool = bufferWrapper->createRunTimePoolInfo();
349 copyRunTimePoolInfos(srcPool.value(), dstPool);
350 return V1_3::ErrorStatus::NONE;
351 }
352
copyFrom(const hardware::hidl_memory & src,const hardware::hidl_vec<uint32_t> & dimensions)353 hardware::Return<V1_3::ErrorStatus> SampleBuffer::copyFrom(
354 const hardware::hidl_memory& src, const hardware::hidl_vec<uint32_t>& dimensions) {
355 const auto status = copyFromInternal(src, dimensions, kBuffer);
356 if (status == V1_3::ErrorStatus::NONE) {
357 kBuffer->updateDimensions(dimensions);
358 kBuffer->setInitialized(true);
359 } else {
360 kBuffer->setInitialized(false);
361 }
362 return status;
363 }
364
initialize()365 bool SamplePreparedModel::initialize() {
366 return setRunTimePoolInfosFromCanonicalMemories(&mPoolInfos, uncheckedConvert(mModel.pools));
367 }
368
369 static std::tuple<V1_3::ErrorStatus, std::vector<RunTimePoolInfo>,
370 std::vector<std::shared_ptr<HalManagedBuffer>>>
createRunTimePoolInfos(const V1_3::Request & request,const SampleDriver & driver,const SamplePreparedModel * preparedModel)371 createRunTimePoolInfos(const V1_3::Request& request, const SampleDriver& driver,
372 const SamplePreparedModel* preparedModel) {
373 std::vector<RunTimePoolInfo> requestPoolInfos;
374 std::vector<std::shared_ptr<HalManagedBuffer>> bufferWrappers;
375 requestPoolInfos.reserve(request.pools.size());
376 bufferWrappers.reserve(request.pools.size());
377 for (uint32_t i = 0; i < request.pools.size(); i++) {
378 auto& pool = request.pools[i];
379 switch (pool.getDiscriminator()) {
380 case V1_3::Request::MemoryPool::hidl_discriminator::hidlMemory: {
381 auto buffer =
382 RunTimePoolInfo::createFromMemory(uncheckedConvert(pool.hidlMemory()));
383 if (!buffer.has_value()) {
384 LOG(ERROR) << "createRuntimeMemoriesFromMemoryPools -- could not map pools";
385 return {V1_3::ErrorStatus::GENERAL_FAILURE, {}, {}};
386 }
387 requestPoolInfos.push_back(std::move(*buffer));
388 bufferWrappers.push_back(nullptr);
389 } break;
390 case V1_3::Request::MemoryPool::hidl_discriminator::token: {
391 auto bufferWrapper = driver.getHalBufferTracker()->get(pool.token());
392 if (bufferWrapper == nullptr) {
393 return {V1_3::ErrorStatus::INVALID_ARGUMENT, {}, {}};
394 }
395 const auto validationStatus = convertToV1_3(bufferWrapper->validateRequest(
396 i, uncheckedConvert(request), preparedModel));
397 if (validationStatus != V1_3::ErrorStatus::NONE) {
398 return {validationStatus, {}, {}};
399 }
400 requestPoolInfos.push_back(bufferWrapper->createRunTimePoolInfo());
401 bufferWrappers.push_back(std::move(bufferWrapper));
402 } break;
403 }
404 }
405 return {V1_3::ErrorStatus::NONE, std::move(requestPoolInfos), std::move(bufferWrappers)};
406 }
407
updateDeviceMemories(V1_3::ErrorStatus status,const V1_3::Request & request,const std::vector<std::shared_ptr<HalManagedBuffer>> & bufferWrappers,const hardware::hidl_vec<V1_2::OutputShape> & outputShapes)408 static V1_3::ErrorStatus updateDeviceMemories(
409 V1_3::ErrorStatus status, const V1_3::Request& request,
410 const std::vector<std::shared_ptr<HalManagedBuffer>>& bufferWrappers,
411 const hardware::hidl_vec<V1_2::OutputShape>& outputShapes) {
412 if (status == V1_3::ErrorStatus::NONE) {
413 for (uint32_t i = 0; i < request.outputs.size(); i++) {
414 const uint32_t poolIndex = request.outputs[i].location.poolIndex;
415 const auto& pool = request.pools[poolIndex];
416 if (pool.getDiscriminator() == V1_3::Request::MemoryPool::hidl_discriminator::token) {
417 if (!bufferWrappers[poolIndex]->updateDimensions(outputShapes[i].dimensions)) {
418 return V1_3::ErrorStatus::GENERAL_FAILURE;
419 }
420 }
421 }
422 for (uint32_t i = 0; i < request.outputs.size(); i++) {
423 const uint32_t poolIndex = request.outputs[i].location.poolIndex;
424 const auto& pool = request.pools[poolIndex];
425 if (pool.getDiscriminator() == V1_3::Request::MemoryPool::hidl_discriminator::token) {
426 bufferWrappers[poolIndex]->setInitialized(true);
427 }
428 }
429 } else if (status == V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
430 // If CpuExecutor reports OUTPUT_INSUFFCIENT_SIZE on a device memory, this is because the
431 // dimensions of the device memory are incorrectly specified. The driver should return
432 // GENERAL_FAILURE instead in this case.
433 for (uint32_t i = 0; i < request.outputs.size(); i++) {
434 const uint32_t poolIndex = request.outputs[i].location.poolIndex;
435 const auto& pool = request.pools[poolIndex];
436 if (pool.getDiscriminator() == V1_3::Request::MemoryPool::hidl_discriminator::token) {
437 if (!outputShapes[i].isSufficient) {
438 LOG(ERROR) << "Invalid dimensions for output " << i
439 << ": actual shape = " << toString(outputShapes[i].dimensions);
440 return V1_3::ErrorStatus::GENERAL_FAILURE;
441 }
442 }
443 }
444 }
445 return V1_3::ErrorStatus::NONE;
446 }
447
448 template <typename T_IExecutionCallback>
asyncExecute(const V1_3::Request & request,V1_2::MeasureTiming measure,TimePoint driverStart,const V1_3::Model & model,const SampleDriver & driver,const SamplePreparedModel * preparedModel,const std::vector<RunTimePoolInfo> & poolInfos,const OptionalTimePoint & deadline,const V1_3::OptionalTimeoutDuration & loopTimeoutDuration,const sp<T_IExecutionCallback> & callback)449 void asyncExecute(const V1_3::Request& request, V1_2::MeasureTiming measure, TimePoint driverStart,
450 const V1_3::Model& model, const SampleDriver& driver,
451 const SamplePreparedModel* preparedModel,
452 const std::vector<RunTimePoolInfo>& poolInfos, const OptionalTimePoint& deadline,
453 const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
454 const sp<T_IExecutionCallback>& callback) {
455 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INPUTS_AND_OUTPUTS,
456 "SampleDriver::asyncExecute");
457
458 const auto [poolStatus, requestPoolInfos, bufferWrappers] =
459 createRunTimePoolInfos(request, driver, preparedModel);
460 if (poolStatus != V1_3::ErrorStatus::NONE) {
461 notify(callback, poolStatus, {}, kNoTiming);
462 return;
463 }
464
465 NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
466 "SampleDriver::asyncExecute");
467 CpuExecutor executor = driver.getExecutor();
468 if (loopTimeoutDuration.getDiscriminator() !=
469 V1_3::OptionalTimeoutDuration::hidl_discriminator::none) {
470 executor.setLoopTimeout(loopTimeoutDuration.nanoseconds());
471 }
472 if (deadline.has_value()) {
473 executor.setDeadline(*deadline);
474 }
475 TimePoint driverEnd, deviceStart, deviceEnd;
476 if (measure == V1_2::MeasureTiming::YES) deviceStart = Clock::now();
477 int n = executor.run(uncheckedConvert(model), uncheckedConvert(request), poolInfos,
478 requestPoolInfos);
479 if (measure == V1_2::MeasureTiming::YES) deviceEnd = Clock::now();
480 VLOG(DRIVER) << "executor.run returned " << n;
481 V1_3::ErrorStatus executionStatus = convertResultCodeToHalErrorStatus(n);
482 hardware::hidl_vec<V1_2::OutputShape> outputShapes = convertToV1_2(executor.getOutputShapes());
483
484 // Update device memory metadata.
485 const V1_3::ErrorStatus updateStatus =
486 updateDeviceMemories(executionStatus, request, bufferWrappers, outputShapes);
487 if (updateStatus != V1_3::ErrorStatus::NONE) {
488 notify(callback, updateStatus, {}, kNoTiming);
489 return;
490 }
491
492 if (measure == V1_2::MeasureTiming::YES && executionStatus == V1_3::ErrorStatus::NONE) {
493 driverEnd = Clock::now();
494 V1_2::Timing timing = {
495 .timeOnDevice = uint64_t(microsecondsDuration(deviceEnd, deviceStart)),
496 .timeInDriver = uint64_t(microsecondsDuration(driverEnd, driverStart))};
497 VLOG(DRIVER) << "SampleDriver::asyncExecute timing = " << toString(timing);
498 notify(callback, executionStatus, outputShapes, timing);
499 } else {
500 notify(callback, executionStatus, outputShapes, kNoTiming);
501 }
502 }
503
504 template <typename T_IExecutionCallback>
executeBase(const V1_3::Request & request,V1_2::MeasureTiming measure,const V1_3::Model & model,const SampleDriver & driver,const SamplePreparedModel * preparedModel,const std::vector<RunTimePoolInfo> & poolInfos,const V1_3::OptionalTimePoint & halDeadline,const V1_3::OptionalTimeoutDuration & loopTimeoutDuration,const sp<T_IExecutionCallback> & callback)505 V1_3::ErrorStatus executeBase(const V1_3::Request& request, V1_2::MeasureTiming measure,
506 const V1_3::Model& model, const SampleDriver& driver,
507 const SamplePreparedModel* preparedModel,
508 const std::vector<RunTimePoolInfo>& poolInfos,
509 const V1_3::OptionalTimePoint& halDeadline,
510 const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
511 const sp<T_IExecutionCallback>& callback) {
512 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION, "SampleDriver::executeBase");
513 VLOG(DRIVER) << "executeBase(" << SHOW_IF_DEBUG(toString(request)) << ")";
514
515 TimePoint driverStart;
516 if (measure == V1_2::MeasureTiming::YES) driverStart = Clock::now();
517
518 if (callback.get() == nullptr) {
519 LOG(ERROR) << "invalid callback passed to executeBase";
520 return V1_3::ErrorStatus::INVALID_ARGUMENT;
521 }
522 if (!validateRequest(request, model)) {
523 notify(callback, V1_3::ErrorStatus::INVALID_ARGUMENT, {}, kNoTiming);
524 return V1_3::ErrorStatus::INVALID_ARGUMENT;
525 }
526 const auto deadline = convert(halDeadline).value();
527 if (hasDeadlinePassed(deadline)) {
528 notify(callback, V1_3::ErrorStatus::MISSED_DEADLINE_PERSISTENT, {}, kNoTiming);
529 return V1_3::ErrorStatus::NONE;
530 }
531
532 // This thread is intentionally detached because the sample driver service
533 // is expected to live forever.
534 std::thread([&model, &driver, preparedModel, &poolInfos, request, measure, driverStart,
535 deadline, loopTimeoutDuration, callback] {
536 asyncExecute(request, measure, driverStart, model, driver, preparedModel, poolInfos,
537 deadline, loopTimeoutDuration, callback);
538 }).detach();
539
540 return V1_3::ErrorStatus::NONE;
541 }
542
execute(const V1_0::Request & request,const sp<V1_0::IExecutionCallback> & callback)543 hardware::Return<V1_0::ErrorStatus> SamplePreparedModel::execute(
544 const V1_0::Request& request, const sp<V1_0::IExecutionCallback>& callback) {
545 const V1_3::ErrorStatus status =
546 executeBase(convertToV1_3(request), V1_2::MeasureTiming::NO, mModel, *mDriver, this,
547 mPoolInfos, {}, {}, callback);
548 return convertToV1_0(status);
549 }
550
execute_1_2(const V1_0::Request & request,V1_2::MeasureTiming measure,const sp<V1_2::IExecutionCallback> & callback)551 hardware::Return<V1_0::ErrorStatus> SamplePreparedModel::execute_1_2(
552 const V1_0::Request& request, V1_2::MeasureTiming measure,
553 const sp<V1_2::IExecutionCallback>& callback) {
554 const V1_3::ErrorStatus status = executeBase(convertToV1_3(request), measure, mModel, *mDriver,
555 this, mPoolInfos, {}, {}, callback);
556 return convertToV1_0(status);
557 }
558
execute_1_3(const V1_3::Request & request,V1_2::MeasureTiming measure,const V1_3::OptionalTimePoint & deadline,const V1_3::OptionalTimeoutDuration & loopTimeoutDuration,const sp<V1_3::IExecutionCallback> & callback)559 hardware::Return<V1_3::ErrorStatus> SamplePreparedModel::execute_1_3(
560 const V1_3::Request& request, V1_2::MeasureTiming measure,
561 const V1_3::OptionalTimePoint& deadline,
562 const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
563 const sp<V1_3::IExecutionCallback>& callback) {
564 return executeBase(request, measure, mModel, *mDriver, this, mPoolInfos, deadline,
565 loopTimeoutDuration, callback);
566 }
567
568 static std::tuple<V1_3::ErrorStatus, hardware::hidl_vec<V1_2::OutputShape>, V1_2::Timing>
executeSynchronouslyBase(const V1_3::Request & request,V1_2::MeasureTiming measure,const V1_3::Model & model,const SampleDriver & driver,const SamplePreparedModel * preparedModel,const std::vector<RunTimePoolInfo> & poolInfos,const V1_3::OptionalTimePoint & halDeadline,const V1_3::OptionalTimeoutDuration & loopTimeoutDuration)569 executeSynchronouslyBase(const V1_3::Request& request, V1_2::MeasureTiming measure,
570 const V1_3::Model& model, const SampleDriver& driver,
571 const SamplePreparedModel* preparedModel,
572 const std::vector<RunTimePoolInfo>& poolInfos,
573 const V1_3::OptionalTimePoint& halDeadline,
574 const V1_3::OptionalTimeoutDuration& loopTimeoutDuration) {
575 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
576 "SampleDriver::executeSynchronouslyBase");
577 VLOG(DRIVER) << "executeSynchronouslyBase(" << SHOW_IF_DEBUG(toString(request)) << ")";
578
579 TimePoint driverStart, driverEnd, deviceStart, deviceEnd;
580 if (measure == V1_2::MeasureTiming::YES) driverStart = Clock::now();
581
582 if (!validateRequest(request, model)) {
583 return {V1_3::ErrorStatus::INVALID_ARGUMENT, {}, kNoTiming};
584 }
585 const auto deadline = convert(halDeadline).value();
586 if (hasDeadlinePassed(deadline)) {
587 return {V1_3::ErrorStatus::MISSED_DEADLINE_PERSISTENT, {}, kNoTiming};
588 }
589
590 NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INPUTS_AND_OUTPUTS,
591 "SampleDriver::executeSynchronouslyBase");
592 const auto [poolStatus, requestPoolInfos, bufferWrappers] =
593 createRunTimePoolInfos(request, driver, preparedModel);
594 if (poolStatus != V1_3::ErrorStatus::NONE) {
595 return {poolStatus, {}, kNoTiming};
596 }
597
598 NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
599 "SampleDriver::executeSynchronouslyBase");
600 CpuExecutor executor = driver.getExecutor();
601 if (loopTimeoutDuration.getDiscriminator() !=
602 V1_3::OptionalTimeoutDuration::hidl_discriminator::none) {
603 executor.setLoopTimeout(loopTimeoutDuration.nanoseconds());
604 }
605 if (deadline.has_value()) {
606 executor.setDeadline(*deadline);
607 }
608 if (measure == V1_2::MeasureTiming::YES) deviceStart = Clock::now();
609 int n = executor.run(uncheckedConvert(model), uncheckedConvert(request), poolInfos,
610 requestPoolInfos);
611 if (measure == V1_2::MeasureTiming::YES) deviceEnd = Clock::now();
612 VLOG(DRIVER) << "executor.run returned " << n;
613 V1_3::ErrorStatus executionStatus = convertResultCodeToHalErrorStatus(n);
614 hardware::hidl_vec<V1_2::OutputShape> outputShapes = convertToV1_2(executor.getOutputShapes());
615
616 // Update device memory metadata.
617 const V1_3::ErrorStatus updateStatus =
618 updateDeviceMemories(executionStatus, request, bufferWrappers, outputShapes);
619 if (updateStatus != V1_3::ErrorStatus::NONE) {
620 return {updateStatus, {}, kNoTiming};
621 }
622
623 if (measure == V1_2::MeasureTiming::YES && executionStatus == V1_3::ErrorStatus::NONE) {
624 driverEnd = Clock::now();
625 V1_2::Timing timing = {
626 .timeOnDevice = uint64_t(microsecondsDuration(deviceEnd, deviceStart)),
627 .timeInDriver = uint64_t(microsecondsDuration(driverEnd, driverStart))};
628 VLOG(DRIVER) << "executeSynchronouslyBase timing = " << toString(timing);
629 return {executionStatus, std::move(outputShapes), timing};
630 }
631 return {executionStatus, std::move(outputShapes), kNoTiming};
632 }
633
executeSynchronously(const V1_0::Request & request,V1_2::MeasureTiming measure,executeSynchronously_cb cb)634 hardware::Return<void> SamplePreparedModel::executeSynchronously(const V1_0::Request& request,
635 V1_2::MeasureTiming measure,
636 executeSynchronously_cb cb) {
637 auto [status, outputShapes, timing] = executeSynchronouslyBase(
638 convertToV1_3(request), measure, mModel, *mDriver, this, mPoolInfos, {}, {});
639 cb(convertToV1_0(status), std::move(outputShapes), timing);
640 return hardware::Void();
641 }
642
executeSynchronously_1_3(const V1_3::Request & request,V1_2::MeasureTiming measure,const V1_3::OptionalTimePoint & deadline,const V1_3::OptionalTimeoutDuration & loopTimeoutDuration,executeSynchronously_1_3_cb cb)643 hardware::Return<void> SamplePreparedModel::executeSynchronously_1_3(
644 const V1_3::Request& request, V1_2::MeasureTiming measure,
645 const V1_3::OptionalTimePoint& deadline,
646 const V1_3::OptionalTimeoutDuration& loopTimeoutDuration, executeSynchronously_1_3_cb cb) {
647 auto [status, outputShapes, timing] = executeSynchronouslyBase(
648 request, measure, mModel, *mDriver, this, mPoolInfos, deadline, loopTimeoutDuration);
649 cb(status, std::move(outputShapes), timing);
650 return hardware::Void();
651 }
652
653 // The sample driver will finish the execution and then return.
executeFenced(const V1_3::Request & request,const hardware::hidl_vec<hardware::hidl_handle> & waitFor,V1_2::MeasureTiming measure,const V1_3::OptionalTimePoint & halDeadline,const V1_3::OptionalTimeoutDuration & loopTimeoutDuration,const V1_3::OptionalTimeoutDuration & duration,executeFenced_cb cb)654 hardware::Return<void> SamplePreparedModel::executeFenced(
655 const V1_3::Request& request, const hardware::hidl_vec<hardware::hidl_handle>& waitFor,
656 V1_2::MeasureTiming measure, const V1_3::OptionalTimePoint& halDeadline,
657 const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
658 const V1_3::OptionalTimeoutDuration& duration, executeFenced_cb cb) {
659 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
660 "SamplePreparedModel::executeFenced");
661 VLOG(DRIVER) << "executeFenced(" << SHOW_IF_DEBUG(toString(request)) << ")";
662
663 TimePoint driverStart, driverEnd, deviceStart, deviceEnd;
664 if (measure == V1_2::MeasureTiming::YES) driverStart = Clock::now();
665
666 if (!validateRequest(request, mModel, /*allowUnspecifiedOutput=*/false)) {
667 cb(V1_3::ErrorStatus::INVALID_ARGUMENT, hardware::hidl_handle(nullptr), nullptr);
668 return hardware::Void();
669 }
670 const auto deadline = convert(halDeadline).value();
671 if (hasDeadlinePassed(deadline)) {
672 cb(V1_3::ErrorStatus::MISSED_DEADLINE_PERSISTENT, hardware::hidl_handle(nullptr), nullptr);
673 return hardware::Void();
674 }
675
676 // Wait for the dependent events to signal
677 for (const auto& fenceHandle : waitFor) {
678 if (!fenceHandle.getNativeHandle()) {
679 cb(V1_3::ErrorStatus::INVALID_ARGUMENT, hardware::hidl_handle(nullptr), nullptr);
680 return hardware::Void();
681 }
682 int syncFenceFd = fenceHandle.getNativeHandle()->data[0];
683 if (syncWait(syncFenceFd, -1) != FenceState::SIGNALED) {
684 LOG(ERROR) << "syncWait failed";
685 cb(V1_3::ErrorStatus::GENERAL_FAILURE, hardware::hidl_handle(nullptr), nullptr);
686 return hardware::Void();
687 }
688 }
689
690 // Update deadline if the timeout duration is closer than the deadline.
691 auto closestDeadline = deadline;
692 if (duration.getDiscriminator() != V1_3::OptionalTimeoutDuration::hidl_discriminator::none) {
693 const auto timeoutDurationDeadline = makeDeadline(duration.nanoseconds());
694 if (!closestDeadline.has_value() || *closestDeadline > timeoutDurationDeadline) {
695 closestDeadline = timeoutDurationDeadline;
696 }
697 }
698
699 TimePoint driverStartAfterFence;
700 if (measure == V1_2::MeasureTiming::YES) driverStartAfterFence = Clock::now();
701
702 NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INPUTS_AND_OUTPUTS,
703 "SamplePreparedModel::executeFenced");
704 const auto [poolStatus, requestPoolInfos, bufferWrappers] =
705 createRunTimePoolInfos(request, *mDriver, this);
706 if (poolStatus != V1_3::ErrorStatus::NONE) {
707 cb(poolStatus, hardware::hidl_handle(nullptr), nullptr);
708 return hardware::Void();
709 }
710
711 NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
712 "SamplePreparedModel::executeFenced");
713 CpuExecutor executor = mDriver->getExecutor();
714 if (loopTimeoutDuration.getDiscriminator() !=
715 V1_3::OptionalTimeoutDuration::hidl_discriminator::none) {
716 executor.setLoopTimeout(loopTimeoutDuration.nanoseconds());
717 }
718 if (closestDeadline.has_value()) {
719 executor.setDeadline(*closestDeadline);
720 }
721 if (measure == V1_2::MeasureTiming::YES) deviceStart = Clock::now();
722 int n = executor.run(uncheckedConvert(mModel), uncheckedConvert(request), mPoolInfos,
723 requestPoolInfos);
724 if (measure == V1_2::MeasureTiming::YES) deviceEnd = Clock::now();
725 VLOG(DRIVER) << "executor.run returned " << n;
726 V1_3::ErrorStatus executionStatus = convertResultCodeToHalErrorStatus(n);
727 if (executionStatus != V1_3::ErrorStatus::NONE) {
728 cb(executionStatus, hardware::hidl_handle(nullptr), nullptr);
729 return hardware::Void();
730 }
731
732 // Set output memories to the initialized state.
733 if (executionStatus == V1_3::ErrorStatus::NONE) {
734 for (const auto& output : request.outputs) {
735 const uint32_t poolIndex = output.location.poolIndex;
736 const auto& pool = request.pools[poolIndex];
737 if (pool.getDiscriminator() == V1_3::Request::MemoryPool::hidl_discriminator::token) {
738 bufferWrappers[poolIndex]->setInitialized(true);
739 }
740 }
741 }
742
743 V1_2::Timing timingSinceLaunch = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
744 V1_2::Timing timingAfterFence = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
745 if (measure == V1_2::MeasureTiming::YES) {
746 driverEnd = Clock::now();
747 timingSinceLaunch = {
748 .timeOnDevice = uint64_t(microsecondsDuration(deviceEnd, deviceStart)),
749 .timeInDriver = uint64_t(microsecondsDuration(driverEnd, driverStart))};
750 timingAfterFence = {
751 .timeOnDevice = uint64_t(microsecondsDuration(deviceEnd, deviceStart)),
752 .timeInDriver = uint64_t(microsecondsDuration(driverEnd, driverStartAfterFence))};
753 VLOG(DRIVER) << "executeFenced timingSinceLaunch = " << toString(timingSinceLaunch);
754 VLOG(DRIVER) << "executeFenced timingAfterFence = " << toString(timingAfterFence);
755 }
756 sp<SampleFencedExecutionCallback> fencedExecutionCallback =
757 new SampleFencedExecutionCallback(timingSinceLaunch, timingAfterFence, executionStatus);
758 cb(executionStatus, hardware::hidl_handle(nullptr), fencedExecutionCallback);
759 return hardware::Void();
760 }
761
762 // BurstExecutorWithCache maps hidl_memory when it is first seen, and preserves
763 // the mapping until either (1) the memory is freed in the runtime, or (2) the
764 // burst object is destroyed. This allows for subsequent executions operating on
765 // pools that have been used before to reuse the mapping instead of mapping and
766 // unmapping the memory on each execution.
767 class BurstExecutorWithCache : public ExecutionBurstServer::IBurstExecutorWithCache {
768 public:
BurstExecutorWithCache(const V1_3::Model & model,const SampleDriver * driver,const std::vector<RunTimePoolInfo> & poolInfos)769 BurstExecutorWithCache(const V1_3::Model& model, const SampleDriver* driver,
770 const std::vector<RunTimePoolInfo>& poolInfos)
771 : mModel(model), mDriver(driver), mModelPoolInfos(poolInfos) {}
772
isCacheEntryPresent(int32_t slot) const773 bool isCacheEntryPresent(int32_t slot) const override {
774 const auto it = mMemoryCache.find(slot);
775 return (it != mMemoryCache.end()) && it->second.has_value();
776 }
777
addCacheEntry(const hardware::hidl_memory & memory,int32_t slot)778 void addCacheEntry(const hardware::hidl_memory& memory, int32_t slot) override {
779 mMemoryCache[slot] = RunTimePoolInfo::createFromMemory(uncheckedConvert(memory));
780 }
781
removeCacheEntry(int32_t slot)782 void removeCacheEntry(int32_t slot) override { mMemoryCache.erase(slot); }
783
execute(const V1_0::Request & request,const std::vector<int32_t> & slots,V1_2::MeasureTiming measure)784 std::tuple<V1_0::ErrorStatus, hardware::hidl_vec<V1_2::OutputShape>, V1_2::Timing> execute(
785 const V1_0::Request& request, const std::vector<int32_t>& slots,
786 V1_2::MeasureTiming measure) override {
787 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
788 "BurstExecutorWithCache::execute");
789
790 TimePoint driverStart, driverEnd, deviceStart, deviceEnd;
791 if (measure == V1_2::MeasureTiming::YES) driverStart = Clock::now();
792
793 // ensure all relevant pools are valid
794 if (!std::all_of(slots.begin(), slots.end(),
795 [this](int32_t slot) { return isCacheEntryPresent(slot); })) {
796 return {V1_0::ErrorStatus::INVALID_ARGUMENT, {}, kNoTiming};
797 }
798
799 // finish the request object (for validation)
800 hardware::hidl_vec<V1_3::Request::MemoryPool> pools(slots.size());
801 std::transform(slots.begin(), slots.end(), pools.begin(), [this](int32_t slot) {
802 V1_3::Request::MemoryPool pool;
803 pool.hidlMemory(convertToV1_0(mMemoryCache[slot]->getMemory()));
804 return pool;
805 });
806 V1_3::Request fullRequest = {.inputs = request.inputs, .outputs = request.outputs};
807 fullRequest.pools = std::move(pools);
808
809 // validate request object against the model
810 if (!validateRequest(fullRequest, mModel)) {
811 return {V1_0::ErrorStatus::INVALID_ARGUMENT, {}, kNoTiming};
812 }
813
814 // select relevant entries from cache
815 std::vector<RunTimePoolInfo> requestPoolInfos;
816 requestPoolInfos.reserve(slots.size());
817 std::transform(slots.begin(), slots.end(), std::back_inserter(requestPoolInfos),
818 [this](int32_t slot) { return *mMemoryCache[slot]; });
819
820 // execution
821 // Configuring the loop timeout duration is not supported. This is OK
822 // because burst does not support HAL 1.3 and hence does not support
823 // WHILE loops.
824 CpuExecutor executor = mDriver->getExecutor();
825 if (measure == V1_2::MeasureTiming::YES) deviceStart = Clock::now();
826 int n = executor.run(uncheckedConvert(mModel), uncheckedConvert(fullRequest),
827 mModelPoolInfos, requestPoolInfos);
828 if (measure == V1_2::MeasureTiming::YES) deviceEnd = Clock::now();
829 VLOG(DRIVER) << "executor.run returned " << n;
830 V1_0::ErrorStatus executionStatus = convertToV1_0(convertResultCodeToHalErrorStatus(n));
831 hardware::hidl_vec<V1_2::OutputShape> outputShapes =
832 convertToV1_2(executor.getOutputShapes());
833 if (measure == V1_2::MeasureTiming::YES && executionStatus == V1_0::ErrorStatus::NONE) {
834 driverEnd = Clock::now();
835 V1_2::Timing timing = {
836 .timeOnDevice = uint64_t(microsecondsDuration(deviceEnd, deviceStart)),
837 .timeInDriver = uint64_t(microsecondsDuration(driverEnd, driverStart))};
838 VLOG(DRIVER) << "BurstExecutorWithCache::execute timing = " << toString(timing);
839 return std::make_tuple(executionStatus, outputShapes, timing);
840 } else {
841 return std::make_tuple(executionStatus, outputShapes, kNoTiming);
842 }
843 }
844
845 private:
846 const V1_3::Model mModel;
847 const SampleDriver* const mDriver;
848 const std::vector<RunTimePoolInfo> mModelPoolInfos;
849 std::map<int32_t, std::optional<RunTimePoolInfo>> mMemoryCache; // cached requestPoolInfos
850 };
851
852 // This is the amount of time the ExecutionBurstServer should spend polling the
853 // FMQ to see if it has data available before it should fall back to waiting on
854 // the futex.
getPollingTimeWindow()855 static std::chrono::microseconds getPollingTimeWindow() {
856 constexpr int32_t defaultPollingTimeWindow = 50;
857 #ifdef NN_DEBUGGABLE
858 constexpr int32_t minPollingTimeWindow = 0;
859 const int32_t selectedPollingTimeWindow =
860 base::GetIntProperty("debug.nn.sample-driver-burst-polling-window",
861 defaultPollingTimeWindow, minPollingTimeWindow);
862 return std::chrono::microseconds{selectedPollingTimeWindow};
863 #else
864 return std::chrono::microseconds{defaultPollingTimeWindow};
865 #endif // NN_DEBUGGABLE
866 }
867
configureExecutionBurst(const sp<V1_2::IBurstCallback> & callback,const MQDescriptorSync<V1_2::FmqRequestDatum> & requestChannel,const MQDescriptorSync<V1_2::FmqResultDatum> & resultChannel,configureExecutionBurst_cb cb)868 hardware::Return<void> SamplePreparedModel::configureExecutionBurst(
869 const sp<V1_2::IBurstCallback>& callback,
870 const MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel,
871 const MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel,
872 configureExecutionBurst_cb cb) {
873 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
874 "SampleDriver::configureExecutionBurst");
875
876 const bool preferPowerOverLatency = (kPreference == V1_1::ExecutionPreference::LOW_POWER);
877 const auto pollingTimeWindow =
878 (preferPowerOverLatency ? std::chrono::microseconds{0} : getPollingTimeWindow());
879
880 // Alternatively, the burst could be configured via:
881 // const sp<V1_2::IBurstContext> burst =
882 // ExecutionBurstServer::create(callback, requestChannel,
883 // resultChannel, this,
884 // pollingTimeWindow);
885 //
886 // However, this alternative representation does not include a memory map
887 // caching optimization, and adds overhead.
888 const std::shared_ptr<BurstExecutorWithCache> executorWithCache =
889 std::make_shared<BurstExecutorWithCache>(mModel, mDriver, mPoolInfos);
890 const sp<V1_2::IBurstContext> burst = ExecutionBurstServer::create(
891 callback, requestChannel, resultChannel, executorWithCache, pollingTimeWindow);
892
893 if (burst == nullptr) {
894 cb(V1_0::ErrorStatus::GENERAL_FAILURE, {});
895 } else {
896 cb(V1_0::ErrorStatus::NONE, burst);
897 }
898
899 return hardware::Void();
900 }
901
902 } // namespace sample_driver
903 } // namespace nn
904 } // namespace android
905