1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "VersionedInterfaces"
18
19 #include "VersionedInterfaces.h"
20
21 #include <android-base/logging.h>
22 #include <android-base/properties.h>
23 #include <android-base/scopeguard.h>
24 #include <android-base/thread_annotations.h>
25 #include <android/sync.h>
26 #include <cutils/native_handle.h>
27
28 #include <algorithm>
29 #include <chrono>
30 #include <functional>
31 #include <memory>
32 #include <string>
33 #include <tuple>
34 #include <type_traits>
35 #include <utility>
36 #include <vector>
37
38 #include "Callbacks.h"
39 #include "ExecutionBurstController.h"
40 #include "MetaModel.h"
41 #include "Tracing.h"
42 #include "Utils.h"
43
44 /*
45 * Some notes about HIDL interface objects and lifetimes across processes:
46 *
47 * All HIDL interface objects inherit from IBase, which itself inherits from
48 * ::android::RefBase. As such, all HIDL interface objects are reference counted
49 * and must be owned through ::android::sp (or referenced through ::android::wp).
50 * Allocating RefBase objects on the stack will log errors and may result in
51 * crashes, and deleting a RefBase object through another means (e.g., "delete",
52 * "free", or RAII-cleanup through std::unique_ptr or some equivalent) will
53 * result in double-free and/or use-after-free undefined behavior.
54 *
55 * HIDL/Binder manages the reference count of HIDL interface objects
56 * automatically across processes. If a process that references (but did not
57 * create) the HIDL interface object dies, HIDL/Binder ensures any reference
58 * count it held is properly released. (Caveat: it might be possible that
59 * HIDL/Binder behave strangely with ::android::wp references.)
60 *
61 * If the process which created the HIDL interface object dies, any call on this
62 * object from another process will result in a HIDL transport error with the
63 * code DEAD_OBJECT.
64 */
65
66 /*
67 * Some notes about asynchronous calls across HIDL:
68 *
69 * For synchronous calls across HIDL, if an error occurs after the function was
70 * called but before it returns, HIDL will return a transport error. For
71 * example, if the message cannot be delivered to the server process or if the
72 * server process dies before returning a result, HIDL will return from the
73 * function with the appropriate transport error in the Return<> object which
74 * can be queried with Return<>::isOk(), Return<>::isDeadObject(),
75 * Return<>::description(), etc.
76 *
77 * However, HIDL offers no such error management in the case of asynchronous
78 * calls. By default, if the client launches an asynchronous task and the server
79 * fails to return a result through the callback, the client will be left
80 * waiting indefinitely for a result it will never receive.
81 *
82 * In the NNAPI, IDevice::prepareModel* and IPreparedModel::execute* (but not
83 * IPreparedModel::executeSynchronously*) are asynchronous calls across HIDL.
84 * Specifically, these asynchronous functions are called with a HIDL interface
85 * callback object (IPrepareModelCallback for IDevice::prepareModel* and
86 * IExecutionCallback for IPreparedModel::execute*) and are expected to quickly
87 * return, and the results are returned at a later time through these callback
88 * objects.
89 *
90 * To protect against the case when the server dies after the asynchronous task
91 * was called successfully but before the results could be returned, HIDL
92 * provides an object called a "hidl_death_recipient", which can be used to
93 * detect when an interface object (and more generally, the server process) has
94 * died. VersionedInterfaces uses hidl_death_recipients to detect when the
95 * driver process has died, and VersionedInterfaces will unblock any thread
96 * waiting on the results of a callback object that may otherwise not be
97 * signaled.
98 */
99
100 namespace android {
101 namespace nn {
102
103 // anonymous namespace
104 namespace {
105
106 using namespace hal;
107
108 const Timing kNoTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
109
sendFailureMessage(IPreparedModelCallback * cb)110 void sendFailureMessage(IPreparedModelCallback* cb) {
111 CHECK(cb != nullptr);
112 cb->notify_1_3(ErrorStatus::GENERAL_FAILURE, nullptr);
113 }
114
115 // This class is thread safe
116 template <typename Callback>
117 class DeathHandler : public hidl_death_recipient {
118 public:
serviceDied(uint64_t,const wp<hidl::base::V1_0::IBase> &)119 void serviceDied(uint64_t /*cookie*/, const wp<hidl::base::V1_0::IBase>& /*who*/) override {
120 LOG(ERROR) << "DeathHandler::serviceDied -- service unexpectedly died!";
121 std::lock_guard<std::mutex> hold(mMutex);
122 std::for_each(mCallbacks.begin(), mCallbacks.end(),
123 [](const auto& cb) { cb->notifyAsDeadObject(); });
124 }
125
protectCallback(const sp<Callback> & callback)126 [[nodiscard]] base::ScopeGuard<std::function<void()>> protectCallback(
127 const sp<Callback>& callback) {
128 registerCallback(callback);
129 return ::android::base::make_scope_guard(
130 [this, callback] { unregisterCallback(callback); });
131 }
132
133 private:
registerCallback(const sp<Callback> & callback)134 void registerCallback(const sp<Callback>& callback) {
135 std::lock_guard<std::mutex> hold(mMutex);
136 mCallbacks.push_back(callback);
137 }
138
unregisterCallback(const sp<Callback> & callback)139 void unregisterCallback(const sp<Callback>& callback) {
140 std::lock_guard<std::mutex> hold(mMutex);
141 mCallbacks.erase(std::remove(mCallbacks.begin(), mCallbacks.end(), callback),
142 mCallbacks.end());
143 }
144
145 std::mutex mMutex;
146 std::vector<sp<Callback>> mCallbacks GUARDED_BY(mMutex);
147 };
148
149 } // anonymous namespace
150
151 class IDeviceDeathHandler : public DeathHandler<PreparedModelCallback> {};
152 class IPreparedModelDeathHandler : public DeathHandler<ExecutionCallback> {};
153
makeVersionedIPreparedModel(sp<V1_0::IPreparedModel> preparedModel)154 static std::pair<int, std::shared_ptr<VersionedIPreparedModel>> makeVersionedIPreparedModel(
155 sp<V1_0::IPreparedModel> preparedModel) {
156 CHECK(preparedModel != nullptr)
157 << "makeVersionedIPreparedModel passed invalid preparedModel object.";
158
159 // create death handler object
160 sp<IPreparedModelDeathHandler> deathHandler = new IPreparedModelDeathHandler();
161
162 // linkToDeath registers a callback that will be invoked on service death to
163 // proactively handle service crashes. If the linkToDeath call fails,
164 // asynchronous calls are susceptible to hangs if the service crashes before
165 // providing the response.
166 const Return<bool> ret = preparedModel->linkToDeath(deathHandler, 0);
167 if (ret.isDeadObject()) {
168 LOG(ERROR) << "makeVersionedIPreparedModel failed to register a death recipient for the "
169 "IPreparedModel object because the IPreparedModel object is dead.";
170 return {ANEURALNETWORKS_DEAD_OBJECT, nullptr};
171 }
172 if (!ret.isOk()) {
173 LOG(ERROR) << "makeVersionedIPreparedModel failed to register a death recipient for the "
174 "IPreparedModel object because of failure: "
175 << ret.description();
176 return {ANEURALNETWORKS_OP_FAILED, nullptr};
177 }
178 if (ret != true) {
179 LOG(ERROR) << "makeVersionedIPreparedModel failed to register a death recipient for the "
180 "IPreparedModel object.";
181 return {ANEURALNETWORKS_OP_FAILED, nullptr};
182 }
183
184 // return a valid VersionedIPreparedModel object
185 return {ANEURALNETWORKS_NO_ERROR, std::make_shared<VersionedIPreparedModel>(
186 std::move(preparedModel), std::move(deathHandler))};
187 }
188
VersionedIPreparedModel(sp<V1_0::IPreparedModel> preparedModel,sp<IPreparedModelDeathHandler> deathHandler)189 VersionedIPreparedModel::VersionedIPreparedModel(sp<V1_0::IPreparedModel> preparedModel,
190 sp<IPreparedModelDeathHandler> deathHandler)
191 : mPreparedModelV1_0(std::move(preparedModel)),
192 mPreparedModelV1_2(V1_2::IPreparedModel::castFrom(mPreparedModelV1_0).withDefault(nullptr)),
193 mPreparedModelV1_3(V1_3::IPreparedModel::castFrom(mPreparedModelV1_0).withDefault(nullptr)),
194 mDeathHandler(std::move(deathHandler)) {}
195
~VersionedIPreparedModel()196 VersionedIPreparedModel::~VersionedIPreparedModel() {
197 // It is safe to ignore any errors resulting from this unlinkToDeath call
198 // because the VersionedIPreparedModel object is already being destroyed and
199 // its underlying IPreparedModel object is no longer being used by the NN
200 // runtime.
201 mPreparedModelV1_0->unlinkToDeath(mDeathHandler).isOk();
202 }
203
executeAsynchronously(const Request & request,MeasureTiming measure,const std::optional<Deadline> & deadline,const OptionalTimeoutDuration & loopTimeoutDuration) const204 std::tuple<int, std::vector<OutputShape>, Timing> VersionedIPreparedModel::executeAsynchronously(
205 const Request& request, MeasureTiming measure, const std::optional<Deadline>& deadline,
206 const OptionalTimeoutDuration& loopTimeoutDuration) const {
207 const auto failDeadObject = []() -> std::tuple<int, std::vector<OutputShape>, Timing> {
208 return {ANEURALNETWORKS_DEAD_OBJECT, {}, kNoTiming};
209 };
210 const auto failWithStatus = [](ErrorStatus status) {
211 return getExecutionResult(status, {}, kNoTiming);
212 };
213 const auto getResults = [failDeadObject](const ExecutionCallback& cb) {
214 if (cb.isDeadObject()) {
215 return failDeadObject();
216 }
217 return getExecutionResult(cb.getStatus(), cb.getOutputShapes(), cb.getTiming());
218 };
219
220 const sp<ExecutionCallback> callback = new ExecutionCallback();
221 const auto scoped = mDeathHandler->protectCallback(callback);
222
223 // version 1.3+ HAL
224 if (mPreparedModelV1_3 != nullptr) {
225 const auto otp = makeTimePoint(deadline);
226 Return<ErrorStatus> ret = mPreparedModelV1_3->execute_1_3(request, measure, otp,
227 loopTimeoutDuration, callback);
228 if (ret.isDeadObject()) {
229 LOG(ERROR) << "execute_1_3 failure: " << ret.description();
230 return failDeadObject();
231 }
232 if (!ret.isOk()) {
233 LOG(ERROR) << "execute_1_3 failure: " << ret.description();
234 return failWithStatus(ErrorStatus::GENERAL_FAILURE);
235 }
236 if (ret != ErrorStatus::NONE) {
237 LOG(ERROR) << "execute_1_3 returned " << toString(static_cast<ErrorStatus>(ret));
238 return failWithStatus(ret);
239 }
240 callback->wait();
241 return getResults(*callback);
242 }
243
244 // version 1.2 HAL
245 if (mPreparedModelV1_2 != nullptr) {
246 const bool compliant = compliantWithV1_2(request);
247 if (!compliant) {
248 LOG(ERROR) << "Could not handle execute_1_2!";
249 return failWithStatus(ErrorStatus::GENERAL_FAILURE);
250 }
251 const V1_0::Request request12 = convertToV1_2(request);
252 Return<V1_0::ErrorStatus> ret =
253 mPreparedModelV1_2->execute_1_2(request12, measure, callback);
254 if (ret.isDeadObject()) {
255 LOG(ERROR) << "execute_1_2 failure: " << ret.description();
256 return failDeadObject();
257 }
258 if (!ret.isOk()) {
259 LOG(ERROR) << "execute_1_2 failure: " << ret.description();
260 return failWithStatus(ErrorStatus::GENERAL_FAILURE);
261 }
262 const V1_0::ErrorStatus status = static_cast<V1_0::ErrorStatus>(ret);
263 if (status != V1_0::ErrorStatus::NONE) {
264 LOG(ERROR) << "execute_1_2 returned " << toString(status);
265 return failWithStatus(convertToV1_3(status));
266 }
267 callback->wait();
268 return getResults(*callback);
269 }
270
271 // version 1.0 HAL
272 if (mPreparedModelV1_0 != nullptr) {
273 const bool compliant = compliantWithV1_0(request);
274 if (!compliant) {
275 LOG(ERROR) << "Could not handle execute!";
276 return failWithStatus(ErrorStatus::GENERAL_FAILURE);
277 }
278 const V1_0::Request request10 = convertToV1_0(request);
279 Return<V1_0::ErrorStatus> ret = mPreparedModelV1_0->execute(request10, callback);
280 if (ret.isDeadObject()) {
281 LOG(ERROR) << "execute failure: " << ret.description();
282 return failDeadObject();
283 }
284 if (!ret.isOk()) {
285 LOG(ERROR) << "execute failure: " << ret.description();
286 return failWithStatus(ErrorStatus::GENERAL_FAILURE);
287 }
288 const V1_0::ErrorStatus status = static_cast<V1_0::ErrorStatus>(ret);
289 if (status != V1_0::ErrorStatus::NONE) {
290 LOG(ERROR) << "execute returned " << toString(status);
291 return failWithStatus(convertToV1_3(status));
292 }
293 callback->wait();
294 return getResults(*callback);
295 }
296
297 // No prepared model available
298 LOG(ERROR) << "executeAsynchronously called with no preparedModel";
299 return failWithStatus(ErrorStatus::GENERAL_FAILURE);
300 }
301
executeSynchronously(const Request & request,MeasureTiming measure,const std::optional<Deadline> & deadline,const OptionalTimeoutDuration & loopTimeoutDuration) const302 std::tuple<int, std::vector<OutputShape>, Timing> VersionedIPreparedModel::executeSynchronously(
303 const Request& request, MeasureTiming measure, const std::optional<Deadline>& deadline,
304 const OptionalTimeoutDuration& loopTimeoutDuration) const {
305 const std::tuple<int, std::vector<OutputShape>, Timing> kDeadObject = {
306 ANEURALNETWORKS_DEAD_OBJECT, {}, kNoTiming};
307 const auto kFailure = getExecutionResult(ErrorStatus::GENERAL_FAILURE, {}, kNoTiming);
308
309 // version 1.3+ HAL
310 if (mPreparedModelV1_3 != nullptr) {
311 std::tuple<int, std::vector<OutputShape>, Timing> result;
312 const auto otp = makeTimePoint(deadline);
313 Return<void> ret = mPreparedModelV1_3->executeSynchronously_1_3(
314 request, measure, otp, loopTimeoutDuration,
315 [&result](ErrorStatus error, const hidl_vec<OutputShape>& outputShapes,
316 const Timing& timing) {
317 result = getExecutionResult(error, outputShapes, timing);
318 });
319 if (ret.isDeadObject()) {
320 LOG(ERROR) << "executeSynchronously_1_3 failure: " << ret.description();
321 return kDeadObject;
322 }
323 if (!ret.isOk()) {
324 LOG(ERROR) << "executeSynchronously_1_3 failure: " << ret.description();
325 return kFailure;
326 }
327 return result;
328 }
329
330 // version 1.2 HAL
331 if (mPreparedModelV1_2 != nullptr) {
332 const bool compliant = compliantWithV1_2(request);
333 if (!compliant) {
334 LOG(ERROR) << "Could not handle executeSynchronously!";
335 return kFailure;
336 }
337 const V1_0::Request request12 = convertToV1_2(request);
338
339 std::tuple<int, std::vector<OutputShape>, Timing> result;
340 Return<void> ret = mPreparedModelV1_2->executeSynchronously(
341 request12, measure,
342 [&result](V1_0::ErrorStatus error, const hidl_vec<OutputShape>& outputShapes,
343 const Timing& timing) {
344 result = getExecutionResult(convertToV1_3(error), outputShapes, timing);
345 });
346 if (ret.isDeadObject()) {
347 LOG(ERROR) << "executeSynchronously failure: " << ret.description();
348 return kDeadObject;
349 }
350 if (!ret.isOk()) {
351 LOG(ERROR) << "executeSynchronously failure: " << ret.description();
352 return kFailure;
353 }
354 return result;
355 }
356
357 // Fallback to asynchronous execution.
358 return executeAsynchronously(request, measure, deadline, loopTimeoutDuration);
359 }
360
execute(const Request & request,MeasureTiming measure,const std::optional<Deadline> & deadline,const OptionalTimeoutDuration & loopTimeoutDuration,bool preferSynchronous) const361 std::tuple<int, std::vector<OutputShape>, Timing> VersionedIPreparedModel::execute(
362 const Request& request, MeasureTiming measure, const std::optional<Deadline>& deadline,
363 const OptionalTimeoutDuration& loopTimeoutDuration, bool preferSynchronous) const {
364 if (preferSynchronous) {
365 VLOG(EXECUTION) << "Before executeSynchronously() " << SHOW_IF_DEBUG(toString(request));
366 return executeSynchronously(request, measure, deadline, loopTimeoutDuration);
367 }
368
369 VLOG(EXECUTION) << "Before executeAsynchronously() " << SHOW_IF_DEBUG(toString(request));
370 return executeAsynchronously(request, measure, deadline, loopTimeoutDuration);
371 }
372
373 // This is the amount of time the ExecutionBurstController should spend polling
374 // the FMQ to see if it has data available before it should fall back to
375 // waiting on the futex.
getPollingTimeWindow()376 static std::chrono::microseconds getPollingTimeWindow() {
377 constexpr int32_t defaultPollingTimeWindow = 50;
378 #ifdef NN_DEBUGGABLE
379 constexpr int32_t minPollingTimeWindow = 0;
380 const int32_t selectedPollingTimeWindow =
381 base::GetIntProperty("debug.nn.burst-conrtoller-polling-window",
382 defaultPollingTimeWindow, minPollingTimeWindow);
383 return std::chrono::microseconds{selectedPollingTimeWindow};
384 #else
385 return std::chrono::microseconds{defaultPollingTimeWindow};
386 #endif // NN_DEBUGGABLE
387 }
388
configureExecutionBurst(bool preferPowerOverLatency) const389 std::shared_ptr<ExecutionBurstController> VersionedIPreparedModel::configureExecutionBurst(
390 bool preferPowerOverLatency) const {
391 if (mPreparedModelV1_2 == nullptr) {
392 return nullptr;
393 }
394 const auto pollingTimeWindow =
395 (preferPowerOverLatency ? std::chrono::microseconds{0} : getPollingTimeWindow());
396 return ExecutionBurstController::create(mPreparedModelV1_2, pollingTimeWindow);
397 }
398
getCapabilitiesFunction(V1_3::IDevice * device)399 static std::pair<ErrorStatus, Capabilities> getCapabilitiesFunction(V1_3::IDevice* device) {
400 CHECK(device != nullptr);
401 NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_INITIALIZATION, "getCapabilities_1_3");
402 const std::pair<ErrorStatus, Capabilities> kFailure = {ErrorStatus::GENERAL_FAILURE, {}};
403 std::pair<ErrorStatus, Capabilities> result = kFailure;
404 const Return<void> ret = device->getCapabilities_1_3(
405 [&result](ErrorStatus error, const Capabilities& capabilities) {
406 result = std::make_pair(error, capabilities);
407 });
408 if (!ret.isOk()) {
409 LOG(ERROR) << "getCapabilities_1_3 failure: " << ret.description();
410 return kFailure;
411 }
412 return result;
413 }
414
415 std::tuple<int, hal::hidl_handle, sp<hal::IFencedExecutionCallback>, hal::Timing>
executeFenced(const hal::Request & request,const hal::hidl_vec<hal::hidl_handle> & waitFor,MeasureTiming measure,const std::optional<Deadline> & deadline,const OptionalTimeoutDuration & loopTimeoutDuration,const hal::OptionalTimeoutDuration & timeoutDurationAfterFence)416 VersionedIPreparedModel::executeFenced(
417 const hal::Request& request, const hal::hidl_vec<hal::hidl_handle>& waitFor,
418 MeasureTiming measure, const std::optional<Deadline>& deadline,
419 const OptionalTimeoutDuration& loopTimeoutDuration,
420 const hal::OptionalTimeoutDuration& timeoutDurationAfterFence) {
421 // version 1.3+ HAL
422 hal::hidl_handle syncFence;
423 sp<hal::IFencedExecutionCallback> dispatchCallback;
424 hal::Timing timing = {UINT64_MAX, UINT64_MAX};
425 if (mPreparedModelV1_3 != nullptr) {
426 ErrorStatus errorStatus;
427 const auto otp = makeTimePoint(deadline);
428 Return<void> ret = mPreparedModelV1_3->executeFenced(
429 request, waitFor, measure, otp, loopTimeoutDuration, timeoutDurationAfterFence,
430 [&syncFence, &errorStatus, &dispatchCallback](
431 ErrorStatus error, const hidl_handle& handle,
432 const sp<hal::IFencedExecutionCallback>& callback) {
433 syncFence = handle;
434 errorStatus = error;
435 dispatchCallback = callback;
436 });
437 if (!ret.isOk()) {
438 LOG(ERROR) << "executeFenced failure: " << ret.description();
439 return std::make_tuple(ANEURALNETWORKS_OP_FAILED, hal::hidl_handle(nullptr), nullptr,
440 timing);
441 }
442 if (errorStatus != ErrorStatus::NONE) {
443 LOG(ERROR) << "executeFenced returned "
444 << toString(static_cast<ErrorStatus>(errorStatus));
445 return std::make_tuple(convertErrorStatusToResultCode(errorStatus),
446 hal::hidl_handle(nullptr), nullptr, timing);
447 }
448 return std::make_tuple(ANEURALNETWORKS_NO_ERROR, syncFence, dispatchCallback, timing);
449 }
450
451 // fallback to synchronous execution if sync_fence is not supported
452 // first wait for all sync fences to be ready.
453 LOG(INFO) << "No drivers able to handle sync fences, falling back to regular execution";
454 for (const auto& fenceHandle : waitFor) {
455 if (!fenceHandle.getNativeHandle()) {
456 return std::make_tuple(ANEURALNETWORKS_BAD_DATA, hal::hidl_handle(nullptr), nullptr,
457 timing);
458 }
459 int syncFd = fenceHandle.getNativeHandle()->data[0];
460 if (syncFd <= 0) {
461 return std::make_tuple(ANEURALNETWORKS_BAD_DATA, hal::hidl_handle(nullptr), nullptr,
462 timing);
463 }
464 auto r = syncWait(syncFd, -1);
465 if (r != FenceState::SIGNALED) {
466 LOG(ERROR) << "syncWait failed, fd: " << syncFd;
467 return std::make_tuple(ANEURALNETWORKS_OP_FAILED, hal::hidl_handle(nullptr), nullptr,
468 timing);
469 }
470 }
471 int errorCode;
472 std::tie(errorCode, std::ignore, timing) =
473 executeSynchronously(request, measure, deadline, loopTimeoutDuration);
474 return std::make_tuple(errorCode, hal::hidl_handle(nullptr), nullptr, timing);
475 }
476
getCapabilitiesFunction(V1_2::IDevice * device)477 static std::pair<ErrorStatus, Capabilities> getCapabilitiesFunction(V1_2::IDevice* device) {
478 CHECK(device != nullptr);
479 NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_INITIALIZATION, "getCapabilities_1_2");
480 const std::pair<ErrorStatus, Capabilities> kFailure = {ErrorStatus::GENERAL_FAILURE, {}};
481 std::pair<ErrorStatus, Capabilities> result = kFailure;
482 const Return<void> ret = device->getCapabilities_1_2(
483 [&result](V1_0::ErrorStatus error, const V1_2::Capabilities& capabilities) {
484 result = std::make_pair(convertToV1_3(error), convertToV1_3(capabilities));
485 });
486 if (!ret.isOk()) {
487 LOG(ERROR) << "getCapabilities_1_2 failure: " << ret.description();
488 return kFailure;
489 }
490 return result;
491 }
492
getCapabilitiesFunction(V1_1::IDevice * device)493 static std::pair<ErrorStatus, Capabilities> getCapabilitiesFunction(V1_1::IDevice* device) {
494 CHECK(device != nullptr);
495 NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_INITIALIZATION, "getCapabilities_1_1");
496 const std::pair<ErrorStatus, Capabilities> kFailure = {ErrorStatus::GENERAL_FAILURE, {}};
497 std::pair<ErrorStatus, Capabilities> result = kFailure;
498 const Return<void> ret = device->getCapabilities_1_1(
499 [&result](V1_0::ErrorStatus error, const V1_1::Capabilities& capabilities) {
500 // Time taken to convert capabilities is trivial
501 result = std::make_pair(convertToV1_3(error), convertToV1_3(capabilities));
502 });
503 if (!ret.isOk()) {
504 LOG(ERROR) << "getCapabilities_1_1 failure: " << ret.description();
505 return kFailure;
506 }
507 return result;
508 }
509
getCapabilitiesFunction(V1_0::IDevice * device)510 static std::pair<ErrorStatus, Capabilities> getCapabilitiesFunction(V1_0::IDevice* device) {
511 CHECK(device != nullptr);
512 NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_INITIALIZATION, "getCapabilities");
513 const std::pair<ErrorStatus, Capabilities> kFailure = {ErrorStatus::GENERAL_FAILURE, {}};
514 std::pair<ErrorStatus, Capabilities> result = kFailure;
515 const Return<void> ret = device->getCapabilities(
516 [&result](V1_0::ErrorStatus error, const V1_0::Capabilities& capabilities) {
517 // Time taken to convert capabilities is trivial
518 result = std::make_pair(convertToV1_3(error), convertToV1_3(capabilities));
519 });
520 if (!ret.isOk()) {
521 LOG(ERROR) << "getCapabilities failure: " << ret.description();
522 return kFailure;
523 }
524 return result;
525 }
526
getSupportedExtensionsFunction(V1_2::IDevice * device)527 static std::pair<ErrorStatus, hidl_vec<Extension>> getSupportedExtensionsFunction(
528 V1_2::IDevice* device) {
529 CHECK(device != nullptr);
530 NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_INITIALIZATION, "getSupportedExtensions");
531 const std::pair<ErrorStatus, hidl_vec<Extension>> kFailure = {ErrorStatus::GENERAL_FAILURE, {}};
532 std::pair<ErrorStatus, hidl_vec<Extension>> result = kFailure;
533 const Return<void> ret = device->getSupportedExtensions(
534 [&result](V1_0::ErrorStatus error, const hidl_vec<Extension>& extensions) {
535 result = std::make_pair(convertToV1_3(error), extensions);
536 });
537 if (!ret.isOk()) {
538 LOG(ERROR) << "getSupportedExtensions failure: " << ret.description();
539 return kFailure;
540 }
541 return result;
542 }
543
getSupportedExtensionsFunction(V1_0::IDevice * device)544 static std::pair<ErrorStatus, hidl_vec<Extension>> getSupportedExtensionsFunction(
545 V1_0::IDevice* device) {
546 CHECK(device != nullptr);
547 return {ErrorStatus::NONE, {/* No extensions. */}};
548 }
549
getTypeFunction(V1_2::IDevice * device)550 static int32_t getTypeFunction(V1_2::IDevice* device) {
551 CHECK(device != nullptr);
552 constexpr int32_t kFailure = -1;
553 int32_t result = kFailure;
554 const Return<void> ret =
555 device->getType([&result](V1_0::ErrorStatus error, DeviceType deviceType) {
556 if (error == V1_0::ErrorStatus::NONE) {
557 result = static_cast<int32_t>(deviceType);
558 }
559 });
560 if (!ret.isOk()) {
561 LOG(ERROR) << "getType failure: " << ret.description();
562 return kFailure;
563 }
564 return result;
565 }
566
getTypeFunction(V1_0::IDevice * device)567 static int32_t getTypeFunction(V1_0::IDevice* device) {
568 CHECK(device != nullptr);
569 return ANEURALNETWORKS_DEVICE_UNKNOWN;
570 }
571
getVersionStringFunction(V1_2::IDevice * device)572 static std::pair<ErrorStatus, hidl_string> getVersionStringFunction(V1_2::IDevice* device) {
573 CHECK(device != nullptr);
574 const std::pair<ErrorStatus, hidl_string> kFailure = {ErrorStatus::GENERAL_FAILURE, ""};
575 std::pair<ErrorStatus, hidl_string> result = kFailure;
576 const Return<void> ret = device->getVersionString(
577 [&result](V1_0::ErrorStatus error, const hidl_string& version) {
578 result = std::make_pair(convertToV1_3(error), version);
579 });
580 if (!ret.isOk()) {
581 LOG(ERROR) << "getVersion failure: " << ret.description();
582 return kFailure;
583 }
584 return result;
585 }
586
getVersionStringFunction(V1_0::IDevice * device)587 static std::pair<ErrorStatus, hidl_string> getVersionStringFunction(V1_0::IDevice* device) {
588 CHECK(device != nullptr);
589 return {ErrorStatus::NONE, "UNKNOWN"};
590 }
591
getNumberOfCacheFilesNeededFunction(V1_2::IDevice * device)592 static std::tuple<ErrorStatus, uint32_t, uint32_t> getNumberOfCacheFilesNeededFunction(
593 V1_2::IDevice* device) {
594 CHECK(device != nullptr);
595 constexpr std::tuple<ErrorStatus, uint32_t, uint32_t> kFailure = {ErrorStatus::GENERAL_FAILURE,
596 0, 0};
597 std::tuple<ErrorStatus, uint32_t, uint32_t> result = kFailure;
598 const Return<void> ret = device->getNumberOfCacheFilesNeeded(
599 [&result](V1_0::ErrorStatus error, uint32_t numModelCache, uint32_t numDataCache) {
600 result = {convertToV1_3(error), numModelCache, numDataCache};
601 });
602 if (!ret.isOk()) {
603 LOG(ERROR) << "getNumberOfCacheFilesNeeded failure: " << ret.description();
604 return kFailure;
605 }
606 return result;
607 }
608
getNumberOfCacheFilesNeededFunction(V1_0::IDevice * device)609 static std::tuple<ErrorStatus, uint32_t, uint32_t> getNumberOfCacheFilesNeededFunction(
610 V1_0::IDevice* device) {
611 CHECK(device != nullptr);
612 return {ErrorStatus::NONE, 0, 0};
613 }
614
615 struct InitialData {
616 hal::Capabilities capabilities;
617 hal::hidl_vec<hal::Extension> supportedExtensions;
618 int32_t type;
619 hal::hidl_string versionString;
620 std::pair<uint32_t, uint32_t> numberOfCacheFilesNeeded;
621 };
622
623 template <typename Device>
initializeFunction(Device * device)624 static std::optional<InitialData> initializeFunction(Device* device) {
625 CHECK(device != nullptr);
626
627 auto [capabilitiesStatus, capabilities] = getCapabilitiesFunction(device);
628 if (capabilitiesStatus != ErrorStatus::NONE) {
629 LOG(ERROR) << "IDevice::getCapabilities* returned the error "
630 << toString(capabilitiesStatus);
631 return std::nullopt;
632 }
633 VLOG(MANAGER) << "Capab " << toString(capabilities);
634
635 auto [versionStatus, versionString] = getVersionStringFunction(device);
636 if (versionStatus != ErrorStatus::NONE) {
637 LOG(ERROR) << "IDevice::getVersionString returned the error " << toString(versionStatus);
638 return std::nullopt;
639 }
640
641 const int32_t type = getTypeFunction(device);
642 if (type == -1) {
643 LOG(ERROR) << "IDevice::getType returned an error";
644 return std::nullopt;
645 }
646
647 auto [extensionsStatus, supportedExtensions] = getSupportedExtensionsFunction(device);
648 if (extensionsStatus != ErrorStatus::NONE) {
649 LOG(ERROR) << "IDevice::getSupportedExtensions returned the error "
650 << toString(extensionsStatus);
651 return std::nullopt;
652 }
653
654 const auto [cacheFilesStatus, numModelCacheFiles, numDataCacheFiles] =
655 getNumberOfCacheFilesNeededFunction(device);
656 if (cacheFilesStatus != ErrorStatus::NONE) {
657 LOG(ERROR) << "IDevice::getNumberOfCacheFilesNeeded returned the error "
658 << toString(cacheFilesStatus);
659 return std::nullopt;
660 }
661
662 // The following limit is enforced by VTS
663 constexpr uint32_t maxNumCacheFiles =
664 static_cast<uint32_t>(Constant::MAX_NUMBER_OF_CACHE_FILES);
665 if (numModelCacheFiles > maxNumCacheFiles || numDataCacheFiles > maxNumCacheFiles) {
666 LOG(ERROR)
667 << "IDevice::getNumberOfCacheFilesNeeded returned invalid number of cache files: "
668 "numModelCacheFiles = "
669 << numModelCacheFiles << ", numDataCacheFiles = " << numDataCacheFiles
670 << ", maxNumCacheFiles = " << maxNumCacheFiles;
671 return std::nullopt;
672 }
673
674 return InitialData{
675 /*.capabilities=*/std::move(capabilities),
676 /*.supportedExtensions=*/std::move(supportedExtensions),
677 /*.type=*/type,
678 /*.versionString=*/std::move(versionString),
679 /*.numberOfCacheFilesNeeded=*/{numModelCacheFiles, numDataCacheFiles},
680 };
681 }
682
683 template <typename Core>
initialize(const Core & core)684 std::optional<InitialData> initialize(const Core& core) {
685 // version 1.3+ HAL
686 if (const auto device = core.template getDevice<V1_3::IDevice>()) {
687 return initializeFunction(device.get());
688 }
689
690 // version 1.2 HAL
691 if (const auto device = core.template getDevice<V1_2::IDevice>()) {
692 return initializeFunction(device.get());
693 }
694
695 // version 1.1 HAL
696 if (const auto device = core.template getDevice<V1_1::IDevice>()) {
697 return initializeFunction(device.get());
698 }
699
700 // version 1.0 HAL
701 if (const auto device = core.template getDevice<V1_0::IDevice>()) {
702 return initializeFunction(device.get());
703 }
704
705 // No device available
706 LOG(ERROR) << "Device not available!";
707 return std::nullopt;
708 }
709
create(std::string serviceName,const DeviceFactory & makeDevice)710 std::shared_ptr<VersionedIDevice> VersionedIDevice::create(std::string serviceName,
711 const DeviceFactory& makeDevice) {
712 CHECK(makeDevice != nullptr)
713 << "VersionedIDevice::create passed invalid device factory object.";
714
715 // get handle to IDevice object
716 sp<V1_0::IDevice> device = makeDevice(/*blocking=*/true);
717 if (device == nullptr) {
718 VLOG(DRIVER) << "VersionedIDevice::create got a null IDevice for " << serviceName;
719 return nullptr;
720 }
721
722 auto core = Core::create(std::move(device));
723 if (!core.has_value()) {
724 LOG(ERROR) << "VersionedIDevice::create failed to create Core.";
725 return nullptr;
726 }
727
728 auto initialData = initialize(*core);
729 if (!initialData.has_value()) {
730 LOG(ERROR) << "VersionedIDevice::create failed to initialize.";
731 return nullptr;
732 }
733
734 auto [capabilities, supportedExtensions, type, versionString, numberOfCacheFilesNeeded] =
735 std::move(*initialData);
736 return std::make_shared<VersionedIDevice>(
737 std::move(capabilities), std::move(supportedExtensions), type, std::move(versionString),
738 numberOfCacheFilesNeeded, std::move(serviceName), makeDevice, std::move(core.value()));
739 }
740
VersionedIDevice(hal::Capabilities capabilities,std::vector<hal::Extension> supportedExtensions,int32_t type,std::string versionString,std::pair<uint32_t,uint32_t> numberOfCacheFilesNeeded,std::string serviceName,const DeviceFactory & makeDevice,Core core)741 VersionedIDevice::VersionedIDevice(hal::Capabilities capabilities,
742 std::vector<hal::Extension> supportedExtensions, int32_t type,
743 std::string versionString,
744 std::pair<uint32_t, uint32_t> numberOfCacheFilesNeeded,
745 std::string serviceName, const DeviceFactory& makeDevice,
746 Core core)
747 : kCapabilities(std::move(capabilities)),
748 kSupportedExtensions(std::move(supportedExtensions)),
749 kType(type),
750 kVersionString(std::move(versionString)),
751 kNumberOfCacheFilesNeeded(numberOfCacheFilesNeeded),
752 kServiceName(std::move(serviceName)),
753 kMakeDevice(makeDevice),
754 mCore(std::move(core)) {}
755
create(sp<V1_0::IDevice> device)756 std::optional<VersionedIDevice::Core> VersionedIDevice::Core::create(sp<V1_0::IDevice> device) {
757 CHECK(device != nullptr) << "VersionedIDevice::Core::create passed invalid device object.";
758
759 // create death handler object
760 sp<IDeviceDeathHandler> deathHandler = new IDeviceDeathHandler();
761
762 // linkToDeath registers a callback that will be invoked on service death to
763 // proactively handle service crashes. If the linkToDeath call fails,
764 // asynchronous calls are susceptible to hangs if the service crashes before
765 // providing the response.
766 const Return<bool> ret = device->linkToDeath(deathHandler, 0);
767 if (!ret.isOk()) {
768 LOG(ERROR) << "VersionedIDevice::Core::create failed to register a death recipient for the "
769 "IDevice object because of failure: "
770 << ret.description();
771 return {};
772 }
773 if (ret != true) {
774 LOG(ERROR) << "VersionedIDevice::Core::create failed to register a death recipient for the "
775 "IDevice object.";
776 return {};
777 }
778
779 // return a valid Core object
780 return Core(std::move(device), std::move(deathHandler));
781 }
782
783 // HIDL guarantees all V1_1 interfaces inherit from their corresponding V1_0 interfaces.
Core(sp<V1_0::IDevice> device,sp<IDeviceDeathHandler> deathHandler)784 VersionedIDevice::Core::Core(sp<V1_0::IDevice> device, sp<IDeviceDeathHandler> deathHandler)
785 : mDeviceV1_0(std::move(device)),
786 mDeviceV1_1(V1_1::IDevice::castFrom(mDeviceV1_0).withDefault(nullptr)),
787 mDeviceV1_2(V1_2::IDevice::castFrom(mDeviceV1_0).withDefault(nullptr)),
788 mDeviceV1_3(V1_3::IDevice::castFrom(mDeviceV1_0).withDefault(nullptr)),
789 mDeathHandler(std::move(deathHandler)) {}
790
~Core()791 VersionedIDevice::Core::~Core() {
792 if (mDeathHandler != nullptr) {
793 CHECK(mDeviceV1_0 != nullptr);
794 // It is safe to ignore any errors resulting from this unlinkToDeath call
795 // because the VersionedIDevice::Core object is already being destroyed and
796 // its underlying IDevice object is no longer being used by the NN runtime.
797 mDeviceV1_0->unlinkToDeath(mDeathHandler).isOk();
798 }
799 }
800
Core(Core && other)801 VersionedIDevice::Core::Core(Core&& other) noexcept
802 : mDeviceV1_0(std::move(other.mDeviceV1_0)),
803 mDeviceV1_1(std::move(other.mDeviceV1_1)),
804 mDeviceV1_2(std::move(other.mDeviceV1_2)),
805 mDeviceV1_3(std::move(other.mDeviceV1_3)),
806 mDeathHandler(std::move(other.mDeathHandler)) {
807 other.mDeathHandler = nullptr;
808 }
809
operator =(Core && other)810 VersionedIDevice::Core& VersionedIDevice::Core::operator=(Core&& other) noexcept {
811 if (this != &other) {
812 mDeviceV1_0 = std::move(other.mDeviceV1_0);
813 mDeviceV1_1 = std::move(other.mDeviceV1_1);
814 mDeviceV1_2 = std::move(other.mDeviceV1_2);
815 mDeviceV1_3 = std::move(other.mDeviceV1_3);
816 mDeathHandler = std::move(other.mDeathHandler);
817 other.mDeathHandler = nullptr;
818 }
819 return *this;
820 }
821
822 template <typename T_IDevice>
getDeviceAndDeathHandler() const823 std::pair<sp<T_IDevice>, sp<IDeviceDeathHandler>> VersionedIDevice::Core::getDeviceAndDeathHandler()
824 const {
825 return {getDevice<T_IDevice>(), mDeathHandler};
826 }
827
828 template <typename T_Return, typename T_IDevice, typename T_Callback>
callProtected(const char * context,const std::function<Return<T_Return> (const sp<T_IDevice> &)> & fn,const sp<T_IDevice> & device,const sp<T_Callback> & callback,const sp<IDeviceDeathHandler> & deathHandler)829 Return<T_Return> callProtected(const char* context,
830 const std::function<Return<T_Return>(const sp<T_IDevice>&)>& fn,
831 const sp<T_IDevice>& device, const sp<T_Callback>& callback,
832 const sp<IDeviceDeathHandler>& deathHandler) {
833 const auto scoped = deathHandler->protectCallback(callback);
834 Return<T_Return> ret = fn(device);
835 // Suppose there was a transport error. We have the following cases:
836 // 1. Either not due to a dead device, or due to a device that was
837 // already dead at the time of the call to protectCallback(). In
838 // this case, the callback was never signalled.
839 // 2. Due to a device that died after the call to protectCallback() but
840 // before fn() completed. In this case, the callback was (or will
841 // be) signalled by the deathHandler.
842 // Furthermore, what if there was no transport error, but the ErrorStatus is
843 // other than NONE? We'll conservatively signal the callback anyway, just in
844 // case the driver was sloppy and failed to do so.
845 if (!ret.isOk() || ret != T_Return::NONE) {
846 // What if the deathHandler has signalled or will signal the callback?
847 // This is fine -- we're permitted to signal multiple times; and we're
848 // sending the same signal that the deathHandler does.
849 //
850 // What if the driver signalled the callback? Then this signal is
851 // ignored.
852
853 if (ret.isOk()) {
854 LOG(ERROR) << context << " returned " << toString(static_cast<T_Return>(ret));
855 } else {
856 LOG(ERROR) << context << " failure: " << ret.description();
857 }
858 sendFailureMessage(callback.get());
859 }
860 callback->wait();
861 return ret;
862 }
863 template <typename T_Return, typename T_IDevice>
callProtected(const char *,const std::function<Return<T_Return> (const sp<T_IDevice> &)> & fn,const sp<T_IDevice> & device,const std::nullptr_t &,const sp<IDeviceDeathHandler> &)864 Return<T_Return> callProtected(const char*,
865 const std::function<Return<T_Return>(const sp<T_IDevice>&)>& fn,
866 const sp<T_IDevice>& device, const std::nullptr_t&,
867 const sp<IDeviceDeathHandler>&) {
868 return fn(device);
869 }
870
871 template <typename T_Return, typename T_IDevice, typename T_Callback>
recoverable(const char * context,const std::function<Return<T_Return> (const sp<T_IDevice> &)> & fn,const T_Callback & callback) const872 Return<T_Return> VersionedIDevice::recoverable(
873 const char* context, const std::function<Return<T_Return>(const sp<T_IDevice>&)>& fn,
874 const T_Callback& callback) const EXCLUDES(mMutex) {
875 CHECK_EQ(callback == nullptr, (std::is_same_v<T_Callback, std::nullptr_t>));
876
877 sp<T_IDevice> device;
878 sp<IDeviceDeathHandler> deathHandler;
879 std::tie(device, deathHandler) = getDeviceAndDeathHandler<T_IDevice>();
880
881 Return<T_Return> ret = callProtected(context, fn, device, callback, deathHandler);
882
883 if (ret.isDeadObject()) {
884 {
885 std::unique_lock lock(mMutex);
886 // It's possible that another device has already done the recovery.
887 // It's harmless but wasteful for us to do so in this case.
888 auto pingReturn = mCore.getDevice<T_IDevice>()->ping();
889 if (pingReturn.isDeadObject()) {
890 VLOG(DRIVER) << "VersionedIDevice::recoverable(" << context << ") -- Recovering "
891 << kServiceName;
892 sp<V1_0::IDevice> recoveredDevice = kMakeDevice(/*blocking=*/false);
893 if (recoveredDevice == nullptr) {
894 VLOG(DRIVER) << "VersionedIDevice::recoverable got a null IDEVICE for "
895 << kServiceName;
896 return ret;
897 }
898
899 auto core = Core::create(std::move(recoveredDevice));
900 if (!core.has_value()) {
901 LOG(ERROR) << "VersionedIDevice::recoverable failed to create Core.";
902 return ret;
903 }
904
905 mCore = std::move(core.value());
906 } else {
907 VLOG(DRIVER) << "VersionedIDevice::recoverable(" << context
908 << ") -- Someone else recovered " << kServiceName;
909 // Might still have a transport error, which we need to check
910 // before pingReturn goes out of scope.
911 (void)pingReturn.isOk();
912 }
913 std::tie(device, deathHandler) = mCore.getDeviceAndDeathHandler<T_IDevice>();
914 }
915 ret = callProtected(context, fn, device, callback, deathHandler);
916 // It's possible that the device died again, but we're only going to
917 // attempt recovery once per call to recoverable().
918 }
919 return ret;
920 }
921
wait() const922 int VersionedIDevice::wait() const {
923 std::unique_lock lock(mMutex);
924 // It's possible that another device has already done the recovery.
925 // It's harmless but wasteful for us to do so in this case.
926 auto pingReturn = mCore.getDevice<V1_0::IDevice>()->ping();
927 if (pingReturn.isDeadObject()) {
928 VLOG(DRIVER) << "VersionedIDevice::wait -- Recovering " << kServiceName;
929 sp<V1_0::IDevice> recoveredDevice = kMakeDevice(/*blocking=*/true);
930 if (recoveredDevice == nullptr) {
931 LOG(ERROR) << "VersionedIDevice::wait got a null IDevice for " << kServiceName;
932 return ANEURALNETWORKS_OP_FAILED;
933 }
934
935 auto core = Core::create(std::move(recoveredDevice));
936 if (!core.has_value()) {
937 LOG(ERROR) << "VersionedIDevice::wait failed to create Core.";
938 return ANEURALNETWORKS_OP_FAILED;
939 }
940
941 mCore = std::move(core.value());
942 } else if (!pingReturn.isOk()) {
943 LOG(ERROR) << "VersionedIDevice::wait failed -- IDevice::ping returned "
944 << pingReturn.description();
945 return ANEURALNETWORKS_OP_FAILED;
946 }
947
948 return ANEURALNETWORKS_NO_ERROR;
949 }
950
getCapabilities() const951 const Capabilities& VersionedIDevice::getCapabilities() const {
952 return kCapabilities;
953 }
954
getSupportedExtensions() const955 const std::vector<Extension>& VersionedIDevice::getSupportedExtensions() const {
956 return kSupportedExtensions;
957 }
958
getSupportedOperations(const MetaModel & metaModel) const959 std::pair<ErrorStatus, hidl_vec<bool>> VersionedIDevice::getSupportedOperations(
960 const MetaModel& metaModel) const {
961 const std::pair<ErrorStatus, hidl_vec<bool>> kFailure = {ErrorStatus::GENERAL_FAILURE, {}};
962 std::pair<ErrorStatus, hidl_vec<bool>> result;
963
964 const Model& model = metaModel.getModel();
965
966 auto noneSupported = [&model] {
967 hidl_vec<bool> supported(model.main.operations.size());
968 std::fill(supported.begin(), supported.end(), false);
969 return std::make_pair(ErrorStatus::NONE, std::move(supported));
970 };
971
972 auto remappedResult = [&model](const std::pair<ErrorStatus, hidl_vec<bool>>& result,
973 const std::function<uint32_t(uint32_t)>&
974 slicedModelOperationIndexToModelOperationIndex) {
975 const ErrorStatus status = result.first;
976 const hidl_vec<bool>& supported = result.second;
977 hidl_vec<bool> remappedSupported(model.main.operations.size());
978 std::fill(remappedSupported.begin(), remappedSupported.end(), false);
979 for (size_t i = 0; i < supported.size(); ++i) {
980 if (supported[i]) {
981 remappedSupported[slicedModelOperationIndexToModelOperationIndex(i)] = true;
982 }
983 }
984 return std::make_pair(status, std::move(remappedSupported));
985 };
986
987 // version 1.3+ HAL
988 if (getDevice<V1_3::IDevice>() != nullptr) {
989 NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_COMPILATION, "getSupportedOperations_1_3");
990 Return<void> ret = recoverable<void, V1_3::IDevice>(
991 __FUNCTION__, [&model, &result](const sp<V1_3::IDevice>& device) {
992 return device->getSupportedOperations_1_3(
993 model, [&result](ErrorStatus error, const hidl_vec<bool>& supported) {
994 result = std::make_pair(error, supported);
995 });
996 });
997 if (!ret.isOk()) {
998 LOG(ERROR) << "getSupportedOperations_1_3 failure: " << ret.description();
999 return kFailure;
1000 }
1001 return result;
1002 }
1003
1004 // version 1.2 HAL
1005 if (getDevice<V1_2::IDevice>() != nullptr) {
1006 const bool compliant = compliantWithV1_2(model);
1007 V1_2::Model model12;
1008 std::function<uint32_t(uint32_t)> slicedModelOperationIndexToModelOperationIndex;
1009 if (compliant) {
1010 model12 = convertToV1_2(model);
1011 } else {
1012 const auto slice12 = metaModel.getSliceV1_2();
1013 if (!slice12.has_value()) {
1014 return noneSupported();
1015 }
1016 std::tie(model12, slicedModelOperationIndexToModelOperationIndex) = *slice12;
1017 }
1018 NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_COMPILATION, "getSupportedOperations_1_2");
1019 Return<void> ret = recoverable<void, V1_2::IDevice>(
1020 __FUNCTION__, [&model12, &result](const sp<V1_2::IDevice>& device) {
1021 return device->getSupportedOperations_1_2(
1022 model12,
1023 [&result](V1_0::ErrorStatus error, const hidl_vec<bool>& supported) {
1024 result = std::make_pair(convertToV1_3(error), supported);
1025 });
1026 });
1027 if (!ret.isOk()) {
1028 LOG(ERROR) << "getSupportedOperations_1_2 failure: " << ret.description();
1029 return kFailure;
1030 }
1031 if (!compliant) {
1032 return remappedResult(result, slicedModelOperationIndexToModelOperationIndex);
1033 }
1034 return result;
1035 }
1036
1037 // version 1.1 HAL
1038 if (getDevice<V1_1::IDevice>() != nullptr) {
1039 const bool compliant = compliantWithV1_1(model);
1040 V1_1::Model model11;
1041 std::function<uint32_t(uint32_t)> slicedModelOperationIndexToModelOperationIndex;
1042 if (compliant) {
1043 model11 = convertToV1_1(model);
1044 } else {
1045 const auto slice11 = metaModel.getSliceV1_1();
1046 if (!slice11.has_value()) {
1047 return noneSupported();
1048 }
1049 std::tie(model11, slicedModelOperationIndexToModelOperationIndex) = *slice11;
1050 }
1051 NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_COMPILATION, "getSupportedOperations_1_1");
1052 Return<void> ret = recoverable<void, V1_1::IDevice>(
1053 __FUNCTION__, [&model11, &result](const sp<V1_1::IDevice>& device) {
1054 return device->getSupportedOperations_1_1(
1055 model11,
1056 [&result](V1_0::ErrorStatus error, const hidl_vec<bool>& supported) {
1057 result = std::make_pair(convertToV1_3(error), supported);
1058 });
1059 });
1060 if (!ret.isOk()) {
1061 LOG(ERROR) << "getSupportedOperations_1_1 failure: " << ret.description();
1062 return kFailure;
1063 }
1064 if (!compliant) {
1065 return remappedResult(result, slicedModelOperationIndexToModelOperationIndex);
1066 }
1067 return result;
1068 }
1069
1070 // version 1.0 HAL
1071 if (getDevice<V1_0::IDevice>() != nullptr) {
1072 const bool compliant = compliantWithV1_0(model);
1073 V1_0::Model model10;
1074 std::function<uint32_t(uint32_t)> slicedModelOperationIndexToModelOperationIndex;
1075 if (compliant) {
1076 model10 = convertToV1_0(model);
1077 } else {
1078 const auto slice10 = metaModel.getSliceV1_0();
1079 if (!slice10.has_value()) {
1080 return noneSupported();
1081 }
1082 std::tie(model10, slicedModelOperationIndexToModelOperationIndex) = *slice10;
1083 }
1084 NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_COMPILATION, "getSupportedOperations");
1085 Return<void> ret = recoverable<void, V1_0::IDevice>(
1086 __FUNCTION__, [&model10, &result](const sp<V1_0::IDevice>& device) {
1087 return device->getSupportedOperations(
1088 model10,
1089 [&result](V1_0::ErrorStatus error, const hidl_vec<bool>& supported) {
1090 result = std::make_pair(convertToV1_3(error), supported);
1091 });
1092 });
1093 if (!ret.isOk()) {
1094 LOG(ERROR) << "getSupportedOperations failure: " << ret.description();
1095 return kFailure;
1096 }
1097 if (!compliant) {
1098 return remappedResult(result, slicedModelOperationIndexToModelOperationIndex);
1099 }
1100 return result;
1101 }
1102
1103 // No device available
1104 LOG(ERROR) << "Device not available!";
1105 return kFailure;
1106 }
1107
1108 // Opens cache file by filename and sets the handle to the opened fd. Returns false on fail. The
1109 // handle is expected to come in as empty, and is only set to a fd when the function returns true.
1110 // The file descriptor is always opened with both read and write permission.
createCacheHandle(const std::string & cache,bool createIfNotExist,hidl_handle * handle)1111 static bool createCacheHandle(const std::string& cache, bool createIfNotExist,
1112 hidl_handle* handle) {
1113 CHECK(handle->getNativeHandle() == nullptr);
1114 int fd = open(cache.c_str(), createIfNotExist ? (O_RDWR | O_CREAT) : O_RDWR, S_IRUSR | S_IWUSR);
1115 NN_RET_CHECK_GE(fd, 0);
1116 native_handle_t* cacheNativeHandle = native_handle_create(1, 0);
1117 if (cacheNativeHandle == nullptr) {
1118 close(fd);
1119 return false;
1120 }
1121 cacheNativeHandle->data[0] = fd;
1122 handle->setTo(cacheNativeHandle, /*shouldOwn=*/true);
1123 return true;
1124 }
1125
1126 // Opens a list of cache files and returns the handle vector. Returns empty vector on fail.
1127 // The file descriptors are always opened with both read and write permission.
createCacheHandleVec(uint32_t numCacheFiles,const std::string & baseFileName,bool createIfNotExist)1128 static hidl_vec<hidl_handle> createCacheHandleVec(uint32_t numCacheFiles,
1129 const std::string& baseFileName,
1130 bool createIfNotExist) {
1131 CHECK(numCacheFiles <= static_cast<uint32_t>(Constant::MAX_NUMBER_OF_CACHE_FILES));
1132 hidl_vec<hidl_handle> handles(numCacheFiles);
1133 for (uint32_t i = 0; i < numCacheFiles; i++) {
1134 std::string filename = baseFileName + std::to_string(i);
1135 VLOG(COMPILATION) << "Cache " << i << ": " << filename;
1136 if (!createCacheHandle(filename, createIfNotExist, &handles[i])) {
1137 return hidl_vec<hidl_handle>();
1138 }
1139 }
1140 return handles;
1141 }
1142
1143 // Maps token to cache file names and sets the handle vectors to the opened fds. Returns false on
1144 // fail and leaves the vectors empty. Each vector is expected to come in as empty.
getCacheHandles(const std::string & cacheDir,const CacheToken & token,const std::pair<uint32_t,uint32_t> & numCacheFiles,bool createIfNotExist,hidl_vec<hidl_handle> * modelCache,hidl_vec<hidl_handle> * dataCache)1145 static bool getCacheHandles(const std::string& cacheDir, const CacheToken& token,
1146 const std::pair<uint32_t, uint32_t>& numCacheFiles,
1147 bool createIfNotExist, hidl_vec<hidl_handle>* modelCache,
1148 hidl_vec<hidl_handle>* dataCache) {
1149 // The filename includes ANEURALNETWORKS_BYTE_SIZE_OF_CACHE_TOKEN * 2 characters for token,
1150 // and 1 character for model/data cache identifier.
1151 std::string filename(ANEURALNETWORKS_BYTE_SIZE_OF_CACHE_TOKEN * 2 + 1, '0');
1152 for (uint32_t i = 0; i < ANEURALNETWORKS_BYTE_SIZE_OF_CACHE_TOKEN; i++) {
1153 filename[i * 2] = 'A' + (token[i] & 0x0F);
1154 filename[i * 2 + 1] = 'A' + (token[i] >> 4);
1155 }
1156 CHECK(cacheDir.empty() || cacheDir.back() == '/');
1157 std::string cacheFileName = cacheDir + filename;
1158
1159 cacheFileName[ANEURALNETWORKS_BYTE_SIZE_OF_CACHE_TOKEN * 2] = '1';
1160 *modelCache = createCacheHandleVec(numCacheFiles.first, cacheFileName, createIfNotExist);
1161 if (modelCache->size() != numCacheFiles.first) {
1162 return false;
1163 }
1164 cacheFileName[ANEURALNETWORKS_BYTE_SIZE_OF_CACHE_TOKEN * 2] = '2';
1165 *dataCache = createCacheHandleVec(numCacheFiles.second, cacheFileName, createIfNotExist);
1166 if (dataCache->size() != numCacheFiles.second) {
1167 modelCache->resize(0);
1168 return false;
1169 }
1170 return true;
1171 }
1172
prepareModelFailure(ErrorStatus status=ErrorStatus::GENERAL_FAILURE)1173 static std::pair<int, std::shared_ptr<VersionedIPreparedModel>> prepareModelFailure(
1174 ErrorStatus status = ErrorStatus::GENERAL_FAILURE) {
1175 return {convertErrorStatusToResultCode(status), nullptr};
1176 }
1177
prepareModelResult(const PreparedModelCallback & callback,const char * prepareName,const std::string & serviceName)1178 static std::pair<int, std::shared_ptr<VersionedIPreparedModel>> prepareModelResult(
1179 const PreparedModelCallback& callback, const char* prepareName,
1180 const std::string& serviceName) {
1181 callback.wait();
1182 if (callback.isDeadObject()) {
1183 LOG(ERROR) << prepareName << " on " << serviceName
1184 << " failed because the PreparedModel object is dead";
1185 return {ANEURALNETWORKS_DEAD_OBJECT, nullptr};
1186 }
1187 const ErrorStatus status = callback.getStatus();
1188 const sp<V1_0::IPreparedModel> preparedModel = callback.getPreparedModel();
1189
1190 if (status != ErrorStatus::NONE) {
1191 LOG(ERROR) << prepareName << " on " << serviceName << " failed: "
1192 << "prepareReturnStatus=" << toString(status);
1193 return prepareModelFailure(status);
1194 }
1195 if (preparedModel == nullptr) {
1196 LOG(ERROR) << prepareName << " on " << serviceName << " failed: preparedModel is nullptr";
1197 return prepareModelFailure();
1198 }
1199
1200 return makeVersionedIPreparedModel(preparedModel);
1201 }
1202
prepareModelInternal(const Model & model,ExecutionPreference preference,Priority priority,const std::optional<Deadline> & deadline,const std::string & cacheDir,const std::optional<CacheToken> & maybeToken) const1203 std::pair<int, std::shared_ptr<VersionedIPreparedModel>> VersionedIDevice::prepareModelInternal(
1204 const Model& model, ExecutionPreference preference, Priority priority,
1205 const std::optional<Deadline>& deadline, const std::string& cacheDir,
1206 const std::optional<CacheToken>& maybeToken) const {
1207 // Note that some work within VersionedIDevice will be subtracted from the IPC layer
1208 NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_COMPILATION, "prepareModel");
1209 const std::pair<int, std::shared_ptr<VersionedIPreparedModel>> kDeadObject = {
1210 ANEURALNETWORKS_DEAD_OBJECT, nullptr};
1211
1212 // Get cache files if they exist, otherwise create them.
1213 hidl_vec<hidl_handle> modelCache, dataCache;
1214 if (!maybeToken.has_value() ||
1215 !getCacheHandles(cacheDir, *maybeToken, kNumberOfCacheFilesNeeded,
1216 /*createIfNotExist=*/true, &modelCache, &dataCache)) {
1217 modelCache.resize(0);
1218 dataCache.resize(0);
1219 }
1220
1221 // Get the token if it exists, otherwise get a null token.
1222 static const CacheToken kNullToken{};
1223 const CacheToken token = maybeToken.value_or(kNullToken);
1224
1225 const sp<PreparedModelCallback> callback = new PreparedModelCallback();
1226
1227 // If 1.3 device, try preparing model
1228 if (getDevice<V1_3::IDevice>() != nullptr) {
1229 const auto otp = makeTimePoint(deadline);
1230 const Return<ErrorStatus> ret = recoverable<ErrorStatus, V1_3::IDevice>(
1231 __FUNCTION__,
1232 [&model, preference, priority, &otp, &modelCache, &dataCache, &token,
1233 &callback](const sp<V1_3::IDevice>& device) {
1234 return device->prepareModel_1_3(model, preference, priority, otp, modelCache,
1235 dataCache, token, callback);
1236 },
1237 callback);
1238 if (ret.isDeadObject()) {
1239 LOG(ERROR) << "prepareModel_1_3 failure: " << ret.description();
1240 return kDeadObject;
1241 }
1242 if (!ret.isOk()) {
1243 LOG(ERROR) << "prepareModel_1_3 failure: " << ret.description();
1244 return prepareModelFailure();
1245 }
1246 if (ret != ErrorStatus::NONE) {
1247 LOG(ERROR) << "prepareModel_1_3 returned " << toString(static_cast<ErrorStatus>(ret));
1248 return prepareModelFailure(ret);
1249 }
1250 return prepareModelResult(*callback, "prepareModel_1_3", kServiceName);
1251 }
1252
1253 // If 1.2 device, try preparing model (requires conversion)
1254 if (getDevice<V1_2::IDevice>() != nullptr) {
1255 bool compliant = false;
1256 V1_2::Model model12;
1257 {
1258 // Attribute time spent in model inspection and conversion to
1259 // Runtime, as the time may be substantial (0.03ms for mobilenet,
1260 // but could be larger for other models).
1261 NNTRACE_FULL_SUBTRACT(NNTRACE_LAYER_RUNTIME, NNTRACE_PHASE_COMPILATION,
1262 "VersionedIDevice::prepareModel_1_2");
1263 compliant = compliantWithV1_2(model);
1264 if (compliant) {
1265 model12 = convertToV1_2(model); // copy is elided
1266 }
1267 }
1268 if (compliant) {
1269 const Return<V1_0::ErrorStatus> ret = recoverable<V1_0::ErrorStatus, V1_2::IDevice>(
1270 __FUNCTION__,
1271 [&model12, &preference, &modelCache, &dataCache, &token,
1272 &callback](const sp<V1_2::IDevice>& device) {
1273 return device->prepareModel_1_2(model12, preference, modelCache, dataCache,
1274 token, callback);
1275 },
1276 callback);
1277 if (ret.isDeadObject()) {
1278 LOG(ERROR) << "prepareModel_1_2 failure: " << ret.description();
1279 return kDeadObject;
1280 }
1281 if (!ret.isOk()) {
1282 LOG(ERROR) << "prepareModel_1_2 failure: " << ret.description();
1283 return prepareModelFailure();
1284 }
1285 const V1_0::ErrorStatus status = static_cast<V1_0::ErrorStatus>(ret);
1286 if (status != V1_0::ErrorStatus::NONE) {
1287 LOG(ERROR) << "prepareModel_1_2 returned " << toString(status);
1288 return prepareModelFailure(convertToV1_3(status));
1289 }
1290 return prepareModelResult(*callback, "prepareModel_1_2", kServiceName);
1291 }
1292
1293 LOG(ERROR) << "Could not handle prepareModel_1_2!";
1294 return prepareModelFailure();
1295 }
1296
1297 // If 1.1 device, try preparing model (requires conversion)
1298 if (getDevice<V1_1::IDevice>() != nullptr) {
1299 bool compliant = false;
1300 V1_1::Model model11;
1301 {
1302 // Attribute time spent in model inspection and conversion to
1303 // Runtime, as the time may be substantial (0.03ms for mobilenet,
1304 // but could be larger for other models).
1305 NNTRACE_FULL_SUBTRACT(NNTRACE_LAYER_RUNTIME, NNTRACE_PHASE_COMPILATION,
1306 "VersionedIDevice::prepareModel_1_1");
1307 compliant = compliantWithV1_1(model);
1308 if (compliant) {
1309 model11 = convertToV1_1(model); // copy is elided
1310 }
1311 }
1312 if (compliant) {
1313 const Return<V1_0::ErrorStatus> ret = recoverable<V1_0::ErrorStatus, V1_1::IDevice>(
1314 __FUNCTION__,
1315 [&model11, &preference, &callback](const sp<V1_1::IDevice>& device) {
1316 return device->prepareModel_1_1(model11, preference, callback);
1317 },
1318 callback);
1319 if (ret.isDeadObject()) {
1320 LOG(ERROR) << "prepareModel_1_1 failure: " << ret.description();
1321 return kDeadObject;
1322 }
1323 if (!ret.isOk()) {
1324 LOG(ERROR) << "prepareModel_1_1 failure: " << ret.description();
1325 return prepareModelFailure();
1326 }
1327 const V1_0::ErrorStatus status = static_cast<V1_0::ErrorStatus>(ret);
1328 if (status != V1_0::ErrorStatus::NONE) {
1329 LOG(ERROR) << "prepareModel_1_1 returned " << toString(status);
1330 return prepareModelFailure(convertToV1_3(status));
1331 }
1332 return prepareModelResult(*callback, "prepareModel_1_1", kServiceName);
1333 }
1334
1335 LOG(ERROR) << "Could not handle prepareModel_1_1!";
1336 return prepareModelFailure();
1337 }
1338
1339 // If 1.0 device, try preparing model (requires conversion)
1340 if (getDevice<V1_0::IDevice>() != nullptr) {
1341 bool compliant = false;
1342 V1_0::Model model10;
1343 {
1344 // Attribute time spent in model inspection and conversion to
1345 // Runtime, as the time may be substantial (0.03ms for mobilenet,
1346 // but could be larger for other models).
1347 NNTRACE_FULL_SUBTRACT(NNTRACE_LAYER_RUNTIME, NNTRACE_PHASE_COMPILATION,
1348 "VersionedIDevice::prepareModel");
1349 compliant = compliantWithV1_0(model);
1350 if (compliant) {
1351 model10 = convertToV1_0(model); // copy is elided
1352 }
1353 }
1354 if (compliant) {
1355 const Return<V1_0::ErrorStatus> ret = recoverable<V1_0::ErrorStatus, V1_0::IDevice>(
1356 __FUNCTION__,
1357 [&model10, &callback](const sp<V1_0::IDevice>& device) {
1358 return device->prepareModel(model10, callback);
1359 },
1360 callback);
1361 if (ret.isDeadObject()) {
1362 LOG(ERROR) << "prepareModel failure: " << ret.description();
1363 return kDeadObject;
1364 }
1365 if (!ret.isOk()) {
1366 LOG(ERROR) << "prepareModel failure: " << ret.description();
1367 return prepareModelFailure();
1368 }
1369 const V1_0::ErrorStatus status = static_cast<V1_0::ErrorStatus>(ret);
1370 if (status != V1_0::ErrorStatus::NONE) {
1371 LOG(ERROR) << "prepareModel returned " << toString(status);
1372 return prepareModelFailure(convertToV1_3(status));
1373 }
1374 return prepareModelResult(*callback, "prepareModel", kServiceName);
1375 }
1376
1377 LOG(ERROR) << "Could not handle prepareModel!";
1378 return prepareModelFailure();
1379 }
1380
1381 // Return error because there is no valid device
1382 LOG(ERROR) << "prepareModel called with no device";
1383 return prepareModelFailure();
1384 }
1385
1386 std::pair<int, std::shared_ptr<VersionedIPreparedModel>>
prepareModelFromCacheInternal(const std::optional<Deadline> & deadline,const std::string & cacheDir,const CacheToken & token) const1387 VersionedIDevice::prepareModelFromCacheInternal(const std::optional<Deadline>& deadline,
1388 const std::string& cacheDir,
1389 const CacheToken& token) const {
1390 // Note that some work within VersionedIDevice will be subtracted from the IPC layer
1391 NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_COMPILATION, "prepareModelFromCache");
1392 VLOG(COMPILATION) << "prepareModelFromCache";
1393 const std::pair<int, std::shared_ptr<VersionedIPreparedModel>> kDeadObject = {
1394 ANEURALNETWORKS_DEAD_OBJECT, nullptr};
1395
1396 // Get cache files if they exist, otherwise return from the function early.
1397 hidl_vec<hidl_handle> modelCache, dataCache;
1398 if (!getCacheHandles(cacheDir, token, kNumberOfCacheFilesNeeded,
1399 /*createIfNotExist=*/false, &modelCache, &dataCache)) {
1400 return prepareModelFailure();
1401 }
1402
1403 // version 1.3+ HAL
1404 if (getDevice<V1_3::IDevice>() != nullptr) {
1405 const auto otp = makeTimePoint(deadline);
1406 const sp<PreparedModelCallback> callback = new PreparedModelCallback();
1407 const Return<ErrorStatus> ret = recoverable<ErrorStatus, V1_3::IDevice>(
1408 __FUNCTION__,
1409 [&otp, &modelCache, &dataCache, &token,
1410 &callback](const sp<V1_3::IDevice>& device) {
1411 return device->prepareModelFromCache_1_3(otp, modelCache, dataCache, token,
1412 callback);
1413 },
1414 callback);
1415 if (ret.isDeadObject()) {
1416 LOG(ERROR) << "prepareModelFromCache_1_3 failure: " << ret.description();
1417 return kDeadObject;
1418 }
1419 if (!ret.isOk()) {
1420 LOG(ERROR) << "prepareModelFromCache_1_3 failure: " << ret.description();
1421 return prepareModelFailure();
1422 }
1423 if (ret != ErrorStatus::NONE) {
1424 LOG(ERROR) << "prepareModelFromCache_1_3 returned "
1425 << toString(static_cast<ErrorStatus>(ret));
1426 return prepareModelFailure(ret);
1427 }
1428 return prepareModelResult(*callback, "prepareModelFromCache_1_3", kServiceName);
1429 }
1430
1431 // version 1.2 HAL
1432 if (getDevice<V1_2::IDevice>() != nullptr) {
1433 const sp<PreparedModelCallback> callback = new PreparedModelCallback();
1434 const Return<V1_0::ErrorStatus> ret = recoverable<V1_0::ErrorStatus, V1_2::IDevice>(
1435 __FUNCTION__,
1436 [&modelCache, &dataCache, &token, &callback](const sp<V1_2::IDevice>& device) {
1437 return device->prepareModelFromCache(modelCache, dataCache, token, callback);
1438 },
1439 callback);
1440 if (ret.isDeadObject()) {
1441 LOG(ERROR) << "prepareModelFromCache failure: " << ret.description();
1442 return kDeadObject;
1443 }
1444 if (!ret.isOk()) {
1445 LOG(ERROR) << "prepareModelFromCache failure: " << ret.description();
1446 return prepareModelFailure();
1447 }
1448 const V1_0::ErrorStatus status = static_cast<V1_0::ErrorStatus>(ret);
1449 if (status != V1_0::ErrorStatus::NONE) {
1450 LOG(ERROR) << "prepareModelFromCache returned " << toString(status);
1451 return prepareModelFailure(convertToV1_3(status));
1452 }
1453 return prepareModelResult(*callback, "prepareModelFromCache", kServiceName);
1454 }
1455
1456 // version too low
1457 if (getDevice<V1_0::IDevice>() != nullptr) {
1458 LOG(ERROR) << "prepareModelFromCache called on V1_1 or V1_0 device";
1459 return prepareModelFailure();
1460 }
1461
1462 // No device available
1463 LOG(ERROR) << "prepareModelFromCache called with no device";
1464 return prepareModelFailure();
1465 }
1466
prepareModel(const ModelFactory & makeModel,ExecutionPreference preference,Priority priority,const std::optional<Deadline> & deadline,const std::string & cacheDir,const std::optional<CacheToken> & maybeToken) const1467 std::pair<int, std::shared_ptr<VersionedIPreparedModel>> VersionedIDevice::prepareModel(
1468 const ModelFactory& makeModel, ExecutionPreference preference, Priority priority,
1469 const std::optional<Deadline>& deadline, const std::string& cacheDir,
1470 const std::optional<CacheToken>& maybeToken) const {
1471 // Attempt to compile from cache if token is present.
1472 if (maybeToken.has_value()) {
1473 const auto [n, preparedModel] =
1474 prepareModelFromCacheInternal(deadline, cacheDir, *maybeToken);
1475 if (n == ANEURALNETWORKS_NO_ERROR) {
1476 return {n, preparedModel};
1477 }
1478 }
1479
1480 // Fallback to full compilation (possibly with token) if
1481 // prepareModelFromCache could not be used or failed.
1482 const Model model = makeModel();
1483 return prepareModelInternal(model, preference, priority, deadline, cacheDir, maybeToken);
1484 }
1485
getFeatureLevel() const1486 int64_t VersionedIDevice::getFeatureLevel() const {
1487 constexpr int64_t kFailure = -1;
1488
1489 if (getDevice<V1_3::IDevice>() != nullptr) {
1490 return __ANDROID_API_R__;
1491 } else if (getDevice<V1_2::IDevice>() != nullptr) {
1492 return __ANDROID_API_Q__;
1493 } else if (getDevice<V1_1::IDevice>() != nullptr) {
1494 return __ANDROID_API_P__;
1495 } else if (getDevice<V1_0::IDevice>() != nullptr) {
1496 return __ANDROID_API_O_MR1__;
1497 } else {
1498 LOG(ERROR) << "Device not available!";
1499 return kFailure;
1500 }
1501 }
1502
getType() const1503 int32_t VersionedIDevice::getType() const {
1504 return kType;
1505 }
1506
getVersionString() const1507 const std::string& VersionedIDevice::getVersionString() const {
1508 return kVersionString;
1509 }
1510
getNumberOfCacheFilesNeeded() const1511 std::pair<uint32_t, uint32_t> VersionedIDevice::getNumberOfCacheFilesNeeded() const {
1512 return kNumberOfCacheFilesNeeded;
1513 }
1514
getName() const1515 const std::string& VersionedIDevice::getName() const {
1516 return kServiceName;
1517 }
1518
allocate(const BufferDesc & desc,const std::vector<std::shared_ptr<VersionedIPreparedModel>> & versionedPreparedModels,const hidl_vec<BufferRole> & inputRoles,const hidl_vec<BufferRole> & outputRoles) const1519 std::tuple<ErrorStatus, sp<IBuffer>, uint32_t> VersionedIDevice::allocate(
1520 const BufferDesc& desc,
1521 const std::vector<std::shared_ptr<VersionedIPreparedModel>>& versionedPreparedModels,
1522 const hidl_vec<BufferRole>& inputRoles, const hidl_vec<BufferRole>& outputRoles) const {
1523 const auto kFailure = std::make_tuple<ErrorStatus, sp<IBuffer>, uint32_t>(
1524 ErrorStatus::GENERAL_FAILURE, nullptr, 0);
1525
1526 // version 1.3+ HAL
1527 if (getDevice<V1_3::IDevice>() != nullptr) {
1528 hidl_vec<sp<V1_3::IPreparedModel>> preparedModels(versionedPreparedModels.size());
1529 std::transform(versionedPreparedModels.begin(), versionedPreparedModels.end(),
1530 preparedModels.begin(),
1531 [](const auto& preparedModel) { return preparedModel->getV1_3(); });
1532
1533 std::tuple<ErrorStatus, sp<IBuffer>, int32_t> result;
1534 const Return<void> ret = recoverable<void, V1_3::IDevice>(
1535 __FUNCTION__, [&](const sp<V1_3::IDevice>& device) {
1536 return device->allocate(desc, preparedModels, inputRoles, outputRoles,
1537 [&result](ErrorStatus error, const sp<IBuffer>& buffer,
1538 uint32_t token) {
1539 result = {error, buffer, token};
1540 });
1541 });
1542 if (!ret.isOk()) {
1543 LOG(ERROR) << "allocate failure: " << ret.description();
1544 return kFailure;
1545 }
1546 return result;
1547 }
1548
1549 // version too low or no device available
1550 LOG(ERROR) << "Could not handle allocate";
1551 return kFailure;
1552 }
1553
1554 } // namespace nn
1555 } // namespace android
1556