1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "SampleDriver"
18 
19 #include "SampleDriver.h"
20 
21 #include <CpuExecutor.h>
22 #include <ExecutionBurstServer.h>
23 #include <HalBufferTracker.h>
24 #include <HalInterfaces.h>
25 #include <Tracing.h>
26 #include <ValidateHal.h>
27 #include <android-base/logging.h>
28 #include <android-base/properties.h>
29 #include <hidl/LegacySupport.h>
30 #include <nnapi/Types.h>
31 #include <nnapi/hal/1.3/Conversions.h>
32 
33 #include <algorithm>
34 #include <chrono>
35 #include <map>
36 #include <memory>
37 #include <optional>
38 #include <set>
39 #include <thread>
40 #include <tuple>
41 #include <utility>
42 #include <vector>
43 
44 #include "SampleDriverUtils.h"
45 
46 namespace android {
47 namespace nn {
48 namespace sample_driver {
49 
50 namespace {
51 
microsecondsDuration(TimePoint end,TimePoint start)52 uint64_t microsecondsDuration(TimePoint end, TimePoint start) {
53     using Microseconds = std::chrono::duration<uint64_t, std::micro>;
54     return std::chrono::duration_cast<Microseconds>(end - start).count();
55 };
56 
57 }  // namespace
58 
59 static const V1_2::Timing kNoTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
60 
getCapabilities(getCapabilities_cb cb)61 hardware::Return<void> SampleDriver::getCapabilities(getCapabilities_cb cb) {
62     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
63                  "SampleDriver::getCapabilities");
64     return getCapabilities_1_3(
65             [&](V1_3::ErrorStatus error, const V1_3::Capabilities& capabilities) {
66                 // TODO(dgross): Do we need to check compliantWithV1_0(capabilities)?
67                 cb(convertToV1_0(error), convertToV1_0(capabilities));
68             });
69 }
70 
getCapabilities_1_1(getCapabilities_1_1_cb cb)71 hardware::Return<void> SampleDriver::getCapabilities_1_1(getCapabilities_1_1_cb cb) {
72     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
73                  "SampleDriver::getCapabilities_1_1");
74     return getCapabilities_1_3(
75             [&](V1_3::ErrorStatus error, const V1_3::Capabilities& capabilities) {
76                 // TODO(dgross): Do we need to check compliantWithV1_1(capabilities)?
77                 cb(convertToV1_0(error), convertToV1_1(capabilities));
78             });
79 }
80 
getCapabilities_1_2(getCapabilities_1_2_cb cb)81 hardware::Return<void> SampleDriver::getCapabilities_1_2(getCapabilities_1_2_cb cb) {
82     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
83                  "SampleDriver::getCapabilities_1_2");
84     return getCapabilities_1_3(
85             [&](V1_3::ErrorStatus error, const V1_3::Capabilities& capabilities) {
86                 // TODO(dgross): Do we need to check compliantWithV1_2(capabilities)?
87                 cb(convertToV1_0(error), convertToV1_2(capabilities));
88             });
89 }
90 
getVersionString(getVersionString_cb cb)91 hardware::Return<void> SampleDriver::getVersionString(getVersionString_cb cb) {
92     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
93                  "SampleDriver::getVersionString");
94     cb(V1_0::ErrorStatus::NONE, "JUST_AN_EXAMPLE");
95     return hardware::Void();
96 }
97 
getType(getType_cb cb)98 hardware::Return<void> SampleDriver::getType(getType_cb cb) {
99     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION, "SampleDriver::getType");
100     cb(V1_0::ErrorStatus::NONE, V1_2::DeviceType::CPU);
101     return hardware::Void();
102 }
103 
getSupportedExtensions(getSupportedExtensions_cb cb)104 hardware::Return<void> SampleDriver::getSupportedExtensions(getSupportedExtensions_cb cb) {
105     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
106                  "SampleDriver::getSupportedExtensions");
107     cb(V1_0::ErrorStatus::NONE, {/* No extensions. */});
108     return hardware::Void();
109 }
110 
getSupportedOperations(const V1_0::Model & model,getSupportedOperations_cb cb)111 hardware::Return<void> SampleDriver::getSupportedOperations(const V1_0::Model& model,
112                                                             getSupportedOperations_cb cb) {
113     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION,
114                  "SampleDriver::getSupportedOperations");
115     if (!validateModel(model)) {
116         VLOG(DRIVER) << "getSupportedOperations";
117         cb(V1_0::ErrorStatus::INVALID_ARGUMENT, {});
118         return hardware::Void();
119     }
120     return getSupportedOperations_1_3(
121             convertToV1_3(model),
122             [&](V1_3::ErrorStatus status, const hardware::hidl_vec<bool>& supported) {
123                 cb(convertToV1_0(status), supported);
124             });
125 }
126 
getSupportedOperations_1_1(const V1_1::Model & model,getSupportedOperations_1_1_cb cb)127 hardware::Return<void> SampleDriver::getSupportedOperations_1_1(const V1_1::Model& model,
128                                                                 getSupportedOperations_1_1_cb cb) {
129     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION,
130                  "SampleDriver::getSupportedOperations_1_1");
131     if (!validateModel(model)) {
132         VLOG(DRIVER) << "getSupportedOperations_1_1";
133         cb(V1_0::ErrorStatus::INVALID_ARGUMENT, {});
134         return hardware::Void();
135     }
136     return getSupportedOperations_1_3(
137             convertToV1_3(model),
138             [&](V1_3::ErrorStatus status, const hardware::hidl_vec<bool>& supported) {
139                 cb(convertToV1_0(status), supported);
140             });
141 }
142 
getSupportedOperations_1_2(const V1_2::Model & model,getSupportedOperations_1_2_cb cb)143 hardware::Return<void> SampleDriver::getSupportedOperations_1_2(const V1_2::Model& model,
144                                                                 getSupportedOperations_1_2_cb cb) {
145     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION,
146                  "SampleDriver::getSupportedOperations_1_2");
147     if (!validateModel(model)) {
148         VLOG(DRIVER) << "getSupportedOperations_1_2";
149         cb(V1_0::ErrorStatus::INVALID_ARGUMENT, {});
150         return hardware::Void();
151     }
152     return getSupportedOperations_1_3(
153             convertToV1_3(model),
154             [&](V1_3::ErrorStatus status, const hardware::hidl_vec<bool>& supported) {
155                 cb(convertToV1_0(status), supported);
156             });
157 }
158 
getNumberOfCacheFilesNeeded(getNumberOfCacheFilesNeeded_cb cb)159 hardware::Return<void> SampleDriver::getNumberOfCacheFilesNeeded(
160         getNumberOfCacheFilesNeeded_cb cb) {
161     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
162                  "SampleDriver::getNumberOfCacheFilesNeeded");
163     // Set both numbers to be 0 for cache not supported.
164     cb(V1_0::ErrorStatus::NONE, /*numModelCache=*/0, /*numDataCache=*/0);
165     return hardware::Void();
166 }
167 
prepareModel(const V1_0::Model & model,const sp<V1_0::IPreparedModelCallback> & callback)168 hardware::Return<V1_0::ErrorStatus> SampleDriver::prepareModel(
169         const V1_0::Model& model, const sp<V1_0::IPreparedModelCallback>& callback) {
170     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION, "SampleDriver::prepareModel");
171     const V1_3::ErrorStatus status =
172             prepareModelBase(model, this, V1_1::ExecutionPreference::FAST_SINGLE_ANSWER,
173                              kDefaultPriority13, {}, callback);
174     return convertToV1_0(status);
175 }
176 
prepareModel_1_1(const V1_1::Model & model,V1_1::ExecutionPreference preference,const sp<V1_0::IPreparedModelCallback> & callback)177 hardware::Return<V1_0::ErrorStatus> SampleDriver::prepareModel_1_1(
178         const V1_1::Model& model, V1_1::ExecutionPreference preference,
179         const sp<V1_0::IPreparedModelCallback>& callback) {
180     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION, "SampleDriver::prepareModel_1_1");
181     const V1_3::ErrorStatus status =
182             prepareModelBase(model, this, preference, kDefaultPriority13, {}, callback);
183     return convertToV1_0(status);
184 }
185 
prepareModel_1_2(const V1_2::Model & model,V1_1::ExecutionPreference preference,const hardware::hidl_vec<hardware::hidl_handle> &,const hardware::hidl_vec<hardware::hidl_handle> &,const HalCacheToken &,const sp<V1_2::IPreparedModelCallback> & callback)186 hardware::Return<V1_0::ErrorStatus> SampleDriver::prepareModel_1_2(
187         const V1_2::Model& model, V1_1::ExecutionPreference preference,
188         const hardware::hidl_vec<hardware::hidl_handle>&,
189         const hardware::hidl_vec<hardware::hidl_handle>&, const HalCacheToken&,
190         const sp<V1_2::IPreparedModelCallback>& callback) {
191     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION, "SampleDriver::prepareModel_1_2");
192     const V1_3::ErrorStatus status =
193             prepareModelBase(model, this, preference, kDefaultPriority13, {}, callback);
194     return convertToV1_0(status);
195 }
196 
prepareModel_1_3(const V1_3::Model & model,V1_1::ExecutionPreference preference,V1_3::Priority priority,const V1_3::OptionalTimePoint & deadline,const hardware::hidl_vec<hardware::hidl_handle> &,const hardware::hidl_vec<hardware::hidl_handle> &,const HalCacheToken &,const sp<V1_3::IPreparedModelCallback> & callback)197 hardware::Return<V1_3::ErrorStatus> SampleDriver::prepareModel_1_3(
198         const V1_3::Model& model, V1_1::ExecutionPreference preference, V1_3::Priority priority,
199         const V1_3::OptionalTimePoint& deadline, const hardware::hidl_vec<hardware::hidl_handle>&,
200         const hardware::hidl_vec<hardware::hidl_handle>&, const HalCacheToken&,
201         const sp<V1_3::IPreparedModelCallback>& callback) {
202     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION, "SampleDriver::prepareModel_1_3");
203     return prepareModelBase(model, this, preference, priority, deadline, callback);
204 }
205 
prepareModelFromCache(const hardware::hidl_vec<hardware::hidl_handle> &,const hardware::hidl_vec<hardware::hidl_handle> &,const HalCacheToken &,const sp<V1_2::IPreparedModelCallback> & callback)206 hardware::Return<V1_0::ErrorStatus> SampleDriver::prepareModelFromCache(
207         const hardware::hidl_vec<hardware::hidl_handle>&,
208         const hardware::hidl_vec<hardware::hidl_handle>&, const HalCacheToken&,
209         const sp<V1_2::IPreparedModelCallback>& callback) {
210     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION,
211                  "SampleDriver::prepareModelFromCache");
212     notify(callback, V1_3::ErrorStatus::GENERAL_FAILURE, nullptr);
213     return V1_0::ErrorStatus::GENERAL_FAILURE;
214 }
215 
prepareModelFromCache_1_3(const V1_3::OptionalTimePoint &,const hardware::hidl_vec<hardware::hidl_handle> &,const hardware::hidl_vec<hardware::hidl_handle> &,const HalCacheToken &,const sp<V1_3::IPreparedModelCallback> & callback)216 hardware::Return<V1_3::ErrorStatus> SampleDriver::prepareModelFromCache_1_3(
217         const V1_3::OptionalTimePoint& /*deadline*/,
218         const hardware::hidl_vec<hardware::hidl_handle>&,
219         const hardware::hidl_vec<hardware::hidl_handle>&, const HalCacheToken&,
220         const sp<V1_3::IPreparedModelCallback>& callback) {
221     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION,
222                  "SampleDriver::prepareModelFromCache_1_3");
223     notify(callback, V1_3::ErrorStatus::GENERAL_FAILURE, nullptr);
224     return V1_3::ErrorStatus::GENERAL_FAILURE;
225 }
226 
getStatus()227 hardware::Return<V1_0::DeviceStatus> SampleDriver::getStatus() {
228     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_UNSPECIFIED, "SampleDriver::getStatus");
229     VLOG(DRIVER) << "getStatus()";
230     return V1_0::DeviceStatus::AVAILABLE;
231 }
232 
233 // Safely downcast an IPreparedModel object to SamplePreparedModel.
234 // This function will return nullptr if the IPreparedModel object is not originated from the sample
235 // driver process.
castToSamplePreparedModel(const sp<V1_3::IPreparedModel> & preparedModel)236 static const SamplePreparedModel* castToSamplePreparedModel(
237         const sp<V1_3::IPreparedModel>& preparedModel) {
238     if (preparedModel->isRemote()) {
239         return nullptr;
240     } else {
241         // This static_cast is safe because SamplePreparedModel is the only class that implements
242         // the IPreparedModel interface in the sample driver process.
243         return static_cast<const SamplePreparedModel*>(preparedModel.get());
244     }
245 }
246 
allocate(const V1_3::BufferDesc & desc,const hardware::hidl_vec<sp<V1_3::IPreparedModel>> & preparedModels,const hardware::hidl_vec<V1_3::BufferRole> & inputRoles,const hardware::hidl_vec<V1_3::BufferRole> & outputRoles,allocate_cb cb)247 hardware::Return<void> SampleDriver::allocate(
248         const V1_3::BufferDesc& desc,
249         const hardware::hidl_vec<sp<V1_3::IPreparedModel>>& preparedModels,
250         const hardware::hidl_vec<V1_3::BufferRole>& inputRoles,
251         const hardware::hidl_vec<V1_3::BufferRole>& outputRoles, allocate_cb cb) {
252     constexpr uint32_t kInvalidBufferToken = 0;
253 
254     VLOG(DRIVER) << "SampleDriver::allocate";
255     std::set<HalPreparedModelRole> roles;
256     V1_3::Operand operand;
257     auto getModel = [](const sp<V1_3::IPreparedModel>& preparedModel) -> const V1_3::Model* {
258         const auto* samplePreparedModel = castToSamplePreparedModel(preparedModel);
259         if (samplePreparedModel == nullptr) {
260             LOG(ERROR) << "SampleDriver::allocate -- unknown remote IPreparedModel.";
261             return nullptr;
262         }
263         return samplePreparedModel->getModel();
264     };
265     if (!validateMemoryDesc(desc, preparedModels, inputRoles, outputRoles, getModel, &roles,
266                             &operand)) {
267         LOG(ERROR) << "SampleDriver::allocate -- validation failed.";
268         cb(V1_3::ErrorStatus::INVALID_ARGUMENT, nullptr, kInvalidBufferToken);
269         return hardware::Void();
270     }
271 
272     if (isExtensionOperandType(operand.type)) {
273         LOG(ERROR) << "SampleDriver::allocate -- does not support extension type.";
274         cb(V1_3::ErrorStatus::GENERAL_FAILURE, nullptr, kInvalidBufferToken);
275         return hardware::Void();
276     }
277 
278     // TODO(xusongw): Support allocating buffers with unknown dimensions or rank.
279     uint32_t size = nonExtensionOperandSizeOfData(operand.type, operand.dimensions);
280     VLOG(DRIVER) << "SampleDriver::allocate -- type = " << toString(operand.type)
281                  << ", dimensions = " << toString(operand.dimensions) << ", size = " << size;
282     if (size == 0) {
283         LOG(ERROR) << "SampleDriver::allocate -- does not support dynamic output shape.";
284         cb(V1_3::ErrorStatus::GENERAL_FAILURE, nullptr, kInvalidBufferToken);
285         return hardware::Void();
286     }
287 
288     auto bufferWrapper =
289             HalManagedBuffer::create(size, std::move(roles), uncheckedConvert(operand));
290     if (bufferWrapper == nullptr) {
291         LOG(ERROR) << "SampleDriver::allocate -- not enough memory.";
292         cb(V1_3::ErrorStatus::GENERAL_FAILURE, nullptr, kInvalidBufferToken);
293         return hardware::Void();
294     }
295 
296     auto token = mHalBufferTracker->add(bufferWrapper);
297     if (token == nullptr) {
298         LOG(ERROR) << "SampleDriver::allocate -- HalBufferTracker returned invalid token.";
299         cb(V1_3::ErrorStatus::GENERAL_FAILURE, nullptr, kInvalidBufferToken);
300         return hardware::Void();
301     }
302 
303     const uint32_t tokenValue = token->get();
304     sp<SampleBuffer> sampleBuffer = new SampleBuffer(std::move(bufferWrapper), std::move(token));
305     VLOG(DRIVER) << "SampleDriver::allocate -- successfully allocates the requested memory";
306     cb(V1_3::ErrorStatus::NONE, std::move(sampleBuffer), tokenValue);
307     return hardware::Void();
308 }
309 
copyRunTimePoolInfos(const RunTimePoolInfo & srcPool,const RunTimePoolInfo & dstPool)310 static void copyRunTimePoolInfos(const RunTimePoolInfo& srcPool, const RunTimePoolInfo& dstPool) {
311     CHECK(srcPool.getBuffer() != nullptr);
312     CHECK(dstPool.getBuffer() != nullptr);
313     CHECK(srcPool.getSize() == dstPool.getSize());
314     std::copy(srcPool.getBuffer(), srcPool.getBuffer() + srcPool.getSize(), dstPool.getBuffer());
315     dstPool.flush();
316 }
317 
copyTo(const hardware::hidl_memory & dst)318 hardware::Return<V1_3::ErrorStatus> SampleBuffer::copyTo(const hardware::hidl_memory& dst) {
319     const auto dstPool = RunTimePoolInfo::createFromMemory(uncheckedConvert(dst));
320     if (!dstPool.has_value()) {
321         LOG(ERROR) << "SampleBuffer::copyTo -- unable to map dst memory.";
322         return V1_3::ErrorStatus::GENERAL_FAILURE;
323     }
324     const V1_3::ErrorStatus validationStatus =
325             convertToV1_3(kBuffer->validateCopyTo(dstPool->getSize()));
326     if (validationStatus != V1_3::ErrorStatus::NONE) {
327         return validationStatus;
328     }
329     const auto srcPool = kBuffer->createRunTimePoolInfo();
330     copyRunTimePoolInfos(srcPool, dstPool.value());
331     return V1_3::ErrorStatus::NONE;
332 }
333 
copyFromInternal(const hardware::hidl_memory & src,const hardware::hidl_vec<uint32_t> & dimensions,const std::shared_ptr<HalManagedBuffer> & bufferWrapper)334 static V1_3::ErrorStatus copyFromInternal(const hardware::hidl_memory& src,
335                                           const hardware::hidl_vec<uint32_t>& dimensions,
336                                           const std::shared_ptr<HalManagedBuffer>& bufferWrapper) {
337     CHECK(bufferWrapper != nullptr);
338     const auto srcPool = RunTimePoolInfo::createFromMemory(uncheckedConvert(src));
339     if (!srcPool.has_value()) {
340         LOG(ERROR) << "SampleBuffer::copyFrom -- unable to map src memory.";
341         return V1_3::ErrorStatus::GENERAL_FAILURE;
342     }
343     const V1_3::ErrorStatus validationStatus =
344             convertToV1_3(bufferWrapper->validateCopyFrom(dimensions, srcPool->getSize()));
345     if (validationStatus != V1_3::ErrorStatus::NONE) {
346         return validationStatus;
347     }
348     const auto dstPool = bufferWrapper->createRunTimePoolInfo();
349     copyRunTimePoolInfos(srcPool.value(), dstPool);
350     return V1_3::ErrorStatus::NONE;
351 }
352 
copyFrom(const hardware::hidl_memory & src,const hardware::hidl_vec<uint32_t> & dimensions)353 hardware::Return<V1_3::ErrorStatus> SampleBuffer::copyFrom(
354         const hardware::hidl_memory& src, const hardware::hidl_vec<uint32_t>& dimensions) {
355     const auto status = copyFromInternal(src, dimensions, kBuffer);
356     if (status == V1_3::ErrorStatus::NONE) {
357         kBuffer->updateDimensions(dimensions);
358         kBuffer->setInitialized(true);
359     } else {
360         kBuffer->setInitialized(false);
361     }
362     return status;
363 }
364 
initialize()365 bool SamplePreparedModel::initialize() {
366     return setRunTimePoolInfosFromCanonicalMemories(&mPoolInfos, uncheckedConvert(mModel.pools));
367 }
368 
369 static std::tuple<V1_3::ErrorStatus, std::vector<RunTimePoolInfo>,
370                   std::vector<std::shared_ptr<HalManagedBuffer>>>
createRunTimePoolInfos(const V1_3::Request & request,const SampleDriver & driver,const SamplePreparedModel * preparedModel)371 createRunTimePoolInfos(const V1_3::Request& request, const SampleDriver& driver,
372                        const SamplePreparedModel* preparedModel) {
373     std::vector<RunTimePoolInfo> requestPoolInfos;
374     std::vector<std::shared_ptr<HalManagedBuffer>> bufferWrappers;
375     requestPoolInfos.reserve(request.pools.size());
376     bufferWrappers.reserve(request.pools.size());
377     for (uint32_t i = 0; i < request.pools.size(); i++) {
378         auto& pool = request.pools[i];
379         switch (pool.getDiscriminator()) {
380             case V1_3::Request::MemoryPool::hidl_discriminator::hidlMemory: {
381                 auto buffer =
382                         RunTimePoolInfo::createFromMemory(uncheckedConvert(pool.hidlMemory()));
383                 if (!buffer.has_value()) {
384                     LOG(ERROR) << "createRuntimeMemoriesFromMemoryPools -- could not map pools";
385                     return {V1_3::ErrorStatus::GENERAL_FAILURE, {}, {}};
386                 }
387                 requestPoolInfos.push_back(std::move(*buffer));
388                 bufferWrappers.push_back(nullptr);
389             } break;
390             case V1_3::Request::MemoryPool::hidl_discriminator::token: {
391                 auto bufferWrapper = driver.getHalBufferTracker()->get(pool.token());
392                 if (bufferWrapper == nullptr) {
393                     return {V1_3::ErrorStatus::INVALID_ARGUMENT, {}, {}};
394                 }
395                 const auto validationStatus = convertToV1_3(bufferWrapper->validateRequest(
396                         i, uncheckedConvert(request), preparedModel));
397                 if (validationStatus != V1_3::ErrorStatus::NONE) {
398                     return {validationStatus, {}, {}};
399                 }
400                 requestPoolInfos.push_back(bufferWrapper->createRunTimePoolInfo());
401                 bufferWrappers.push_back(std::move(bufferWrapper));
402             } break;
403         }
404     }
405     return {V1_3::ErrorStatus::NONE, std::move(requestPoolInfos), std::move(bufferWrappers)};
406 }
407 
updateDeviceMemories(V1_3::ErrorStatus status,const V1_3::Request & request,const std::vector<std::shared_ptr<HalManagedBuffer>> & bufferWrappers,const hardware::hidl_vec<V1_2::OutputShape> & outputShapes)408 static V1_3::ErrorStatus updateDeviceMemories(
409         V1_3::ErrorStatus status, const V1_3::Request& request,
410         const std::vector<std::shared_ptr<HalManagedBuffer>>& bufferWrappers,
411         const hardware::hidl_vec<V1_2::OutputShape>& outputShapes) {
412     if (status == V1_3::ErrorStatus::NONE) {
413         for (uint32_t i = 0; i < request.outputs.size(); i++) {
414             const uint32_t poolIndex = request.outputs[i].location.poolIndex;
415             const auto& pool = request.pools[poolIndex];
416             if (pool.getDiscriminator() == V1_3::Request::MemoryPool::hidl_discriminator::token) {
417                 if (!bufferWrappers[poolIndex]->updateDimensions(outputShapes[i].dimensions)) {
418                     return V1_3::ErrorStatus::GENERAL_FAILURE;
419                 }
420             }
421         }
422         for (uint32_t i = 0; i < request.outputs.size(); i++) {
423             const uint32_t poolIndex = request.outputs[i].location.poolIndex;
424             const auto& pool = request.pools[poolIndex];
425             if (pool.getDiscriminator() == V1_3::Request::MemoryPool::hidl_discriminator::token) {
426                 bufferWrappers[poolIndex]->setInitialized(true);
427             }
428         }
429     } else if (status == V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
430         // If CpuExecutor reports OUTPUT_INSUFFCIENT_SIZE on a device memory, this is because the
431         // dimensions of the device memory are incorrectly specified. The driver should return
432         // GENERAL_FAILURE instead in this case.
433         for (uint32_t i = 0; i < request.outputs.size(); i++) {
434             const uint32_t poolIndex = request.outputs[i].location.poolIndex;
435             const auto& pool = request.pools[poolIndex];
436             if (pool.getDiscriminator() == V1_3::Request::MemoryPool::hidl_discriminator::token) {
437                 if (!outputShapes[i].isSufficient) {
438                     LOG(ERROR) << "Invalid dimensions for output " << i
439                                << ": actual shape = " << toString(outputShapes[i].dimensions);
440                     return V1_3::ErrorStatus::GENERAL_FAILURE;
441                 }
442             }
443         }
444     }
445     return V1_3::ErrorStatus::NONE;
446 }
447 
448 template <typename T_IExecutionCallback>
asyncExecute(const V1_3::Request & request,V1_2::MeasureTiming measure,TimePoint driverStart,const V1_3::Model & model,const SampleDriver & driver,const SamplePreparedModel * preparedModel,const std::vector<RunTimePoolInfo> & poolInfos,const OptionalTimePoint & deadline,const V1_3::OptionalTimeoutDuration & loopTimeoutDuration,const sp<T_IExecutionCallback> & callback)449 void asyncExecute(const V1_3::Request& request, V1_2::MeasureTiming measure, TimePoint driverStart,
450                   const V1_3::Model& model, const SampleDriver& driver,
451                   const SamplePreparedModel* preparedModel,
452                   const std::vector<RunTimePoolInfo>& poolInfos, const OptionalTimePoint& deadline,
453                   const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
454                   const sp<T_IExecutionCallback>& callback) {
455     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INPUTS_AND_OUTPUTS,
456                  "SampleDriver::asyncExecute");
457 
458     const auto [poolStatus, requestPoolInfos, bufferWrappers] =
459             createRunTimePoolInfos(request, driver, preparedModel);
460     if (poolStatus != V1_3::ErrorStatus::NONE) {
461         notify(callback, poolStatus, {}, kNoTiming);
462         return;
463     }
464 
465     NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
466                         "SampleDriver::asyncExecute");
467     CpuExecutor executor = driver.getExecutor();
468     if (loopTimeoutDuration.getDiscriminator() !=
469         V1_3::OptionalTimeoutDuration::hidl_discriminator::none) {
470         executor.setLoopTimeout(loopTimeoutDuration.nanoseconds());
471     }
472     if (deadline.has_value()) {
473         executor.setDeadline(*deadline);
474     }
475     TimePoint driverEnd, deviceStart, deviceEnd;
476     if (measure == V1_2::MeasureTiming::YES) deviceStart = Clock::now();
477     int n = executor.run(uncheckedConvert(model), uncheckedConvert(request), poolInfos,
478                          requestPoolInfos);
479     if (measure == V1_2::MeasureTiming::YES) deviceEnd = Clock::now();
480     VLOG(DRIVER) << "executor.run returned " << n;
481     V1_3::ErrorStatus executionStatus = convertResultCodeToHalErrorStatus(n);
482     hardware::hidl_vec<V1_2::OutputShape> outputShapes = convertToV1_2(executor.getOutputShapes());
483 
484     // Update device memory metadata.
485     const V1_3::ErrorStatus updateStatus =
486             updateDeviceMemories(executionStatus, request, bufferWrappers, outputShapes);
487     if (updateStatus != V1_3::ErrorStatus::NONE) {
488         notify(callback, updateStatus, {}, kNoTiming);
489         return;
490     }
491 
492     if (measure == V1_2::MeasureTiming::YES && executionStatus == V1_3::ErrorStatus::NONE) {
493         driverEnd = Clock::now();
494         V1_2::Timing timing = {
495                 .timeOnDevice = uint64_t(microsecondsDuration(deviceEnd, deviceStart)),
496                 .timeInDriver = uint64_t(microsecondsDuration(driverEnd, driverStart))};
497         VLOG(DRIVER) << "SampleDriver::asyncExecute timing = " << toString(timing);
498         notify(callback, executionStatus, outputShapes, timing);
499     } else {
500         notify(callback, executionStatus, outputShapes, kNoTiming);
501     }
502 }
503 
504 template <typename T_IExecutionCallback>
executeBase(const V1_3::Request & request,V1_2::MeasureTiming measure,const V1_3::Model & model,const SampleDriver & driver,const SamplePreparedModel * preparedModel,const std::vector<RunTimePoolInfo> & poolInfos,const V1_3::OptionalTimePoint & halDeadline,const V1_3::OptionalTimeoutDuration & loopTimeoutDuration,const sp<T_IExecutionCallback> & callback)505 V1_3::ErrorStatus executeBase(const V1_3::Request& request, V1_2::MeasureTiming measure,
506                               const V1_3::Model& model, const SampleDriver& driver,
507                               const SamplePreparedModel* preparedModel,
508                               const std::vector<RunTimePoolInfo>& poolInfos,
509                               const V1_3::OptionalTimePoint& halDeadline,
510                               const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
511                               const sp<T_IExecutionCallback>& callback) {
512     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION, "SampleDriver::executeBase");
513     VLOG(DRIVER) << "executeBase(" << SHOW_IF_DEBUG(toString(request)) << ")";
514 
515     TimePoint driverStart;
516     if (measure == V1_2::MeasureTiming::YES) driverStart = Clock::now();
517 
518     if (callback.get() == nullptr) {
519         LOG(ERROR) << "invalid callback passed to executeBase";
520         return V1_3::ErrorStatus::INVALID_ARGUMENT;
521     }
522     if (!validateRequest(request, model)) {
523         notify(callback, V1_3::ErrorStatus::INVALID_ARGUMENT, {}, kNoTiming);
524         return V1_3::ErrorStatus::INVALID_ARGUMENT;
525     }
526     const auto deadline = convert(halDeadline).value();
527     if (hasDeadlinePassed(deadline)) {
528         notify(callback, V1_3::ErrorStatus::MISSED_DEADLINE_PERSISTENT, {}, kNoTiming);
529         return V1_3::ErrorStatus::NONE;
530     }
531 
532     // This thread is intentionally detached because the sample driver service
533     // is expected to live forever.
534     std::thread([&model, &driver, preparedModel, &poolInfos, request, measure, driverStart,
535                  deadline, loopTimeoutDuration, callback] {
536         asyncExecute(request, measure, driverStart, model, driver, preparedModel, poolInfos,
537                      deadline, loopTimeoutDuration, callback);
538     }).detach();
539 
540     return V1_3::ErrorStatus::NONE;
541 }
542 
execute(const V1_0::Request & request,const sp<V1_0::IExecutionCallback> & callback)543 hardware::Return<V1_0::ErrorStatus> SamplePreparedModel::execute(
544         const V1_0::Request& request, const sp<V1_0::IExecutionCallback>& callback) {
545     const V1_3::ErrorStatus status =
546             executeBase(convertToV1_3(request), V1_2::MeasureTiming::NO, mModel, *mDriver, this,
547                         mPoolInfos, {}, {}, callback);
548     return convertToV1_0(status);
549 }
550 
execute_1_2(const V1_0::Request & request,V1_2::MeasureTiming measure,const sp<V1_2::IExecutionCallback> & callback)551 hardware::Return<V1_0::ErrorStatus> SamplePreparedModel::execute_1_2(
552         const V1_0::Request& request, V1_2::MeasureTiming measure,
553         const sp<V1_2::IExecutionCallback>& callback) {
554     const V1_3::ErrorStatus status = executeBase(convertToV1_3(request), measure, mModel, *mDriver,
555                                                  this, mPoolInfos, {}, {}, callback);
556     return convertToV1_0(status);
557 }
558 
execute_1_3(const V1_3::Request & request,V1_2::MeasureTiming measure,const V1_3::OptionalTimePoint & deadline,const V1_3::OptionalTimeoutDuration & loopTimeoutDuration,const sp<V1_3::IExecutionCallback> & callback)559 hardware::Return<V1_3::ErrorStatus> SamplePreparedModel::execute_1_3(
560         const V1_3::Request& request, V1_2::MeasureTiming measure,
561         const V1_3::OptionalTimePoint& deadline,
562         const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
563         const sp<V1_3::IExecutionCallback>& callback) {
564     return executeBase(request, measure, mModel, *mDriver, this, mPoolInfos, deadline,
565                        loopTimeoutDuration, callback);
566 }
567 
568 static std::tuple<V1_3::ErrorStatus, hardware::hidl_vec<V1_2::OutputShape>, V1_2::Timing>
executeSynchronouslyBase(const V1_3::Request & request,V1_2::MeasureTiming measure,const V1_3::Model & model,const SampleDriver & driver,const SamplePreparedModel * preparedModel,const std::vector<RunTimePoolInfo> & poolInfos,const V1_3::OptionalTimePoint & halDeadline,const V1_3::OptionalTimeoutDuration & loopTimeoutDuration)569 executeSynchronouslyBase(const V1_3::Request& request, V1_2::MeasureTiming measure,
570                          const V1_3::Model& model, const SampleDriver& driver,
571                          const SamplePreparedModel* preparedModel,
572                          const std::vector<RunTimePoolInfo>& poolInfos,
573                          const V1_3::OptionalTimePoint& halDeadline,
574                          const V1_3::OptionalTimeoutDuration& loopTimeoutDuration) {
575     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
576                  "SampleDriver::executeSynchronouslyBase");
577     VLOG(DRIVER) << "executeSynchronouslyBase(" << SHOW_IF_DEBUG(toString(request)) << ")";
578 
579     TimePoint driverStart, driverEnd, deviceStart, deviceEnd;
580     if (measure == V1_2::MeasureTiming::YES) driverStart = Clock::now();
581 
582     if (!validateRequest(request, model)) {
583         return {V1_3::ErrorStatus::INVALID_ARGUMENT, {}, kNoTiming};
584     }
585     const auto deadline = convert(halDeadline).value();
586     if (hasDeadlinePassed(deadline)) {
587         return {V1_3::ErrorStatus::MISSED_DEADLINE_PERSISTENT, {}, kNoTiming};
588     }
589 
590     NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INPUTS_AND_OUTPUTS,
591                         "SampleDriver::executeSynchronouslyBase");
592     const auto [poolStatus, requestPoolInfos, bufferWrappers] =
593             createRunTimePoolInfos(request, driver, preparedModel);
594     if (poolStatus != V1_3::ErrorStatus::NONE) {
595         return {poolStatus, {}, kNoTiming};
596     }
597 
598     NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
599                         "SampleDriver::executeSynchronouslyBase");
600     CpuExecutor executor = driver.getExecutor();
601     if (loopTimeoutDuration.getDiscriminator() !=
602         V1_3::OptionalTimeoutDuration::hidl_discriminator::none) {
603         executor.setLoopTimeout(loopTimeoutDuration.nanoseconds());
604     }
605     if (deadline.has_value()) {
606         executor.setDeadline(*deadline);
607     }
608     if (measure == V1_2::MeasureTiming::YES) deviceStart = Clock::now();
609     int n = executor.run(uncheckedConvert(model), uncheckedConvert(request), poolInfos,
610                          requestPoolInfos);
611     if (measure == V1_2::MeasureTiming::YES) deviceEnd = Clock::now();
612     VLOG(DRIVER) << "executor.run returned " << n;
613     V1_3::ErrorStatus executionStatus = convertResultCodeToHalErrorStatus(n);
614     hardware::hidl_vec<V1_2::OutputShape> outputShapes = convertToV1_2(executor.getOutputShapes());
615 
616     // Update device memory metadata.
617     const V1_3::ErrorStatus updateStatus =
618             updateDeviceMemories(executionStatus, request, bufferWrappers, outputShapes);
619     if (updateStatus != V1_3::ErrorStatus::NONE) {
620         return {updateStatus, {}, kNoTiming};
621     }
622 
623     if (measure == V1_2::MeasureTiming::YES && executionStatus == V1_3::ErrorStatus::NONE) {
624         driverEnd = Clock::now();
625         V1_2::Timing timing = {
626                 .timeOnDevice = uint64_t(microsecondsDuration(deviceEnd, deviceStart)),
627                 .timeInDriver = uint64_t(microsecondsDuration(driverEnd, driverStart))};
628         VLOG(DRIVER) << "executeSynchronouslyBase timing = " << toString(timing);
629         return {executionStatus, std::move(outputShapes), timing};
630     }
631     return {executionStatus, std::move(outputShapes), kNoTiming};
632 }
633 
executeSynchronously(const V1_0::Request & request,V1_2::MeasureTiming measure,executeSynchronously_cb cb)634 hardware::Return<void> SamplePreparedModel::executeSynchronously(const V1_0::Request& request,
635                                                                  V1_2::MeasureTiming measure,
636                                                                  executeSynchronously_cb cb) {
637     auto [status, outputShapes, timing] = executeSynchronouslyBase(
638             convertToV1_3(request), measure, mModel, *mDriver, this, mPoolInfos, {}, {});
639     cb(convertToV1_0(status), std::move(outputShapes), timing);
640     return hardware::Void();
641 }
642 
executeSynchronously_1_3(const V1_3::Request & request,V1_2::MeasureTiming measure,const V1_3::OptionalTimePoint & deadline,const V1_3::OptionalTimeoutDuration & loopTimeoutDuration,executeSynchronously_1_3_cb cb)643 hardware::Return<void> SamplePreparedModel::executeSynchronously_1_3(
644         const V1_3::Request& request, V1_2::MeasureTiming measure,
645         const V1_3::OptionalTimePoint& deadline,
646         const V1_3::OptionalTimeoutDuration& loopTimeoutDuration, executeSynchronously_1_3_cb cb) {
647     auto [status, outputShapes, timing] = executeSynchronouslyBase(
648             request, measure, mModel, *mDriver, this, mPoolInfos, deadline, loopTimeoutDuration);
649     cb(status, std::move(outputShapes), timing);
650     return hardware::Void();
651 }
652 
653 // The sample driver will finish the execution and then return.
executeFenced(const V1_3::Request & request,const hardware::hidl_vec<hardware::hidl_handle> & waitFor,V1_2::MeasureTiming measure,const V1_3::OptionalTimePoint & halDeadline,const V1_3::OptionalTimeoutDuration & loopTimeoutDuration,const V1_3::OptionalTimeoutDuration & duration,executeFenced_cb cb)654 hardware::Return<void> SamplePreparedModel::executeFenced(
655         const V1_3::Request& request, const hardware::hidl_vec<hardware::hidl_handle>& waitFor,
656         V1_2::MeasureTiming measure, const V1_3::OptionalTimePoint& halDeadline,
657         const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
658         const V1_3::OptionalTimeoutDuration& duration, executeFenced_cb cb) {
659     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
660                  "SamplePreparedModel::executeFenced");
661     VLOG(DRIVER) << "executeFenced(" << SHOW_IF_DEBUG(toString(request)) << ")";
662 
663     TimePoint driverStart, driverEnd, deviceStart, deviceEnd;
664     if (measure == V1_2::MeasureTiming::YES) driverStart = Clock::now();
665 
666     if (!validateRequest(request, mModel, /*allowUnspecifiedOutput=*/false)) {
667         cb(V1_3::ErrorStatus::INVALID_ARGUMENT, hardware::hidl_handle(nullptr), nullptr);
668         return hardware::Void();
669     }
670     const auto deadline = convert(halDeadline).value();
671     if (hasDeadlinePassed(deadline)) {
672         cb(V1_3::ErrorStatus::MISSED_DEADLINE_PERSISTENT, hardware::hidl_handle(nullptr), nullptr);
673         return hardware::Void();
674     }
675 
676     // Wait for the dependent events to signal
677     for (const auto& fenceHandle : waitFor) {
678         if (!fenceHandle.getNativeHandle()) {
679             cb(V1_3::ErrorStatus::INVALID_ARGUMENT, hardware::hidl_handle(nullptr), nullptr);
680             return hardware::Void();
681         }
682         int syncFenceFd = fenceHandle.getNativeHandle()->data[0];
683         if (syncWait(syncFenceFd, -1) != FenceState::SIGNALED) {
684             LOG(ERROR) << "syncWait failed";
685             cb(V1_3::ErrorStatus::GENERAL_FAILURE, hardware::hidl_handle(nullptr), nullptr);
686             return hardware::Void();
687         }
688     }
689 
690     // Update deadline if the timeout duration is closer than the deadline.
691     auto closestDeadline = deadline;
692     if (duration.getDiscriminator() != V1_3::OptionalTimeoutDuration::hidl_discriminator::none) {
693         const auto timeoutDurationDeadline = makeDeadline(duration.nanoseconds());
694         if (!closestDeadline.has_value() || *closestDeadline > timeoutDurationDeadline) {
695             closestDeadline = timeoutDurationDeadline;
696         }
697     }
698 
699     TimePoint driverStartAfterFence;
700     if (measure == V1_2::MeasureTiming::YES) driverStartAfterFence = Clock::now();
701 
702     NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INPUTS_AND_OUTPUTS,
703                         "SamplePreparedModel::executeFenced");
704     const auto [poolStatus, requestPoolInfos, bufferWrappers] =
705             createRunTimePoolInfos(request, *mDriver, this);
706     if (poolStatus != V1_3::ErrorStatus::NONE) {
707         cb(poolStatus, hardware::hidl_handle(nullptr), nullptr);
708         return hardware::Void();
709     }
710 
711     NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
712                         "SamplePreparedModel::executeFenced");
713     CpuExecutor executor = mDriver->getExecutor();
714     if (loopTimeoutDuration.getDiscriminator() !=
715         V1_3::OptionalTimeoutDuration::hidl_discriminator::none) {
716         executor.setLoopTimeout(loopTimeoutDuration.nanoseconds());
717     }
718     if (closestDeadline.has_value()) {
719         executor.setDeadline(*closestDeadline);
720     }
721     if (measure == V1_2::MeasureTiming::YES) deviceStart = Clock::now();
722     int n = executor.run(uncheckedConvert(mModel), uncheckedConvert(request), mPoolInfos,
723                          requestPoolInfos);
724     if (measure == V1_2::MeasureTiming::YES) deviceEnd = Clock::now();
725     VLOG(DRIVER) << "executor.run returned " << n;
726     V1_3::ErrorStatus executionStatus = convertResultCodeToHalErrorStatus(n);
727     if (executionStatus != V1_3::ErrorStatus::NONE) {
728         cb(executionStatus, hardware::hidl_handle(nullptr), nullptr);
729         return hardware::Void();
730     }
731 
732     // Set output memories to the initialized state.
733     if (executionStatus == V1_3::ErrorStatus::NONE) {
734         for (const auto& output : request.outputs) {
735             const uint32_t poolIndex = output.location.poolIndex;
736             const auto& pool = request.pools[poolIndex];
737             if (pool.getDiscriminator() == V1_3::Request::MemoryPool::hidl_discriminator::token) {
738                 bufferWrappers[poolIndex]->setInitialized(true);
739             }
740         }
741     }
742 
743     V1_2::Timing timingSinceLaunch = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
744     V1_2::Timing timingAfterFence = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
745     if (measure == V1_2::MeasureTiming::YES) {
746         driverEnd = Clock::now();
747         timingSinceLaunch = {
748                 .timeOnDevice = uint64_t(microsecondsDuration(deviceEnd, deviceStart)),
749                 .timeInDriver = uint64_t(microsecondsDuration(driverEnd, driverStart))};
750         timingAfterFence = {
751                 .timeOnDevice = uint64_t(microsecondsDuration(deviceEnd, deviceStart)),
752                 .timeInDriver = uint64_t(microsecondsDuration(driverEnd, driverStartAfterFence))};
753         VLOG(DRIVER) << "executeFenced timingSinceLaunch = " << toString(timingSinceLaunch);
754         VLOG(DRIVER) << "executeFenced timingAfterFence = " << toString(timingAfterFence);
755     }
756     sp<SampleFencedExecutionCallback> fencedExecutionCallback =
757             new SampleFencedExecutionCallback(timingSinceLaunch, timingAfterFence, executionStatus);
758     cb(executionStatus, hardware::hidl_handle(nullptr), fencedExecutionCallback);
759     return hardware::Void();
760 }
761 
762 // BurstExecutorWithCache maps hidl_memory when it is first seen, and preserves
763 // the mapping until either (1) the memory is freed in the runtime, or (2) the
764 // burst object is destroyed. This allows for subsequent executions operating on
765 // pools that have been used before to reuse the mapping instead of mapping and
766 // unmapping the memory on each execution.
767 class BurstExecutorWithCache : public ExecutionBurstServer::IBurstExecutorWithCache {
768    public:
BurstExecutorWithCache(const V1_3::Model & model,const SampleDriver * driver,const std::vector<RunTimePoolInfo> & poolInfos)769     BurstExecutorWithCache(const V1_3::Model& model, const SampleDriver* driver,
770                            const std::vector<RunTimePoolInfo>& poolInfos)
771         : mModel(model), mDriver(driver), mModelPoolInfos(poolInfos) {}
772 
isCacheEntryPresent(int32_t slot) const773     bool isCacheEntryPresent(int32_t slot) const override {
774         const auto it = mMemoryCache.find(slot);
775         return (it != mMemoryCache.end()) && it->second.has_value();
776     }
777 
addCacheEntry(const hardware::hidl_memory & memory,int32_t slot)778     void addCacheEntry(const hardware::hidl_memory& memory, int32_t slot) override {
779         mMemoryCache[slot] = RunTimePoolInfo::createFromMemory(uncheckedConvert(memory));
780     }
781 
removeCacheEntry(int32_t slot)782     void removeCacheEntry(int32_t slot) override { mMemoryCache.erase(slot); }
783 
execute(const V1_0::Request & request,const std::vector<int32_t> & slots,V1_2::MeasureTiming measure)784     std::tuple<V1_0::ErrorStatus, hardware::hidl_vec<V1_2::OutputShape>, V1_2::Timing> execute(
785             const V1_0::Request& request, const std::vector<int32_t>& slots,
786             V1_2::MeasureTiming measure) override {
787         NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
788                      "BurstExecutorWithCache::execute");
789 
790         TimePoint driverStart, driverEnd, deviceStart, deviceEnd;
791         if (measure == V1_2::MeasureTiming::YES) driverStart = Clock::now();
792 
793         // ensure all relevant pools are valid
794         if (!std::all_of(slots.begin(), slots.end(),
795                          [this](int32_t slot) { return isCacheEntryPresent(slot); })) {
796             return {V1_0::ErrorStatus::INVALID_ARGUMENT, {}, kNoTiming};
797         }
798 
799         // finish the request object (for validation)
800         hardware::hidl_vec<V1_3::Request::MemoryPool> pools(slots.size());
801         std::transform(slots.begin(), slots.end(), pools.begin(), [this](int32_t slot) {
802             V1_3::Request::MemoryPool pool;
803             pool.hidlMemory(convertToV1_0(mMemoryCache[slot]->getMemory()));
804             return pool;
805         });
806         V1_3::Request fullRequest = {.inputs = request.inputs, .outputs = request.outputs};
807         fullRequest.pools = std::move(pools);
808 
809         // validate request object against the model
810         if (!validateRequest(fullRequest, mModel)) {
811             return {V1_0::ErrorStatus::INVALID_ARGUMENT, {}, kNoTiming};
812         }
813 
814         // select relevant entries from cache
815         std::vector<RunTimePoolInfo> requestPoolInfos;
816         requestPoolInfos.reserve(slots.size());
817         std::transform(slots.begin(), slots.end(), std::back_inserter(requestPoolInfos),
818                        [this](int32_t slot) { return *mMemoryCache[slot]; });
819 
820         // execution
821         // Configuring the loop timeout duration is not supported. This is OK
822         // because burst does not support HAL 1.3 and hence does not support
823         // WHILE loops.
824         CpuExecutor executor = mDriver->getExecutor();
825         if (measure == V1_2::MeasureTiming::YES) deviceStart = Clock::now();
826         int n = executor.run(uncheckedConvert(mModel), uncheckedConvert(fullRequest),
827                              mModelPoolInfos, requestPoolInfos);
828         if (measure == V1_2::MeasureTiming::YES) deviceEnd = Clock::now();
829         VLOG(DRIVER) << "executor.run returned " << n;
830         V1_0::ErrorStatus executionStatus = convertToV1_0(convertResultCodeToHalErrorStatus(n));
831         hardware::hidl_vec<V1_2::OutputShape> outputShapes =
832                 convertToV1_2(executor.getOutputShapes());
833         if (measure == V1_2::MeasureTiming::YES && executionStatus == V1_0::ErrorStatus::NONE) {
834             driverEnd = Clock::now();
835             V1_2::Timing timing = {
836                     .timeOnDevice = uint64_t(microsecondsDuration(deviceEnd, deviceStart)),
837                     .timeInDriver = uint64_t(microsecondsDuration(driverEnd, driverStart))};
838             VLOG(DRIVER) << "BurstExecutorWithCache::execute timing = " << toString(timing);
839             return std::make_tuple(executionStatus, outputShapes, timing);
840         } else {
841             return std::make_tuple(executionStatus, outputShapes, kNoTiming);
842         }
843     }
844 
845    private:
846     const V1_3::Model mModel;
847     const SampleDriver* const mDriver;
848     const std::vector<RunTimePoolInfo> mModelPoolInfos;
849     std::map<int32_t, std::optional<RunTimePoolInfo>> mMemoryCache;  // cached requestPoolInfos
850 };
851 
852 // This is the amount of time the ExecutionBurstServer should spend polling the
853 // FMQ to see if it has data available before it should fall back to waiting on
854 // the futex.
getPollingTimeWindow()855 static std::chrono::microseconds getPollingTimeWindow() {
856     constexpr int32_t defaultPollingTimeWindow = 50;
857 #ifdef NN_DEBUGGABLE
858     constexpr int32_t minPollingTimeWindow = 0;
859     const int32_t selectedPollingTimeWindow =
860             base::GetIntProperty("debug.nn.sample-driver-burst-polling-window",
861                                  defaultPollingTimeWindow, minPollingTimeWindow);
862     return std::chrono::microseconds{selectedPollingTimeWindow};
863 #else
864     return std::chrono::microseconds{defaultPollingTimeWindow};
865 #endif  // NN_DEBUGGABLE
866 }
867 
configureExecutionBurst(const sp<V1_2::IBurstCallback> & callback,const MQDescriptorSync<V1_2::FmqRequestDatum> & requestChannel,const MQDescriptorSync<V1_2::FmqResultDatum> & resultChannel,configureExecutionBurst_cb cb)868 hardware::Return<void> SamplePreparedModel::configureExecutionBurst(
869         const sp<V1_2::IBurstCallback>& callback,
870         const MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel,
871         const MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel,
872         configureExecutionBurst_cb cb) {
873     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
874                  "SampleDriver::configureExecutionBurst");
875 
876     const bool preferPowerOverLatency = (kPreference == V1_1::ExecutionPreference::LOW_POWER);
877     const auto pollingTimeWindow =
878             (preferPowerOverLatency ? std::chrono::microseconds{0} : getPollingTimeWindow());
879 
880     // Alternatively, the burst could be configured via:
881     // const sp<V1_2::IBurstContext> burst =
882     //         ExecutionBurstServer::create(callback, requestChannel,
883     //                                      resultChannel, this,
884     //                                      pollingTimeWindow);
885     //
886     // However, this alternative representation does not include a memory map
887     // caching optimization, and adds overhead.
888     const std::shared_ptr<BurstExecutorWithCache> executorWithCache =
889             std::make_shared<BurstExecutorWithCache>(mModel, mDriver, mPoolInfos);
890     const sp<V1_2::IBurstContext> burst = ExecutionBurstServer::create(
891             callback, requestChannel, resultChannel, executorWithCache, pollingTimeWindow);
892 
893     if (burst == nullptr) {
894         cb(V1_0::ErrorStatus::GENERAL_FAILURE, {});
895     } else {
896         cb(V1_0::ErrorStatus::NONE, burst);
897     }
898 
899     return hardware::Void();
900 }
901 
902 }  // namespace sample_driver
903 }  // namespace nn
904 }  // namespace android
905