1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "SampleDriver"
18
19 #include "SampleDriver.h"
20
21 #include <android-base/logging.h>
22 #include <android-base/properties.h>
23 #include <android/sync.h>
24 #include <hidl/LegacySupport.h>
25
26 #include <algorithm>
27 #include <chrono>
28 #include <map>
29 #include <memory>
30 #include <optional>
31 #include <set>
32 #include <thread>
33 #include <tuple>
34 #include <utility>
35 #include <vector>
36
37 #include "BufferTracker.h"
38 #include "CpuExecutor.h"
39 #include "ExecutionBurstServer.h"
40 #include "HalInterfaces.h"
41 #include "SampleDriverUtils.h"
42 #include "Tracing.h"
43 #include "ValidateHal.h"
44
45 namespace android {
46 namespace nn {
47 namespace sample_driver {
48
49 namespace {
50
51 using namespace hal;
52
53 using time_point = std::chrono::steady_clock::time_point;
54
now()55 auto now() {
56 return std::chrono::steady_clock::now();
57 };
58
microsecondsDuration(decltype(now ()) end,decltype(now ()) start)59 auto microsecondsDuration(decltype(now()) end, decltype(now()) start) {
60 return std::chrono::duration_cast<std::chrono::microseconds>(end - start).count();
61 };
62
63 } // namespace
64
65 static const Timing kNoTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
66
getCapabilities(getCapabilities_cb cb)67 Return<void> SampleDriver::getCapabilities(getCapabilities_cb cb) {
68 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
69 "SampleDriver::getCapabilities");
70 return getCapabilities_1_3([&](ErrorStatus error, const V1_3::Capabilities& capabilities) {
71 // TODO(dgross): Do we need to check compliantWithV1_0(capabilities)?
72 cb(convertToV1_0(error), convertToV1_0(capabilities));
73 });
74 }
75
getCapabilities_1_1(getCapabilities_1_1_cb cb)76 Return<void> SampleDriver::getCapabilities_1_1(getCapabilities_1_1_cb cb) {
77 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
78 "SampleDriver::getCapabilities_1_1");
79 return getCapabilities_1_3([&](ErrorStatus error, const V1_3::Capabilities& capabilities) {
80 // TODO(dgross): Do we need to check compliantWithV1_1(capabilities)?
81 cb(convertToV1_0(error), convertToV1_1(capabilities));
82 });
83 }
84
getCapabilities_1_2(getCapabilities_1_2_cb cb)85 Return<void> SampleDriver::getCapabilities_1_2(getCapabilities_1_2_cb cb) {
86 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
87 "SampleDriver::getCapabilities_1_2");
88 return getCapabilities_1_3([&](ErrorStatus error, const V1_3::Capabilities& capabilities) {
89 // TODO(dgross): Do we need to check compliantWithV1_2(capabilities)?
90 cb(convertToV1_0(error), convertToV1_2(capabilities));
91 });
92 }
93
getVersionString(getVersionString_cb cb)94 Return<void> SampleDriver::getVersionString(getVersionString_cb cb) {
95 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
96 "SampleDriver::getVersionString");
97 cb(V1_0::ErrorStatus::NONE, "JUST_AN_EXAMPLE");
98 return Void();
99 }
100
getType(getType_cb cb)101 Return<void> SampleDriver::getType(getType_cb cb) {
102 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION, "SampleDriver::getType");
103 cb(V1_0::ErrorStatus::NONE, V1_2::DeviceType::CPU);
104 return Void();
105 }
106
getSupportedExtensions(getSupportedExtensions_cb cb)107 Return<void> SampleDriver::getSupportedExtensions(getSupportedExtensions_cb cb) {
108 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
109 "SampleDriver::getSupportedExtensions");
110 cb(V1_0::ErrorStatus::NONE, {/* No extensions. */});
111 return Void();
112 }
113
getSupportedOperations(const V1_0::Model & model,getSupportedOperations_cb cb)114 Return<void> SampleDriver::getSupportedOperations(const V1_0::Model& model,
115 getSupportedOperations_cb cb) {
116 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION,
117 "SampleDriver::getSupportedOperations");
118 if (!validateModel(model)) {
119 VLOG(DRIVER) << "getSupportedOperations";
120 cb(V1_0::ErrorStatus::INVALID_ARGUMENT, {});
121 return Void();
122 }
123 return getSupportedOperations_1_3(convertToV1_3(model),
124 [&](ErrorStatus status, const hidl_vec<bool>& supported) {
125 cb(convertToV1_0(status), supported);
126 });
127 }
128
getSupportedOperations_1_1(const V1_1::Model & model,getSupportedOperations_1_1_cb cb)129 Return<void> SampleDriver::getSupportedOperations_1_1(const V1_1::Model& model,
130 getSupportedOperations_1_1_cb cb) {
131 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION,
132 "SampleDriver::getSupportedOperations_1_1");
133 if (!validateModel(model)) {
134 VLOG(DRIVER) << "getSupportedOperations_1_1";
135 cb(V1_0::ErrorStatus::INVALID_ARGUMENT, {});
136 return Void();
137 }
138 return getSupportedOperations_1_3(convertToV1_3(model),
139 [&](ErrorStatus status, const hidl_vec<bool>& supported) {
140 cb(convertToV1_0(status), supported);
141 });
142 }
143
getSupportedOperations_1_2(const V1_2::Model & model,getSupportedOperations_1_2_cb cb)144 Return<void> SampleDriver::getSupportedOperations_1_2(const V1_2::Model& model,
145 getSupportedOperations_1_2_cb cb) {
146 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION,
147 "SampleDriver::getSupportedOperations_1_2");
148 if (!validateModel(model)) {
149 VLOG(DRIVER) << "getSupportedOperations_1_2";
150 cb(V1_0::ErrorStatus::INVALID_ARGUMENT, {});
151 return Void();
152 }
153 return getSupportedOperations_1_3(convertToV1_3(model),
154 [&](ErrorStatus status, const hidl_vec<bool>& supported) {
155 cb(convertToV1_0(status), supported);
156 });
157 }
158
getNumberOfCacheFilesNeeded(getNumberOfCacheFilesNeeded_cb cb)159 Return<void> SampleDriver::getNumberOfCacheFilesNeeded(getNumberOfCacheFilesNeeded_cb cb) {
160 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
161 "SampleDriver::getNumberOfCacheFilesNeeded");
162 // Set both numbers to be 0 for cache not supported.
163 cb(V1_0::ErrorStatus::NONE, /*numModelCache=*/0, /*numDataCache=*/0);
164 return Void();
165 }
166
prepareModel(const V1_0::Model & model,const sp<V1_0::IPreparedModelCallback> & callback)167 Return<V1_0::ErrorStatus> SampleDriver::prepareModel(
168 const V1_0::Model& model, const sp<V1_0::IPreparedModelCallback>& callback) {
169 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION, "SampleDriver::prepareModel");
170 const ErrorStatus status = prepareModelBase(
171 model, this, ExecutionPreference::FAST_SINGLE_ANSWER, kDefaultPriority, {}, callback);
172 return convertToV1_0(status);
173 }
174
prepareModel_1_1(const V1_1::Model & model,ExecutionPreference preference,const sp<V1_0::IPreparedModelCallback> & callback)175 Return<V1_0::ErrorStatus> SampleDriver::prepareModel_1_1(
176 const V1_1::Model& model, ExecutionPreference preference,
177 const sp<V1_0::IPreparedModelCallback>& callback) {
178 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION, "SampleDriver::prepareModel_1_1");
179 const ErrorStatus status =
180 prepareModelBase(model, this, preference, kDefaultPriority, {}, callback);
181 return convertToV1_0(status);
182 }
183
prepareModel_1_2(const V1_2::Model & model,ExecutionPreference preference,const hidl_vec<hidl_handle> &,const hidl_vec<hidl_handle> &,const CacheToken &,const sp<V1_2::IPreparedModelCallback> & callback)184 Return<V1_0::ErrorStatus> SampleDriver::prepareModel_1_2(
185 const V1_2::Model& model, ExecutionPreference preference, const hidl_vec<hidl_handle>&,
186 const hidl_vec<hidl_handle>&, const CacheToken&,
187 const sp<V1_2::IPreparedModelCallback>& callback) {
188 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION, "SampleDriver::prepareModel_1_2");
189 const ErrorStatus status =
190 prepareModelBase(model, this, preference, kDefaultPriority, {}, callback);
191 return convertToV1_0(status);
192 }
193
prepareModel_1_3(const V1_3::Model & model,ExecutionPreference preference,Priority priority,const OptionalTimePoint & deadline,const hidl_vec<hidl_handle> &,const hidl_vec<hidl_handle> &,const CacheToken &,const sp<V1_3::IPreparedModelCallback> & callback)194 Return<V1_3::ErrorStatus> SampleDriver::prepareModel_1_3(
195 const V1_3::Model& model, ExecutionPreference preference, Priority priority,
196 const OptionalTimePoint& deadline, const hidl_vec<hidl_handle>&,
197 const hidl_vec<hidl_handle>&, const CacheToken&,
198 const sp<V1_3::IPreparedModelCallback>& callback) {
199 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION, "SampleDriver::prepareModel_1_3");
200 return prepareModelBase(model, this, preference, priority, deadline, callback);
201 }
202
prepareModelFromCache(const hidl_vec<hidl_handle> &,const hidl_vec<hidl_handle> &,const CacheToken &,const sp<V1_2::IPreparedModelCallback> & callback)203 Return<V1_0::ErrorStatus> SampleDriver::prepareModelFromCache(
204 const hidl_vec<hidl_handle>&, const hidl_vec<hidl_handle>&, const CacheToken&,
205 const sp<V1_2::IPreparedModelCallback>& callback) {
206 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION,
207 "SampleDriver::prepareModelFromCache");
208 notify(callback, ErrorStatus::GENERAL_FAILURE, nullptr);
209 return V1_0::ErrorStatus::GENERAL_FAILURE;
210 }
211
prepareModelFromCache_1_3(const OptionalTimePoint &,const hidl_vec<hidl_handle> &,const hidl_vec<hidl_handle> &,const CacheToken &,const sp<V1_3::IPreparedModelCallback> & callback)212 Return<ErrorStatus> SampleDriver::prepareModelFromCache_1_3(
213 const OptionalTimePoint& /*deadline*/, const hidl_vec<hidl_handle>&,
214 const hidl_vec<hidl_handle>&, const CacheToken&,
215 const sp<V1_3::IPreparedModelCallback>& callback) {
216 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION,
217 "SampleDriver::prepareModelFromCache_1_3");
218 notify(callback, ErrorStatus::GENERAL_FAILURE, nullptr);
219 return ErrorStatus::GENERAL_FAILURE;
220 }
221
getStatus()222 Return<DeviceStatus> SampleDriver::getStatus() {
223 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_UNSPECIFIED, "SampleDriver::getStatus");
224 VLOG(DRIVER) << "getStatus()";
225 return DeviceStatus::AVAILABLE;
226 }
227
228 // Safely downcast an IPreparedModel object to SamplePreparedModel.
229 // This function will return nullptr if the IPreparedModel object is not originated from the sample
230 // driver process.
castToSamplePreparedModel(const sp<IPreparedModel> & preparedModel)231 static const SamplePreparedModel* castToSamplePreparedModel(
232 const sp<IPreparedModel>& preparedModel) {
233 if (preparedModel->isRemote()) {
234 return nullptr;
235 } else {
236 // This static_cast is safe because SamplePreparedModel is the only class that implements
237 // the IPreparedModel interface in the sample driver process.
238 return static_cast<const SamplePreparedModel*>(preparedModel.get());
239 }
240 }
241
allocate(const V1_3::BufferDesc & desc,const hidl_vec<sp<V1_3::IPreparedModel>> & preparedModels,const hidl_vec<V1_3::BufferRole> & inputRoles,const hidl_vec<V1_3::BufferRole> & outputRoles,allocate_cb cb)242 Return<void> SampleDriver::allocate(const V1_3::BufferDesc& desc,
243 const hidl_vec<sp<V1_3::IPreparedModel>>& preparedModels,
244 const hidl_vec<V1_3::BufferRole>& inputRoles,
245 const hidl_vec<V1_3::BufferRole>& outputRoles, allocate_cb cb) {
246 constexpr uint32_t kInvalidBufferToken = 0;
247
248 VLOG(DRIVER) << "SampleDriver::allocate";
249 std::set<PreparedModelRole> roles;
250 V1_3::Operand operand;
251 auto getModel = [](const sp<V1_3::IPreparedModel>& preparedModel) -> const V1_3::Model* {
252 const auto* samplePreparedModel = castToSamplePreparedModel(preparedModel);
253 if (samplePreparedModel == nullptr) {
254 LOG(ERROR) << "SampleDriver::allocate -- unknown remote IPreparedModel.";
255 return nullptr;
256 }
257 return samplePreparedModel->getModel();
258 };
259 if (!validateMemoryDesc(desc, preparedModels, inputRoles, outputRoles, getModel, &roles,
260 &operand)) {
261 LOG(ERROR) << "SampleDriver::allocate -- validation failed.";
262 cb(ErrorStatus::INVALID_ARGUMENT, nullptr, kInvalidBufferToken);
263 return Void();
264 }
265
266 if (isExtensionOperandType(operand.type)) {
267 LOG(ERROR) << "SampleDriver::allocate -- does not support extension type.";
268 cb(ErrorStatus::GENERAL_FAILURE, nullptr, kInvalidBufferToken);
269 return Void();
270 }
271
272 // TODO(xusongw): Support allocating buffers with unknown dimensions or rank.
273 uint32_t size = nonExtensionOperandSizeOfData(operand.type, operand.dimensions);
274 VLOG(DRIVER) << "SampleDriver::allocate -- type = " << toString(operand.type)
275 << ", dimensions = " << toString(operand.dimensions) << ", size = " << size;
276 if (size == 0) {
277 LOG(ERROR) << "SampleDriver::allocate -- does not support dynamic output shape.";
278 cb(ErrorStatus::GENERAL_FAILURE, nullptr, kInvalidBufferToken);
279 return Void();
280 }
281
282 auto bufferWrapper = ManagedBuffer::create(size, std::move(roles), std::move(operand));
283 if (bufferWrapper == nullptr) {
284 LOG(ERROR) << "SampleDriver::allocate -- not enough memory.";
285 cb(ErrorStatus::GENERAL_FAILURE, nullptr, kInvalidBufferToken);
286 return Void();
287 }
288
289 auto token = mBufferTracker->add(bufferWrapper);
290 if (token == nullptr) {
291 LOG(ERROR) << "SampleDriver::allocate -- BufferTracker returned invalid token.";
292 cb(ErrorStatus::GENERAL_FAILURE, nullptr, kInvalidBufferToken);
293 return Void();
294 }
295
296 const uint32_t tokenValue = token->get();
297 sp<SampleBuffer> sampleBuffer = new SampleBuffer(std::move(bufferWrapper), std::move(token));
298 VLOG(DRIVER) << "SampleDriver::allocate -- successfully allocates the requested memory";
299 cb(ErrorStatus::NONE, std::move(sampleBuffer), tokenValue);
300 return Void();
301 }
302
run()303 int SampleDriver::run() {
304 android::hardware::configureRpcThreadpool(4, true);
305 if (registerAsService(mName) != android::OK) {
306 LOG(ERROR) << "Could not register service";
307 return 1;
308 }
309 android::hardware::joinRpcThreadpool();
310 LOG(ERROR) << "Service exited!";
311 return 1;
312 }
313
copyRunTimePoolInfos(const RunTimePoolInfo & srcPool,const RunTimePoolInfo & dstPool)314 static void copyRunTimePoolInfos(const RunTimePoolInfo& srcPool, const RunTimePoolInfo& dstPool) {
315 CHECK(srcPool.getBuffer() != nullptr);
316 CHECK(dstPool.getBuffer() != nullptr);
317 CHECK(srcPool.getSize() == dstPool.getSize());
318 std::copy(srcPool.getBuffer(), srcPool.getBuffer() + srcPool.getSize(), dstPool.getBuffer());
319 dstPool.flush();
320 }
321
copyTo(const hidl_memory & dst)322 Return<ErrorStatus> SampleBuffer::copyTo(const hidl_memory& dst) {
323 const auto dstPool = RunTimePoolInfo::createFromHidlMemory(dst);
324 if (!dstPool.has_value()) {
325 LOG(ERROR) << "SampleBuffer::copyTo -- unable to map dst memory.";
326 return ErrorStatus::GENERAL_FAILURE;
327 }
328 const ErrorStatus validationStatus = kBuffer->validateCopyTo(dstPool->getSize());
329 if (validationStatus != ErrorStatus::NONE) {
330 return validationStatus;
331 }
332 const auto srcPool = kBuffer->createRunTimePoolInfo();
333 copyRunTimePoolInfos(srcPool, dstPool.value());
334 return ErrorStatus::NONE;
335 }
336
copyFromInternal(const hidl_memory & src,const hidl_vec<uint32_t> & dimensions,const std::shared_ptr<ManagedBuffer> & bufferWrapper)337 static ErrorStatus copyFromInternal(const hidl_memory& src, const hidl_vec<uint32_t>& dimensions,
338 const std::shared_ptr<ManagedBuffer>& bufferWrapper) {
339 CHECK(bufferWrapper != nullptr);
340 const auto srcPool = RunTimePoolInfo::createFromHidlMemory(src);
341 if (!srcPool.has_value()) {
342 LOG(ERROR) << "SampleBuffer::copyFrom -- unable to map src memory.";
343 return ErrorStatus::GENERAL_FAILURE;
344 }
345 const ErrorStatus validationStatus =
346 bufferWrapper->validateCopyFrom(dimensions, srcPool->getSize());
347 if (validationStatus != ErrorStatus::NONE) {
348 return validationStatus;
349 }
350 const auto dstPool = bufferWrapper->createRunTimePoolInfo();
351 copyRunTimePoolInfos(srcPool.value(), dstPool);
352 return ErrorStatus::NONE;
353 }
354
copyFrom(const hidl_memory & src,const hidl_vec<uint32_t> & dimensions)355 Return<ErrorStatus> SampleBuffer::copyFrom(const hidl_memory& src,
356 const hidl_vec<uint32_t>& dimensions) {
357 const auto status = copyFromInternal(src, dimensions, kBuffer);
358 if (status == ErrorStatus::NONE) {
359 kBuffer->updateDimensions(dimensions);
360 kBuffer->setInitialized(true);
361 } else {
362 kBuffer->setInitialized(false);
363 }
364 return status;
365 }
366
initialize()367 bool SamplePreparedModel::initialize() {
368 return setRunTimePoolInfosFromHidlMemories(&mPoolInfos, mModel.pools);
369 }
370
371 static std::tuple<ErrorStatus, std::vector<RunTimePoolInfo>,
372 std::vector<std::shared_ptr<ManagedBuffer>>>
createRunTimePoolInfos(const Request & request,const SampleDriver & driver,const SamplePreparedModel * preparedModel)373 createRunTimePoolInfos(const Request& request, const SampleDriver& driver,
374 const SamplePreparedModel* preparedModel) {
375 std::vector<RunTimePoolInfo> requestPoolInfos;
376 std::vector<std::shared_ptr<ManagedBuffer>> bufferWrappers;
377 requestPoolInfos.reserve(request.pools.size());
378 bufferWrappers.reserve(request.pools.size());
379 for (uint32_t i = 0; i < request.pools.size(); i++) {
380 auto& pool = request.pools[i];
381 switch (pool.getDiscriminator()) {
382 case Request::MemoryPool::hidl_discriminator::hidlMemory: {
383 auto buffer = RunTimePoolInfo::createFromHidlMemory(pool.hidlMemory());
384 if (!buffer.has_value()) {
385 LOG(ERROR) << "createRuntimeMemoriesFromMemoryPools -- could not map pools";
386 return {ErrorStatus::GENERAL_FAILURE, {}, {}};
387 }
388 requestPoolInfos.push_back(std::move(*buffer));
389 bufferWrappers.push_back(nullptr);
390 } break;
391 case Request::MemoryPool::hidl_discriminator::token: {
392 auto bufferWrapper = driver.getBufferTracker()->get(pool.token());
393 if (bufferWrapper == nullptr) {
394 return {ErrorStatus::INVALID_ARGUMENT, {}, {}};
395 }
396 const auto validationStatus =
397 bufferWrapper->validateRequest(i, request, preparedModel);
398 if (validationStatus != ErrorStatus::NONE) {
399 return {validationStatus, {}, {}};
400 }
401 requestPoolInfos.push_back(bufferWrapper->createRunTimePoolInfo());
402 bufferWrappers.push_back(std::move(bufferWrapper));
403 } break;
404 }
405 }
406 return {ErrorStatus::NONE, std::move(requestPoolInfos), std::move(bufferWrappers)};
407 }
408
updateDeviceMemories(ErrorStatus status,const Request & request,const std::vector<std::shared_ptr<ManagedBuffer>> & bufferWrappers,const hidl_vec<OutputShape> & outputShapes)409 static ErrorStatus updateDeviceMemories(
410 ErrorStatus status, const Request& request,
411 const std::vector<std::shared_ptr<ManagedBuffer>>& bufferWrappers,
412 const hidl_vec<OutputShape>& outputShapes) {
413 if (status == ErrorStatus::NONE) {
414 for (uint32_t i = 0; i < request.outputs.size(); i++) {
415 const uint32_t poolIndex = request.outputs[i].location.poolIndex;
416 const auto& pool = request.pools[poolIndex];
417 if (pool.getDiscriminator() == Request::MemoryPool::hidl_discriminator::token) {
418 if (!bufferWrappers[poolIndex]->updateDimensions(outputShapes[i].dimensions)) {
419 return ErrorStatus::GENERAL_FAILURE;
420 }
421 }
422 }
423 for (uint32_t i = 0; i < request.outputs.size(); i++) {
424 const uint32_t poolIndex = request.outputs[i].location.poolIndex;
425 const auto& pool = request.pools[poolIndex];
426 if (pool.getDiscriminator() == Request::MemoryPool::hidl_discriminator::token) {
427 bufferWrappers[poolIndex]->setInitialized(true);
428 }
429 }
430 } else if (status == ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
431 // If CpuExecutor reports OUTPUT_INSUFFCIENT_SIZE on a device memory, this is because the
432 // dimensions of the device memory are incorrectly specified. The driver should return
433 // GENERAL_FAILURE instead in this case.
434 for (uint32_t i = 0; i < request.outputs.size(); i++) {
435 const uint32_t poolIndex = request.outputs[i].location.poolIndex;
436 const auto& pool = request.pools[poolIndex];
437 if (pool.getDiscriminator() == Request::MemoryPool::hidl_discriminator::token) {
438 if (!outputShapes[i].isSufficient) {
439 LOG(ERROR) << "Invalid dimensions for output " << i
440 << ": actual shape = " << toString(outputShapes[i].dimensions);
441 return ErrorStatus::GENERAL_FAILURE;
442 }
443 }
444 }
445 }
446 return ErrorStatus::NONE;
447 }
448
449 template <typename T_IExecutionCallback>
asyncExecute(const Request & request,MeasureTiming measure,time_point driverStart,const Model & model,const SampleDriver & driver,const SamplePreparedModel * preparedModel,const std::vector<RunTimePoolInfo> & poolInfos,const std::optional<Deadline> & deadline,const OptionalTimeoutDuration & loopTimeoutDuration,const sp<T_IExecutionCallback> & callback)450 void asyncExecute(const Request& request, MeasureTiming measure, time_point driverStart,
451 const Model& model, const SampleDriver& driver,
452 const SamplePreparedModel* preparedModel,
453 const std::vector<RunTimePoolInfo>& poolInfos,
454 const std::optional<Deadline>& deadline,
455 const OptionalTimeoutDuration& loopTimeoutDuration,
456 const sp<T_IExecutionCallback>& callback) {
457 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INPUTS_AND_OUTPUTS,
458 "SampleDriver::asyncExecute");
459
460 const auto [poolStatus, requestPoolInfos, bufferWrappers] =
461 createRunTimePoolInfos(request, driver, preparedModel);
462 if (poolStatus != ErrorStatus::NONE) {
463 notify(callback, poolStatus, {}, kNoTiming);
464 return;
465 }
466
467 NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
468 "SampleDriver::asyncExecute");
469 CpuExecutor executor = driver.getExecutor();
470 if (loopTimeoutDuration.getDiscriminator() !=
471 OptionalTimeoutDuration::hidl_discriminator::none) {
472 executor.setLoopTimeout(loopTimeoutDuration.nanoseconds());
473 }
474 if (deadline.has_value()) {
475 executor.setDeadline(*deadline);
476 }
477 time_point driverEnd, deviceStart, deviceEnd;
478 if (measure == MeasureTiming::YES) deviceStart = now();
479 int n = executor.run(model, request, poolInfos, requestPoolInfos);
480 if (measure == MeasureTiming::YES) deviceEnd = now();
481 VLOG(DRIVER) << "executor.run returned " << n;
482 ErrorStatus executionStatus = convertResultCodeToErrorStatus(n);
483 hidl_vec<OutputShape> outputShapes = executor.getOutputShapes();
484
485 // Update device memory metadata.
486 const ErrorStatus updateStatus =
487 updateDeviceMemories(executionStatus, request, bufferWrappers, outputShapes);
488 if (updateStatus != ErrorStatus::NONE) {
489 notify(callback, updateStatus, {}, kNoTiming);
490 return;
491 }
492
493 if (measure == MeasureTiming::YES && executionStatus == ErrorStatus::NONE) {
494 driverEnd = now();
495 Timing timing = {.timeOnDevice = uint64_t(microsecondsDuration(deviceEnd, deviceStart)),
496 .timeInDriver = uint64_t(microsecondsDuration(driverEnd, driverStart))};
497 VLOG(DRIVER) << "SampleDriver::asyncExecute timing = " << toString(timing);
498 notify(callback, executionStatus, outputShapes, timing);
499 } else {
500 notify(callback, executionStatus, outputShapes, kNoTiming);
501 }
502 }
503
504 template <typename T_IExecutionCallback>
executeBase(const Request & request,MeasureTiming measure,const Model & model,const SampleDriver & driver,const SamplePreparedModel * preparedModel,const std::vector<RunTimePoolInfo> & poolInfos,const OptionalTimePoint & halDeadline,const OptionalTimeoutDuration & loopTimeoutDuration,const sp<T_IExecutionCallback> & callback)505 ErrorStatus executeBase(const Request& request, MeasureTiming measure, const Model& model,
506 const SampleDriver& driver, const SamplePreparedModel* preparedModel,
507 const std::vector<RunTimePoolInfo>& poolInfos,
508 const OptionalTimePoint& halDeadline,
509 const OptionalTimeoutDuration& loopTimeoutDuration,
510 const sp<T_IExecutionCallback>& callback) {
511 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION, "SampleDriver::executeBase");
512 VLOG(DRIVER) << "executeBase(" << SHOW_IF_DEBUG(toString(request)) << ")";
513
514 time_point driverStart;
515 if (measure == MeasureTiming::YES) driverStart = now();
516
517 if (callback.get() == nullptr) {
518 LOG(ERROR) << "invalid callback passed to executeBase";
519 return ErrorStatus::INVALID_ARGUMENT;
520 }
521 if (!validateRequest(request, model)) {
522 notify(callback, ErrorStatus::INVALID_ARGUMENT, {}, kNoTiming);
523 return ErrorStatus::INVALID_ARGUMENT;
524 }
525 const auto deadline = makeDeadline(halDeadline);
526 if (hasDeadlinePassed(deadline)) {
527 notify(callback, ErrorStatus::MISSED_DEADLINE_PERSISTENT, {}, kNoTiming);
528 return ErrorStatus::NONE;
529 }
530
531 // This thread is intentionally detached because the sample driver service
532 // is expected to live forever.
533 std::thread([&model, &driver, preparedModel, &poolInfos, request, measure, driverStart,
534 deadline, loopTimeoutDuration, callback] {
535 asyncExecute(request, measure, driverStart, model, driver, preparedModel, poolInfos,
536 deadline, loopTimeoutDuration, callback);
537 }).detach();
538
539 return ErrorStatus::NONE;
540 }
541
execute(const V1_0::Request & request,const sp<V1_0::IExecutionCallback> & callback)542 Return<V1_0::ErrorStatus> SamplePreparedModel::execute(
543 const V1_0::Request& request, const sp<V1_0::IExecutionCallback>& callback) {
544 const ErrorStatus status = executeBase(convertToV1_3(request), MeasureTiming::NO, mModel,
545 *mDriver, this, mPoolInfos, {}, {}, callback);
546 return convertToV1_0(status);
547 }
548
execute_1_2(const V1_0::Request & request,MeasureTiming measure,const sp<V1_2::IExecutionCallback> & callback)549 Return<V1_0::ErrorStatus> SamplePreparedModel::execute_1_2(
550 const V1_0::Request& request, MeasureTiming measure,
551 const sp<V1_2::IExecutionCallback>& callback) {
552 const ErrorStatus status = executeBase(convertToV1_3(request), measure, mModel, *mDriver, this,
553 mPoolInfos, {}, {}, callback);
554 return convertToV1_0(status);
555 }
556
execute_1_3(const V1_3::Request & request,MeasureTiming measure,const OptionalTimePoint & deadline,const OptionalTimeoutDuration & loopTimeoutDuration,const sp<V1_3::IExecutionCallback> & callback)557 Return<V1_3::ErrorStatus> SamplePreparedModel::execute_1_3(
558 const V1_3::Request& request, MeasureTiming measure, const OptionalTimePoint& deadline,
559 const OptionalTimeoutDuration& loopTimeoutDuration,
560 const sp<V1_3::IExecutionCallback>& callback) {
561 return executeBase(request, measure, mModel, *mDriver, this, mPoolInfos, deadline,
562 loopTimeoutDuration, callback);
563 }
564
executeSynchronouslyBase(const Request & request,MeasureTiming measure,const Model & model,const SampleDriver & driver,const SamplePreparedModel * preparedModel,const std::vector<RunTimePoolInfo> & poolInfos,const OptionalTimePoint & halDeadline,const OptionalTimeoutDuration & loopTimeoutDuration)565 static std::tuple<ErrorStatus, hidl_vec<OutputShape>, Timing> executeSynchronouslyBase(
566 const Request& request, MeasureTiming measure, const Model& model,
567 const SampleDriver& driver, const SamplePreparedModel* preparedModel,
568 const std::vector<RunTimePoolInfo>& poolInfos, const OptionalTimePoint& halDeadline,
569 const OptionalTimeoutDuration& loopTimeoutDuration) {
570 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
571 "SampleDriver::executeSynchronouslyBase");
572 VLOG(DRIVER) << "executeSynchronouslyBase(" << SHOW_IF_DEBUG(toString(request)) << ")";
573
574 time_point driverStart, driverEnd, deviceStart, deviceEnd;
575 if (measure == MeasureTiming::YES) driverStart = now();
576
577 if (!validateRequest(request, model)) {
578 return {ErrorStatus::INVALID_ARGUMENT, {}, kNoTiming};
579 }
580 const auto deadline = makeDeadline(halDeadline);
581 if (hasDeadlinePassed(deadline)) {
582 return {ErrorStatus::MISSED_DEADLINE_PERSISTENT, {}, kNoTiming};
583 }
584
585 NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INPUTS_AND_OUTPUTS,
586 "SampleDriver::executeSynchronouslyBase");
587 const auto [poolStatus, requestPoolInfos, bufferWrappers] =
588 createRunTimePoolInfos(request, driver, preparedModel);
589 if (poolStatus != ErrorStatus::NONE) {
590 return {poolStatus, {}, kNoTiming};
591 }
592
593 NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
594 "SampleDriver::executeSynchronouslyBase");
595 CpuExecutor executor = driver.getExecutor();
596 if (loopTimeoutDuration.getDiscriminator() !=
597 OptionalTimeoutDuration::hidl_discriminator::none) {
598 executor.setLoopTimeout(loopTimeoutDuration.nanoseconds());
599 }
600 if (deadline.has_value()) {
601 executor.setDeadline(*deadline);
602 }
603 if (measure == MeasureTiming::YES) deviceStart = now();
604 int n = executor.run(model, request, poolInfos, requestPoolInfos);
605 if (measure == MeasureTiming::YES) deviceEnd = now();
606 VLOG(DRIVER) << "executor.run returned " << n;
607 ErrorStatus executionStatus = convertResultCodeToErrorStatus(n);
608 hidl_vec<OutputShape> outputShapes = executor.getOutputShapes();
609
610 // Update device memory metadata.
611 const ErrorStatus updateStatus =
612 updateDeviceMemories(executionStatus, request, bufferWrappers, outputShapes);
613 if (updateStatus != ErrorStatus::NONE) {
614 return {updateStatus, {}, kNoTiming};
615 }
616
617 if (measure == MeasureTiming::YES && executionStatus == ErrorStatus::NONE) {
618 driverEnd = now();
619 Timing timing = {.timeOnDevice = uint64_t(microsecondsDuration(deviceEnd, deviceStart)),
620 .timeInDriver = uint64_t(microsecondsDuration(driverEnd, driverStart))};
621 VLOG(DRIVER) << "executeSynchronouslyBase timing = " << toString(timing);
622 return {executionStatus, std::move(outputShapes), timing};
623 }
624 return {executionStatus, std::move(outputShapes), kNoTiming};
625 }
626
executeSynchronously(const V1_0::Request & request,MeasureTiming measure,executeSynchronously_cb cb)627 Return<void> SamplePreparedModel::executeSynchronously(const V1_0::Request& request,
628 MeasureTiming measure,
629 executeSynchronously_cb cb) {
630 auto [status, outputShapes, timing] = executeSynchronouslyBase(
631 convertToV1_3(request), measure, mModel, *mDriver, this, mPoolInfos, {}, {});
632 cb(convertToV1_0(status), std::move(outputShapes), timing);
633 return Void();
634 }
635
executeSynchronously_1_3(const V1_3::Request & request,MeasureTiming measure,const OptionalTimePoint & deadline,const OptionalTimeoutDuration & loopTimeoutDuration,executeSynchronously_1_3_cb cb)636 Return<void> SamplePreparedModel::executeSynchronously_1_3(
637 const V1_3::Request& request, MeasureTiming measure, const OptionalTimePoint& deadline,
638 const OptionalTimeoutDuration& loopTimeoutDuration, executeSynchronously_1_3_cb cb) {
639 auto [status, outputShapes, timing] = executeSynchronouslyBase(
640 request, measure, mModel, *mDriver, this, mPoolInfos, deadline, loopTimeoutDuration);
641 cb(status, std::move(outputShapes), timing);
642 return Void();
643 }
644
645 // The sample driver will finish the execution and then return.
executeFenced(const hal::Request & request,const hidl_vec<hidl_handle> & waitFor,MeasureTiming measure,const OptionalTimePoint & halDeadline,const OptionalTimeoutDuration & loopTimeoutDuration,const OptionalTimeoutDuration & duration,executeFenced_cb cb)646 Return<void> SamplePreparedModel::executeFenced(
647 const hal::Request& request, const hidl_vec<hidl_handle>& waitFor, MeasureTiming measure,
648 const OptionalTimePoint& halDeadline, const OptionalTimeoutDuration& loopTimeoutDuration,
649 const OptionalTimeoutDuration& duration, executeFenced_cb cb) {
650 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
651 "SamplePreparedModel::executeFenced");
652 VLOG(DRIVER) << "executeFenced(" << SHOW_IF_DEBUG(toString(request)) << ")";
653
654 time_point driverStart, driverEnd, deviceStart, deviceEnd;
655 if (measure == MeasureTiming::YES) driverStart = now();
656
657 if (!validateRequest(request, mModel, /*allowUnspecifiedOutput=*/false)) {
658 cb(ErrorStatus::INVALID_ARGUMENT, hidl_handle(nullptr), nullptr);
659 return Void();
660 }
661 const auto deadline = makeDeadline(halDeadline);
662 if (hasDeadlinePassed(deadline)) {
663 cb(ErrorStatus::MISSED_DEADLINE_PERSISTENT, hidl_handle(nullptr), nullptr);
664 return Void();
665 }
666
667 // Wait for the dependent events to signal
668 for (const auto& fenceHandle : waitFor) {
669 if (!fenceHandle.getNativeHandle()) {
670 cb(ErrorStatus::INVALID_ARGUMENT, hidl_handle(nullptr), nullptr);
671 return Void();
672 }
673 int syncFenceFd = fenceHandle.getNativeHandle()->data[0];
674 if (syncWait(syncFenceFd, -1) != FenceState::SIGNALED) {
675 LOG(ERROR) << "syncWait failed";
676 cb(ErrorStatus::GENERAL_FAILURE, hidl_handle(nullptr), nullptr);
677 return Void();
678 }
679 }
680
681 // Update deadline if the timeout duration is closer than the deadline.
682 auto closestDeadline = deadline;
683 if (duration.getDiscriminator() != OptionalTimeoutDuration::hidl_discriminator::none) {
684 const auto timeoutDurationDeadline = makeDeadline(duration.nanoseconds());
685 if (!closestDeadline.has_value() || *closestDeadline > timeoutDurationDeadline) {
686 closestDeadline = timeoutDurationDeadline;
687 }
688 }
689
690 time_point driverStartAfterFence;
691 if (measure == MeasureTiming::YES) driverStartAfterFence = now();
692
693 NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INPUTS_AND_OUTPUTS,
694 "SamplePreparedModel::executeFenced");
695 const auto [poolStatus, requestPoolInfos, bufferWrappers] =
696 createRunTimePoolInfos(request, *mDriver, this);
697 if (poolStatus != ErrorStatus::NONE) {
698 cb(poolStatus, hidl_handle(nullptr), nullptr);
699 return Void();
700 }
701
702 NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
703 "SamplePreparedModel::executeFenced");
704 CpuExecutor executor = mDriver->getExecutor();
705 if (loopTimeoutDuration.getDiscriminator() !=
706 OptionalTimeoutDuration::hidl_discriminator::none) {
707 executor.setLoopTimeout(loopTimeoutDuration.nanoseconds());
708 }
709 if (closestDeadline.has_value()) {
710 executor.setDeadline(*closestDeadline);
711 }
712 if (measure == MeasureTiming::YES) deviceStart = now();
713 int n = executor.run(mModel, request, mPoolInfos, requestPoolInfos);
714 if (measure == MeasureTiming::YES) deviceEnd = now();
715 VLOG(DRIVER) << "executor.run returned " << n;
716 ErrorStatus executionStatus = convertResultCodeToErrorStatus(n);
717 if (executionStatus != ErrorStatus::NONE) {
718 cb(executionStatus, hidl_handle(nullptr), nullptr);
719 return Void();
720 }
721
722 // Set output memories to the initialized state.
723 if (executionStatus == ErrorStatus::NONE) {
724 for (const auto& output : request.outputs) {
725 const uint32_t poolIndex = output.location.poolIndex;
726 const auto& pool = request.pools[poolIndex];
727 if (pool.getDiscriminator() == Request::MemoryPool::hidl_discriminator::token) {
728 bufferWrappers[poolIndex]->setInitialized(true);
729 }
730 }
731 }
732
733 Timing timingSinceLaunch = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
734 Timing timingAfterFence = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
735 if (measure == MeasureTiming::YES) {
736 driverEnd = now();
737 timingSinceLaunch = {
738 .timeOnDevice = uint64_t(microsecondsDuration(deviceEnd, deviceStart)),
739 .timeInDriver = uint64_t(microsecondsDuration(driverEnd, driverStart))};
740 timingAfterFence = {
741 .timeOnDevice = uint64_t(microsecondsDuration(deviceEnd, deviceStart)),
742 .timeInDriver = uint64_t(microsecondsDuration(driverEnd, driverStartAfterFence))};
743 VLOG(DRIVER) << "executeFenced timingSinceLaunch = " << toString(timingSinceLaunch);
744 VLOG(DRIVER) << "executeFenced timingAfterFence = " << toString(timingAfterFence);
745 }
746 sp<SampleFencedExecutionCallback> fencedExecutionCallback =
747 new SampleFencedExecutionCallback(timingSinceLaunch, timingAfterFence, executionStatus);
748 cb(executionStatus, hidl_handle(nullptr), fencedExecutionCallback);
749 return Void();
750 }
751
752 // BurstExecutorWithCache maps hidl_memory when it is first seen, and preserves
753 // the mapping until either (1) the memory is freed in the runtime, or (2) the
754 // burst object is destroyed. This allows for subsequent executions operating on
755 // pools that have been used before to reuse the mapping instead of mapping and
756 // unmapping the memory on each execution.
757 class BurstExecutorWithCache : public ExecutionBurstServer::IBurstExecutorWithCache {
758 public:
BurstExecutorWithCache(const Model & model,const SampleDriver * driver,const std::vector<RunTimePoolInfo> & poolInfos)759 BurstExecutorWithCache(const Model& model, const SampleDriver* driver,
760 const std::vector<RunTimePoolInfo>& poolInfos)
761 : mModel(model), mDriver(driver), mModelPoolInfos(poolInfos) {}
762
isCacheEntryPresent(int32_t slot) const763 bool isCacheEntryPresent(int32_t slot) const override {
764 const auto it = mMemoryCache.find(slot);
765 return (it != mMemoryCache.end()) && it->second.has_value();
766 }
767
addCacheEntry(const hidl_memory & memory,int32_t slot)768 void addCacheEntry(const hidl_memory& memory, int32_t slot) override {
769 mMemoryCache[slot] = RunTimePoolInfo::createFromHidlMemory(memory);
770 }
771
removeCacheEntry(int32_t slot)772 void removeCacheEntry(int32_t slot) override { mMemoryCache.erase(slot); }
773
execute(const V1_0::Request & request,const std::vector<int32_t> & slots,MeasureTiming measure)774 std::tuple<V1_0::ErrorStatus, hidl_vec<OutputShape>, Timing> execute(
775 const V1_0::Request& request, const std::vector<int32_t>& slots,
776 MeasureTiming measure) override {
777 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
778 "BurstExecutorWithCache::execute");
779
780 time_point driverStart, driverEnd, deviceStart, deviceEnd;
781 if (measure == MeasureTiming::YES) driverStart = now();
782
783 // ensure all relevant pools are valid
784 if (!std::all_of(slots.begin(), slots.end(),
785 [this](int32_t slot) { return isCacheEntryPresent(slot); })) {
786 return {V1_0::ErrorStatus::INVALID_ARGUMENT, {}, kNoTiming};
787 }
788
789 // finish the request object (for validation)
790 hidl_vec<Request::MemoryPool> pools(slots.size());
791 std::transform(slots.begin(), slots.end(), pools.begin(), [this](int32_t slot) {
792 Request::MemoryPool pool;
793 pool.hidlMemory(mMemoryCache[slot]->getHidlMemory());
794 return pool;
795 });
796 Request fullRequest = {.inputs = request.inputs, .outputs = request.outputs};
797 fullRequest.pools = std::move(pools);
798
799 // validate request object against the model
800 if (!validateRequest(fullRequest, mModel)) {
801 return {V1_0::ErrorStatus::INVALID_ARGUMENT, {}, kNoTiming};
802 }
803
804 // select relevant entries from cache
805 std::vector<RunTimePoolInfo> requestPoolInfos;
806 requestPoolInfos.reserve(slots.size());
807 std::transform(slots.begin(), slots.end(), std::back_inserter(requestPoolInfos),
808 [this](int32_t slot) { return *mMemoryCache[slot]; });
809
810 // execution
811 // Configuring the loop timeout duration is not supported. This is OK
812 // because burst does not support HAL 1.3 and hence does not support
813 // WHILE loops.
814 CpuExecutor executor = mDriver->getExecutor();
815 if (measure == MeasureTiming::YES) deviceStart = now();
816 int n = executor.run(mModel, fullRequest, mModelPoolInfos, requestPoolInfos);
817 if (measure == MeasureTiming::YES) deviceEnd = now();
818 VLOG(DRIVER) << "executor.run returned " << n;
819 V1_0::ErrorStatus executionStatus = convertToV1_0(convertResultCodeToErrorStatus(n));
820 hidl_vec<OutputShape> outputShapes = executor.getOutputShapes();
821 if (measure == MeasureTiming::YES && executionStatus == V1_0::ErrorStatus::NONE) {
822 driverEnd = now();
823 Timing timing = {
824 .timeOnDevice = uint64_t(microsecondsDuration(deviceEnd, deviceStart)),
825 .timeInDriver = uint64_t(microsecondsDuration(driverEnd, driverStart))};
826 VLOG(DRIVER) << "BurstExecutorWithCache::execute timing = " << toString(timing);
827 return std::make_tuple(executionStatus, outputShapes, timing);
828 } else {
829 return std::make_tuple(executionStatus, outputShapes, kNoTiming);
830 }
831 }
832
833 private:
834 const Model mModel;
835 const SampleDriver* const mDriver;
836 const std::vector<RunTimePoolInfo> mModelPoolInfos;
837 std::map<int32_t, std::optional<RunTimePoolInfo>> mMemoryCache; // cached requestPoolInfos
838 };
839
840 // This is the amount of time the ExecutionBurstServer should spend polling the
841 // FMQ to see if it has data available before it should fall back to waiting on
842 // the futex.
getPollingTimeWindow()843 static std::chrono::microseconds getPollingTimeWindow() {
844 constexpr int32_t defaultPollingTimeWindow = 50;
845 #ifdef NN_DEBUGGABLE
846 constexpr int32_t minPollingTimeWindow = 0;
847 const int32_t selectedPollingTimeWindow =
848 base::GetIntProperty("debug.nn.sample-driver-burst-polling-window",
849 defaultPollingTimeWindow, minPollingTimeWindow);
850 return std::chrono::microseconds{selectedPollingTimeWindow};
851 #else
852 return std::chrono::microseconds{defaultPollingTimeWindow};
853 #endif // NN_DEBUGGABLE
854 }
855
configureExecutionBurst(const sp<V1_2::IBurstCallback> & callback,const MQDescriptorSync<V1_2::FmqRequestDatum> & requestChannel,const MQDescriptorSync<V1_2::FmqResultDatum> & resultChannel,configureExecutionBurst_cb cb)856 Return<void> SamplePreparedModel::configureExecutionBurst(
857 const sp<V1_2::IBurstCallback>& callback,
858 const MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel,
859 const MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel,
860 configureExecutionBurst_cb cb) {
861 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
862 "SampleDriver::configureExecutionBurst");
863
864 const bool preferPowerOverLatency = (kPreference == ExecutionPreference::LOW_POWER);
865 const auto pollingTimeWindow =
866 (preferPowerOverLatency ? std::chrono::microseconds{0} : getPollingTimeWindow());
867
868 // Alternatively, the burst could be configured via:
869 // const sp<V1_2::IBurstContext> burst =
870 // ExecutionBurstServer::create(callback, requestChannel,
871 // resultChannel, this,
872 // pollingTimeWindow);
873 //
874 // However, this alternative representation does not include a memory map
875 // caching optimization, and adds overhead.
876 const std::shared_ptr<BurstExecutorWithCache> executorWithCache =
877 std::make_shared<BurstExecutorWithCache>(mModel, mDriver, mPoolInfos);
878 const sp<V1_2::IBurstContext> burst = ExecutionBurstServer::create(
879 callback, requestChannel, resultChannel, executorWithCache, pollingTimeWindow);
880
881 if (burst == nullptr) {
882 cb(V1_0::ErrorStatus::GENERAL_FAILURE, {});
883 } else {
884 cb(V1_0::ErrorStatus::NONE, burst);
885 }
886
887 return Void();
888 }
889
890 } // namespace sample_driver
891 } // namespace nn
892 } // namespace android
893