1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "VersionedInterfaces"
18 
19 #include "VersionedInterfaces.h"
20 
21 #include <android-base/logging.h>
22 #include <android-base/properties.h>
23 #include <android-base/scopeguard.h>
24 #include <android-base/thread_annotations.h>
25 #include <android/sync.h>
26 #include <cutils/native_handle.h>
27 
28 #include <algorithm>
29 #include <chrono>
30 #include <functional>
31 #include <memory>
32 #include <string>
33 #include <tuple>
34 #include <type_traits>
35 #include <utility>
36 #include <vector>
37 
38 #include "Callbacks.h"
39 #include "ExecutionBurstController.h"
40 #include "MetaModel.h"
41 #include "Tracing.h"
42 #include "Utils.h"
43 
44 /*
45  * Some notes about HIDL interface objects and lifetimes across processes:
46  *
47  * All HIDL interface objects inherit from IBase, which itself inherits from
48  * ::android::RefBase. As such, all HIDL interface objects are reference counted
49  * and must be owned through ::android::sp (or referenced through ::android::wp).
50  * Allocating RefBase objects on the stack will log errors and may result in
51  * crashes, and deleting a RefBase object through another means (e.g., "delete",
52  * "free", or RAII-cleanup through std::unique_ptr or some equivalent) will
53  * result in double-free and/or use-after-free undefined behavior.
54  *
55  * HIDL/Binder manages the reference count of HIDL interface objects
56  * automatically across processes. If a process that references (but did not
57  * create) the HIDL interface object dies, HIDL/Binder ensures any reference
58  * count it held is properly released. (Caveat: it might be possible that
59  * HIDL/Binder behave strangely with ::android::wp references.)
60  *
61  * If the process which created the HIDL interface object dies, any call on this
62  * object from another process will result in a HIDL transport error with the
63  * code DEAD_OBJECT.
64  */
65 
66 /*
67  * Some notes about asynchronous calls across HIDL:
68  *
69  * For synchronous calls across HIDL, if an error occurs after the function was
70  * called but before it returns, HIDL will return a transport error. For
71  * example, if the message cannot be delivered to the server process or if the
72  * server process dies before returning a result, HIDL will return from the
73  * function with the appropriate transport error in the Return<> object which
74  * can be queried with Return<>::isOk(), Return<>::isDeadObject(),
75  * Return<>::description(), etc.
76  *
77  * However, HIDL offers no such error management in the case of asynchronous
78  * calls. By default, if the client launches an asynchronous task and the server
79  * fails to return a result through the callback, the client will be left
80  * waiting indefinitely for a result it will never receive.
81  *
82  * In the NNAPI, IDevice::prepareModel* and IPreparedModel::execute* (but not
83  * IPreparedModel::executeSynchronously*) are asynchronous calls across HIDL.
84  * Specifically, these asynchronous functions are called with a HIDL interface
85  * callback object (IPrepareModelCallback for IDevice::prepareModel* and
86  * IExecutionCallback for IPreparedModel::execute*) and are expected to quickly
87  * return, and the results are returned at a later time through these callback
88  * objects.
89  *
90  * To protect against the case when the server dies after the asynchronous task
91  * was called successfully but before the results could be returned, HIDL
92  * provides an object called a "hidl_death_recipient", which can be used to
93  * detect when an interface object (and more generally, the server process) has
94  * died. VersionedInterfaces uses hidl_death_recipients to detect when the
95  * driver process has died, and VersionedInterfaces will unblock any thread
96  * waiting on the results of a callback object that may otherwise not be
97  * signaled.
98  */
99 
100 namespace android {
101 namespace nn {
102 
103 // anonymous namespace
104 namespace {
105 
106 using namespace hal;
107 
108 const Timing kNoTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
109 
sendFailureMessage(IPreparedModelCallback * cb)110 void sendFailureMessage(IPreparedModelCallback* cb) {
111     CHECK(cb != nullptr);
112     cb->notify_1_3(ErrorStatus::GENERAL_FAILURE, nullptr);
113 }
114 
115 // This class is thread safe
116 template <typename Callback>
117 class DeathHandler : public hidl_death_recipient {
118    public:
serviceDied(uint64_t,const wp<hidl::base::V1_0::IBase> &)119     void serviceDied(uint64_t /*cookie*/, const wp<hidl::base::V1_0::IBase>& /*who*/) override {
120         LOG(ERROR) << "DeathHandler::serviceDied -- service unexpectedly died!";
121         std::lock_guard<std::mutex> hold(mMutex);
122         std::for_each(mCallbacks.begin(), mCallbacks.end(),
123                       [](const auto& cb) { cb->notifyAsDeadObject(); });
124     }
125 
protectCallback(const sp<Callback> & callback)126     [[nodiscard]] base::ScopeGuard<std::function<void()>> protectCallback(
127             const sp<Callback>& callback) {
128         registerCallback(callback);
129         return ::android::base::make_scope_guard(
130                 [this, callback] { unregisterCallback(callback); });
131     }
132 
133    private:
registerCallback(const sp<Callback> & callback)134     void registerCallback(const sp<Callback>& callback) {
135         std::lock_guard<std::mutex> hold(mMutex);
136         mCallbacks.push_back(callback);
137     }
138 
unregisterCallback(const sp<Callback> & callback)139     void unregisterCallback(const sp<Callback>& callback) {
140         std::lock_guard<std::mutex> hold(mMutex);
141         mCallbacks.erase(std::remove(mCallbacks.begin(), mCallbacks.end(), callback),
142                          mCallbacks.end());
143     }
144 
145     std::mutex mMutex;
146     std::vector<sp<Callback>> mCallbacks GUARDED_BY(mMutex);
147 };
148 
149 }  // anonymous namespace
150 
151 class IDeviceDeathHandler : public DeathHandler<PreparedModelCallback> {};
152 class IPreparedModelDeathHandler : public DeathHandler<ExecutionCallback> {};
153 
makeVersionedIPreparedModel(sp<V1_0::IPreparedModel> preparedModel)154 static std::pair<int, std::shared_ptr<VersionedIPreparedModel>> makeVersionedIPreparedModel(
155         sp<V1_0::IPreparedModel> preparedModel) {
156     CHECK(preparedModel != nullptr)
157             << "makeVersionedIPreparedModel passed invalid preparedModel object.";
158 
159     // create death handler object
160     sp<IPreparedModelDeathHandler> deathHandler = new IPreparedModelDeathHandler();
161 
162     // linkToDeath registers a callback that will be invoked on service death to
163     // proactively handle service crashes. If the linkToDeath call fails,
164     // asynchronous calls are susceptible to hangs if the service crashes before
165     // providing the response.
166     const Return<bool> ret = preparedModel->linkToDeath(deathHandler, 0);
167     if (ret.isDeadObject()) {
168         LOG(ERROR) << "makeVersionedIPreparedModel failed to register a death recipient for the "
169                       "IPreparedModel object because the IPreparedModel object is dead.";
170         return {ANEURALNETWORKS_DEAD_OBJECT, nullptr};
171     }
172     if (!ret.isOk()) {
173         LOG(ERROR) << "makeVersionedIPreparedModel failed to register a death recipient for the "
174                       "IPreparedModel object because of failure: "
175                    << ret.description();
176         return {ANEURALNETWORKS_OP_FAILED, nullptr};
177     }
178     if (ret != true) {
179         LOG(ERROR) << "makeVersionedIPreparedModel failed to register a death recipient for the "
180                       "IPreparedModel object.";
181         return {ANEURALNETWORKS_OP_FAILED, nullptr};
182     }
183 
184     // return a valid VersionedIPreparedModel object
185     return {ANEURALNETWORKS_NO_ERROR, std::make_shared<VersionedIPreparedModel>(
186                                               std::move(preparedModel), std::move(deathHandler))};
187 }
188 
VersionedIPreparedModel(sp<V1_0::IPreparedModel> preparedModel,sp<IPreparedModelDeathHandler> deathHandler)189 VersionedIPreparedModel::VersionedIPreparedModel(sp<V1_0::IPreparedModel> preparedModel,
190                                                  sp<IPreparedModelDeathHandler> deathHandler)
191     : mPreparedModelV1_0(std::move(preparedModel)),
192       mPreparedModelV1_2(V1_2::IPreparedModel::castFrom(mPreparedModelV1_0).withDefault(nullptr)),
193       mPreparedModelV1_3(V1_3::IPreparedModel::castFrom(mPreparedModelV1_0).withDefault(nullptr)),
194       mDeathHandler(std::move(deathHandler)) {}
195 
~VersionedIPreparedModel()196 VersionedIPreparedModel::~VersionedIPreparedModel() {
197     // It is safe to ignore any errors resulting from this unlinkToDeath call
198     // because the VersionedIPreparedModel object is already being destroyed and
199     // its underlying IPreparedModel object is no longer being used by the NN
200     // runtime.
201     mPreparedModelV1_0->unlinkToDeath(mDeathHandler).isOk();
202 }
203 
executeAsynchronously(const Request & request,MeasureTiming measure,const std::optional<Deadline> & deadline,const OptionalTimeoutDuration & loopTimeoutDuration) const204 std::tuple<int, std::vector<OutputShape>, Timing> VersionedIPreparedModel::executeAsynchronously(
205         const Request& request, MeasureTiming measure, const std::optional<Deadline>& deadline,
206         const OptionalTimeoutDuration& loopTimeoutDuration) const {
207     const auto failDeadObject = []() -> std::tuple<int, std::vector<OutputShape>, Timing> {
208         return {ANEURALNETWORKS_DEAD_OBJECT, {}, kNoTiming};
209     };
210     const auto failWithStatus = [](ErrorStatus status) {
211         return getExecutionResult(status, {}, kNoTiming);
212     };
213     const auto getResults = [failDeadObject](const ExecutionCallback& cb) {
214         if (cb.isDeadObject()) {
215             return failDeadObject();
216         }
217         return getExecutionResult(cb.getStatus(), cb.getOutputShapes(), cb.getTiming());
218     };
219 
220     const sp<ExecutionCallback> callback = new ExecutionCallback();
221     const auto scoped = mDeathHandler->protectCallback(callback);
222 
223     // version 1.3+ HAL
224     if (mPreparedModelV1_3 != nullptr) {
225         const auto otp = makeTimePoint(deadline);
226         Return<ErrorStatus> ret = mPreparedModelV1_3->execute_1_3(request, measure, otp,
227                                                                   loopTimeoutDuration, callback);
228         if (ret.isDeadObject()) {
229             LOG(ERROR) << "execute_1_3 failure: " << ret.description();
230             return failDeadObject();
231         }
232         if (!ret.isOk()) {
233             LOG(ERROR) << "execute_1_3 failure: " << ret.description();
234             return failWithStatus(ErrorStatus::GENERAL_FAILURE);
235         }
236         if (ret != ErrorStatus::NONE) {
237             LOG(ERROR) << "execute_1_3 returned " << toString(static_cast<ErrorStatus>(ret));
238             return failWithStatus(ret);
239         }
240         callback->wait();
241         return getResults(*callback);
242     }
243 
244     // version 1.2 HAL
245     if (mPreparedModelV1_2 != nullptr) {
246         const bool compliant = compliantWithV1_2(request);
247         if (!compliant) {
248             LOG(ERROR) << "Could not handle execute_1_2!";
249             return failWithStatus(ErrorStatus::GENERAL_FAILURE);
250         }
251         const V1_0::Request request12 = convertToV1_2(request);
252         Return<V1_0::ErrorStatus> ret =
253                 mPreparedModelV1_2->execute_1_2(request12, measure, callback);
254         if (ret.isDeadObject()) {
255             LOG(ERROR) << "execute_1_2 failure: " << ret.description();
256             return failDeadObject();
257         }
258         if (!ret.isOk()) {
259             LOG(ERROR) << "execute_1_2 failure: " << ret.description();
260             return failWithStatus(ErrorStatus::GENERAL_FAILURE);
261         }
262         const V1_0::ErrorStatus status = static_cast<V1_0::ErrorStatus>(ret);
263         if (status != V1_0::ErrorStatus::NONE) {
264             LOG(ERROR) << "execute_1_2 returned " << toString(status);
265             return failWithStatus(convertToV1_3(status));
266         }
267         callback->wait();
268         return getResults(*callback);
269     }
270 
271     // version 1.0 HAL
272     if (mPreparedModelV1_0 != nullptr) {
273         const bool compliant = compliantWithV1_0(request);
274         if (!compliant) {
275             LOG(ERROR) << "Could not handle execute!";
276             return failWithStatus(ErrorStatus::GENERAL_FAILURE);
277         }
278         const V1_0::Request request10 = convertToV1_0(request);
279         Return<V1_0::ErrorStatus> ret = mPreparedModelV1_0->execute(request10, callback);
280         if (ret.isDeadObject()) {
281             LOG(ERROR) << "execute failure: " << ret.description();
282             return failDeadObject();
283         }
284         if (!ret.isOk()) {
285             LOG(ERROR) << "execute failure: " << ret.description();
286             return failWithStatus(ErrorStatus::GENERAL_FAILURE);
287         }
288         const V1_0::ErrorStatus status = static_cast<V1_0::ErrorStatus>(ret);
289         if (status != V1_0::ErrorStatus::NONE) {
290             LOG(ERROR) << "execute returned " << toString(status);
291             return failWithStatus(convertToV1_3(status));
292         }
293         callback->wait();
294         return getResults(*callback);
295     }
296 
297     // No prepared model available
298     LOG(ERROR) << "executeAsynchronously called with no preparedModel";
299     return failWithStatus(ErrorStatus::GENERAL_FAILURE);
300 }
301 
executeSynchronously(const Request & request,MeasureTiming measure,const std::optional<Deadline> & deadline,const OptionalTimeoutDuration & loopTimeoutDuration) const302 std::tuple<int, std::vector<OutputShape>, Timing> VersionedIPreparedModel::executeSynchronously(
303         const Request& request, MeasureTiming measure, const std::optional<Deadline>& deadline,
304         const OptionalTimeoutDuration& loopTimeoutDuration) const {
305     const std::tuple<int, std::vector<OutputShape>, Timing> kDeadObject = {
306             ANEURALNETWORKS_DEAD_OBJECT, {}, kNoTiming};
307     const auto kFailure = getExecutionResult(ErrorStatus::GENERAL_FAILURE, {}, kNoTiming);
308 
309     // version 1.3+ HAL
310     if (mPreparedModelV1_3 != nullptr) {
311         std::tuple<int, std::vector<OutputShape>, Timing> result;
312         const auto otp = makeTimePoint(deadline);
313         Return<void> ret = mPreparedModelV1_3->executeSynchronously_1_3(
314                 request, measure, otp, loopTimeoutDuration,
315                 [&result](ErrorStatus error, const hidl_vec<OutputShape>& outputShapes,
316                           const Timing& timing) {
317                     result = getExecutionResult(error, outputShapes, timing);
318                 });
319         if (ret.isDeadObject()) {
320             LOG(ERROR) << "executeSynchronously_1_3 failure: " << ret.description();
321             return kDeadObject;
322         }
323         if (!ret.isOk()) {
324             LOG(ERROR) << "executeSynchronously_1_3 failure: " << ret.description();
325             return kFailure;
326         }
327         return result;
328     }
329 
330     // version 1.2 HAL
331     if (mPreparedModelV1_2 != nullptr) {
332         const bool compliant = compliantWithV1_2(request);
333         if (!compliant) {
334             LOG(ERROR) << "Could not handle executeSynchronously!";
335             return kFailure;
336         }
337         const V1_0::Request request12 = convertToV1_2(request);
338 
339         std::tuple<int, std::vector<OutputShape>, Timing> result;
340         Return<void> ret = mPreparedModelV1_2->executeSynchronously(
341                 request12, measure,
342                 [&result](V1_0::ErrorStatus error, const hidl_vec<OutputShape>& outputShapes,
343                           const Timing& timing) {
344                     result = getExecutionResult(convertToV1_3(error), outputShapes, timing);
345                 });
346         if (ret.isDeadObject()) {
347             LOG(ERROR) << "executeSynchronously failure: " << ret.description();
348             return kDeadObject;
349         }
350         if (!ret.isOk()) {
351             LOG(ERROR) << "executeSynchronously failure: " << ret.description();
352             return kFailure;
353         }
354         return result;
355     }
356 
357     // Fallback to asynchronous execution.
358     return executeAsynchronously(request, measure, deadline, loopTimeoutDuration);
359 }
360 
execute(const Request & request,MeasureTiming measure,const std::optional<Deadline> & deadline,const OptionalTimeoutDuration & loopTimeoutDuration,bool preferSynchronous) const361 std::tuple<int, std::vector<OutputShape>, Timing> VersionedIPreparedModel::execute(
362         const Request& request, MeasureTiming measure, const std::optional<Deadline>& deadline,
363         const OptionalTimeoutDuration& loopTimeoutDuration, bool preferSynchronous) const {
364     if (preferSynchronous) {
365         VLOG(EXECUTION) << "Before executeSynchronously() " << SHOW_IF_DEBUG(toString(request));
366         return executeSynchronously(request, measure, deadline, loopTimeoutDuration);
367     }
368 
369     VLOG(EXECUTION) << "Before executeAsynchronously() " << SHOW_IF_DEBUG(toString(request));
370     return executeAsynchronously(request, measure, deadline, loopTimeoutDuration);
371 }
372 
373 // This is the amount of time the ExecutionBurstController should spend polling
374 // the FMQ to see if it has data available before it should fall back to
375 // waiting on the futex.
getPollingTimeWindow()376 static std::chrono::microseconds getPollingTimeWindow() {
377     constexpr int32_t defaultPollingTimeWindow = 50;
378 #ifdef NN_DEBUGGABLE
379     constexpr int32_t minPollingTimeWindow = 0;
380     const int32_t selectedPollingTimeWindow =
381             base::GetIntProperty("debug.nn.burst-conrtoller-polling-window",
382                                  defaultPollingTimeWindow, minPollingTimeWindow);
383     return std::chrono::microseconds{selectedPollingTimeWindow};
384 #else
385     return std::chrono::microseconds{defaultPollingTimeWindow};
386 #endif  // NN_DEBUGGABLE
387 }
388 
configureExecutionBurst(bool preferPowerOverLatency) const389 std::shared_ptr<ExecutionBurstController> VersionedIPreparedModel::configureExecutionBurst(
390         bool preferPowerOverLatency) const {
391     if (mPreparedModelV1_2 == nullptr) {
392         return nullptr;
393     }
394     const auto pollingTimeWindow =
395             (preferPowerOverLatency ? std::chrono::microseconds{0} : getPollingTimeWindow());
396     return ExecutionBurstController::create(mPreparedModelV1_2, pollingTimeWindow);
397 }
398 
getCapabilitiesFunction(V1_3::IDevice * device)399 static std::pair<ErrorStatus, Capabilities> getCapabilitiesFunction(V1_3::IDevice* device) {
400     CHECK(device != nullptr);
401     NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_INITIALIZATION, "getCapabilities_1_3");
402     const std::pair<ErrorStatus, Capabilities> kFailure = {ErrorStatus::GENERAL_FAILURE, {}};
403     std::pair<ErrorStatus, Capabilities> result = kFailure;
404     const Return<void> ret = device->getCapabilities_1_3(
405             [&result](ErrorStatus error, const Capabilities& capabilities) {
406                 result = std::make_pair(error, capabilities);
407             });
408     if (!ret.isOk()) {
409         LOG(ERROR) << "getCapabilities_1_3 failure: " << ret.description();
410         return kFailure;
411     }
412     return result;
413 }
414 
415 std::tuple<int, hal::hidl_handle, sp<hal::IFencedExecutionCallback>, hal::Timing>
executeFenced(const hal::Request & request,const hal::hidl_vec<hal::hidl_handle> & waitFor,MeasureTiming measure,const std::optional<Deadline> & deadline,const OptionalTimeoutDuration & loopTimeoutDuration,const hal::OptionalTimeoutDuration & timeoutDurationAfterFence)416 VersionedIPreparedModel::executeFenced(
417         const hal::Request& request, const hal::hidl_vec<hal::hidl_handle>& waitFor,
418         MeasureTiming measure, const std::optional<Deadline>& deadline,
419         const OptionalTimeoutDuration& loopTimeoutDuration,
420         const hal::OptionalTimeoutDuration& timeoutDurationAfterFence) {
421     // version 1.3+ HAL
422     hal::hidl_handle syncFence;
423     sp<hal::IFencedExecutionCallback> dispatchCallback;
424     hal::Timing timing = {UINT64_MAX, UINT64_MAX};
425     if (mPreparedModelV1_3 != nullptr) {
426         ErrorStatus errorStatus;
427         const auto otp = makeTimePoint(deadline);
428         Return<void> ret = mPreparedModelV1_3->executeFenced(
429                 request, waitFor, measure, otp, loopTimeoutDuration, timeoutDurationAfterFence,
430                 [&syncFence, &errorStatus, &dispatchCallback](
431                         ErrorStatus error, const hidl_handle& handle,
432                         const sp<hal::IFencedExecutionCallback>& callback) {
433                     syncFence = handle;
434                     errorStatus = error;
435                     dispatchCallback = callback;
436                 });
437         if (!ret.isOk()) {
438             LOG(ERROR) << "executeFenced failure: " << ret.description();
439             return std::make_tuple(ANEURALNETWORKS_OP_FAILED, hal::hidl_handle(nullptr), nullptr,
440                                    timing);
441         }
442         if (errorStatus != ErrorStatus::NONE) {
443             LOG(ERROR) << "executeFenced returned "
444                        << toString(static_cast<ErrorStatus>(errorStatus));
445             return std::make_tuple(convertErrorStatusToResultCode(errorStatus),
446                                    hal::hidl_handle(nullptr), nullptr, timing);
447         }
448         return std::make_tuple(ANEURALNETWORKS_NO_ERROR, syncFence, dispatchCallback, timing);
449     }
450 
451     // fallback to synchronous execution if sync_fence is not supported
452     // first wait for all sync fences to be ready.
453     LOG(INFO) << "No drivers able to handle sync fences, falling back to regular execution";
454     for (const auto& fenceHandle : waitFor) {
455         if (!fenceHandle.getNativeHandle()) {
456             return std::make_tuple(ANEURALNETWORKS_BAD_DATA, hal::hidl_handle(nullptr), nullptr,
457                                    timing);
458         }
459         int syncFd = fenceHandle.getNativeHandle()->data[0];
460         if (syncFd <= 0) {
461             return std::make_tuple(ANEURALNETWORKS_BAD_DATA, hal::hidl_handle(nullptr), nullptr,
462                                    timing);
463         }
464         auto r = syncWait(syncFd, -1);
465         if (r != FenceState::SIGNALED) {
466             LOG(ERROR) << "syncWait failed, fd: " << syncFd;
467             return std::make_tuple(ANEURALNETWORKS_OP_FAILED, hal::hidl_handle(nullptr), nullptr,
468                                    timing);
469         }
470     }
471     int errorCode;
472     std::tie(errorCode, std::ignore, timing) =
473             executeSynchronously(request, measure, deadline, loopTimeoutDuration);
474     return std::make_tuple(errorCode, hal::hidl_handle(nullptr), nullptr, timing);
475 }
476 
getCapabilitiesFunction(V1_2::IDevice * device)477 static std::pair<ErrorStatus, Capabilities> getCapabilitiesFunction(V1_2::IDevice* device) {
478     CHECK(device != nullptr);
479     NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_INITIALIZATION, "getCapabilities_1_2");
480     const std::pair<ErrorStatus, Capabilities> kFailure = {ErrorStatus::GENERAL_FAILURE, {}};
481     std::pair<ErrorStatus, Capabilities> result = kFailure;
482     const Return<void> ret = device->getCapabilities_1_2(
483             [&result](V1_0::ErrorStatus error, const V1_2::Capabilities& capabilities) {
484                 result = std::make_pair(convertToV1_3(error), convertToV1_3(capabilities));
485             });
486     if (!ret.isOk()) {
487         LOG(ERROR) << "getCapabilities_1_2 failure: " << ret.description();
488         return kFailure;
489     }
490     return result;
491 }
492 
getCapabilitiesFunction(V1_1::IDevice * device)493 static std::pair<ErrorStatus, Capabilities> getCapabilitiesFunction(V1_1::IDevice* device) {
494     CHECK(device != nullptr);
495     NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_INITIALIZATION, "getCapabilities_1_1");
496     const std::pair<ErrorStatus, Capabilities> kFailure = {ErrorStatus::GENERAL_FAILURE, {}};
497     std::pair<ErrorStatus, Capabilities> result = kFailure;
498     const Return<void> ret = device->getCapabilities_1_1(
499             [&result](V1_0::ErrorStatus error, const V1_1::Capabilities& capabilities) {
500                 // Time taken to convert capabilities is trivial
501                 result = std::make_pair(convertToV1_3(error), convertToV1_3(capabilities));
502             });
503     if (!ret.isOk()) {
504         LOG(ERROR) << "getCapabilities_1_1 failure: " << ret.description();
505         return kFailure;
506     }
507     return result;
508 }
509 
getCapabilitiesFunction(V1_0::IDevice * device)510 static std::pair<ErrorStatus, Capabilities> getCapabilitiesFunction(V1_0::IDevice* device) {
511     CHECK(device != nullptr);
512     NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_INITIALIZATION, "getCapabilities");
513     const std::pair<ErrorStatus, Capabilities> kFailure = {ErrorStatus::GENERAL_FAILURE, {}};
514     std::pair<ErrorStatus, Capabilities> result = kFailure;
515     const Return<void> ret = device->getCapabilities(
516             [&result](V1_0::ErrorStatus error, const V1_0::Capabilities& capabilities) {
517                 // Time taken to convert capabilities is trivial
518                 result = std::make_pair(convertToV1_3(error), convertToV1_3(capabilities));
519             });
520     if (!ret.isOk()) {
521         LOG(ERROR) << "getCapabilities failure: " << ret.description();
522         return kFailure;
523     }
524     return result;
525 }
526 
getSupportedExtensionsFunction(V1_2::IDevice * device)527 static std::pair<ErrorStatus, hidl_vec<Extension>> getSupportedExtensionsFunction(
528         V1_2::IDevice* device) {
529     CHECK(device != nullptr);
530     NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_INITIALIZATION, "getSupportedExtensions");
531     const std::pair<ErrorStatus, hidl_vec<Extension>> kFailure = {ErrorStatus::GENERAL_FAILURE, {}};
532     std::pair<ErrorStatus, hidl_vec<Extension>> result = kFailure;
533     const Return<void> ret = device->getSupportedExtensions(
534             [&result](V1_0::ErrorStatus error, const hidl_vec<Extension>& extensions) {
535                 result = std::make_pair(convertToV1_3(error), extensions);
536             });
537     if (!ret.isOk()) {
538         LOG(ERROR) << "getSupportedExtensions failure: " << ret.description();
539         return kFailure;
540     }
541     return result;
542 }
543 
getSupportedExtensionsFunction(V1_0::IDevice * device)544 static std::pair<ErrorStatus, hidl_vec<Extension>> getSupportedExtensionsFunction(
545         V1_0::IDevice* device) {
546     CHECK(device != nullptr);
547     return {ErrorStatus::NONE, {/* No extensions. */}};
548 }
549 
getTypeFunction(V1_2::IDevice * device)550 static int32_t getTypeFunction(V1_2::IDevice* device) {
551     CHECK(device != nullptr);
552     constexpr int32_t kFailure = -1;
553     int32_t result = kFailure;
554     const Return<void> ret =
555             device->getType([&result](V1_0::ErrorStatus error, DeviceType deviceType) {
556                 if (error == V1_0::ErrorStatus::NONE) {
557                     result = static_cast<int32_t>(deviceType);
558                 }
559             });
560     if (!ret.isOk()) {
561         LOG(ERROR) << "getType failure: " << ret.description();
562         return kFailure;
563     }
564     return result;
565 }
566 
getTypeFunction(V1_0::IDevice * device)567 static int32_t getTypeFunction(V1_0::IDevice* device) {
568     CHECK(device != nullptr);
569     return ANEURALNETWORKS_DEVICE_UNKNOWN;
570 }
571 
getVersionStringFunction(V1_2::IDevice * device)572 static std::pair<ErrorStatus, hidl_string> getVersionStringFunction(V1_2::IDevice* device) {
573     CHECK(device != nullptr);
574     const std::pair<ErrorStatus, hidl_string> kFailure = {ErrorStatus::GENERAL_FAILURE, ""};
575     std::pair<ErrorStatus, hidl_string> result = kFailure;
576     const Return<void> ret = device->getVersionString(
577             [&result](V1_0::ErrorStatus error, const hidl_string& version) {
578                 result = std::make_pair(convertToV1_3(error), version);
579             });
580     if (!ret.isOk()) {
581         LOG(ERROR) << "getVersion failure: " << ret.description();
582         return kFailure;
583     }
584     return result;
585 }
586 
getVersionStringFunction(V1_0::IDevice * device)587 static std::pair<ErrorStatus, hidl_string> getVersionStringFunction(V1_0::IDevice* device) {
588     CHECK(device != nullptr);
589     return {ErrorStatus::NONE, "UNKNOWN"};
590 }
591 
getNumberOfCacheFilesNeededFunction(V1_2::IDevice * device)592 static std::tuple<ErrorStatus, uint32_t, uint32_t> getNumberOfCacheFilesNeededFunction(
593         V1_2::IDevice* device) {
594     CHECK(device != nullptr);
595     constexpr std::tuple<ErrorStatus, uint32_t, uint32_t> kFailure = {ErrorStatus::GENERAL_FAILURE,
596                                                                       0, 0};
597     std::tuple<ErrorStatus, uint32_t, uint32_t> result = kFailure;
598     const Return<void> ret = device->getNumberOfCacheFilesNeeded(
599             [&result](V1_0::ErrorStatus error, uint32_t numModelCache, uint32_t numDataCache) {
600                 result = {convertToV1_3(error), numModelCache, numDataCache};
601             });
602     if (!ret.isOk()) {
603         LOG(ERROR) << "getNumberOfCacheFilesNeeded failure: " << ret.description();
604         return kFailure;
605     }
606     return result;
607 }
608 
getNumberOfCacheFilesNeededFunction(V1_0::IDevice * device)609 static std::tuple<ErrorStatus, uint32_t, uint32_t> getNumberOfCacheFilesNeededFunction(
610         V1_0::IDevice* device) {
611     CHECK(device != nullptr);
612     return {ErrorStatus::NONE, 0, 0};
613 }
614 
615 struct InitialData {
616     hal::Capabilities capabilities;
617     hal::hidl_vec<hal::Extension> supportedExtensions;
618     int32_t type;
619     hal::hidl_string versionString;
620     std::pair<uint32_t, uint32_t> numberOfCacheFilesNeeded;
621 };
622 
623 template <typename Device>
initializeFunction(Device * device)624 static std::optional<InitialData> initializeFunction(Device* device) {
625     CHECK(device != nullptr);
626 
627     auto [capabilitiesStatus, capabilities] = getCapabilitiesFunction(device);
628     if (capabilitiesStatus != ErrorStatus::NONE) {
629         LOG(ERROR) << "IDevice::getCapabilities* returned the error "
630                    << toString(capabilitiesStatus);
631         return std::nullopt;
632     }
633     VLOG(MANAGER) << "Capab " << toString(capabilities);
634 
635     auto [versionStatus, versionString] = getVersionStringFunction(device);
636     if (versionStatus != ErrorStatus::NONE) {
637         LOG(ERROR) << "IDevice::getVersionString returned the error " << toString(versionStatus);
638         return std::nullopt;
639     }
640 
641     const int32_t type = getTypeFunction(device);
642     if (type == -1) {
643         LOG(ERROR) << "IDevice::getType returned an error";
644         return std::nullopt;
645     }
646 
647     auto [extensionsStatus, supportedExtensions] = getSupportedExtensionsFunction(device);
648     if (extensionsStatus != ErrorStatus::NONE) {
649         LOG(ERROR) << "IDevice::getSupportedExtensions returned the error "
650                    << toString(extensionsStatus);
651         return std::nullopt;
652     }
653 
654     const auto [cacheFilesStatus, numModelCacheFiles, numDataCacheFiles] =
655             getNumberOfCacheFilesNeededFunction(device);
656     if (cacheFilesStatus != ErrorStatus::NONE) {
657         LOG(ERROR) << "IDevice::getNumberOfCacheFilesNeeded returned the error "
658                    << toString(cacheFilesStatus);
659         return std::nullopt;
660     }
661 
662     // The following limit is enforced by VTS
663     constexpr uint32_t maxNumCacheFiles =
664             static_cast<uint32_t>(Constant::MAX_NUMBER_OF_CACHE_FILES);
665     if (numModelCacheFiles > maxNumCacheFiles || numDataCacheFiles > maxNumCacheFiles) {
666         LOG(ERROR)
667                 << "IDevice::getNumberOfCacheFilesNeeded returned invalid number of cache files: "
668                    "numModelCacheFiles = "
669                 << numModelCacheFiles << ", numDataCacheFiles = " << numDataCacheFiles
670                 << ", maxNumCacheFiles = " << maxNumCacheFiles;
671         return std::nullopt;
672     }
673 
674     return InitialData{
675             /*.capabilities=*/std::move(capabilities),
676             /*.supportedExtensions=*/std::move(supportedExtensions),
677             /*.type=*/type,
678             /*.versionString=*/std::move(versionString),
679             /*.numberOfCacheFilesNeeded=*/{numModelCacheFiles, numDataCacheFiles},
680     };
681 }
682 
683 template <typename Core>
initialize(const Core & core)684 std::optional<InitialData> initialize(const Core& core) {
685     // version 1.3+ HAL
686     if (const auto device = core.template getDevice<V1_3::IDevice>()) {
687         return initializeFunction(device.get());
688     }
689 
690     // version 1.2 HAL
691     if (const auto device = core.template getDevice<V1_2::IDevice>()) {
692         return initializeFunction(device.get());
693     }
694 
695     // version 1.1 HAL
696     if (const auto device = core.template getDevice<V1_1::IDevice>()) {
697         return initializeFunction(device.get());
698     }
699 
700     // version 1.0 HAL
701     if (const auto device = core.template getDevice<V1_0::IDevice>()) {
702         return initializeFunction(device.get());
703     }
704 
705     // No device available
706     LOG(ERROR) << "Device not available!";
707     return std::nullopt;
708 }
709 
create(std::string serviceName,const DeviceFactory & makeDevice)710 std::shared_ptr<VersionedIDevice> VersionedIDevice::create(std::string serviceName,
711                                                            const DeviceFactory& makeDevice) {
712     CHECK(makeDevice != nullptr)
713             << "VersionedIDevice::create passed invalid device factory object.";
714 
715     // get handle to IDevice object
716     sp<V1_0::IDevice> device = makeDevice(/*blocking=*/true);
717     if (device == nullptr) {
718         VLOG(DRIVER) << "VersionedIDevice::create got a null IDevice for " << serviceName;
719         return nullptr;
720     }
721 
722     auto core = Core::create(std::move(device));
723     if (!core.has_value()) {
724         LOG(ERROR) << "VersionedIDevice::create failed to create Core.";
725         return nullptr;
726     }
727 
728     auto initialData = initialize(*core);
729     if (!initialData.has_value()) {
730         LOG(ERROR) << "VersionedIDevice::create failed to initialize.";
731         return nullptr;
732     }
733 
734     auto [capabilities, supportedExtensions, type, versionString, numberOfCacheFilesNeeded] =
735             std::move(*initialData);
736     return std::make_shared<VersionedIDevice>(
737             std::move(capabilities), std::move(supportedExtensions), type, std::move(versionString),
738             numberOfCacheFilesNeeded, std::move(serviceName), makeDevice, std::move(core.value()));
739 }
740 
VersionedIDevice(hal::Capabilities capabilities,std::vector<hal::Extension> supportedExtensions,int32_t type,std::string versionString,std::pair<uint32_t,uint32_t> numberOfCacheFilesNeeded,std::string serviceName,const DeviceFactory & makeDevice,Core core)741 VersionedIDevice::VersionedIDevice(hal::Capabilities capabilities,
742                                    std::vector<hal::Extension> supportedExtensions, int32_t type,
743                                    std::string versionString,
744                                    std::pair<uint32_t, uint32_t> numberOfCacheFilesNeeded,
745                                    std::string serviceName, const DeviceFactory& makeDevice,
746                                    Core core)
747     : kCapabilities(std::move(capabilities)),
748       kSupportedExtensions(std::move(supportedExtensions)),
749       kType(type),
750       kVersionString(std::move(versionString)),
751       kNumberOfCacheFilesNeeded(numberOfCacheFilesNeeded),
752       kServiceName(std::move(serviceName)),
753       kMakeDevice(makeDevice),
754       mCore(std::move(core)) {}
755 
create(sp<V1_0::IDevice> device)756 std::optional<VersionedIDevice::Core> VersionedIDevice::Core::create(sp<V1_0::IDevice> device) {
757     CHECK(device != nullptr) << "VersionedIDevice::Core::create passed invalid device object.";
758 
759     // create death handler object
760     sp<IDeviceDeathHandler> deathHandler = new IDeviceDeathHandler();
761 
762     // linkToDeath registers a callback that will be invoked on service death to
763     // proactively handle service crashes. If the linkToDeath call fails,
764     // asynchronous calls are susceptible to hangs if the service crashes before
765     // providing the response.
766     const Return<bool> ret = device->linkToDeath(deathHandler, 0);
767     if (!ret.isOk()) {
768         LOG(ERROR) << "VersionedIDevice::Core::create failed to register a death recipient for the "
769                       "IDevice object because of failure: "
770                    << ret.description();
771         return {};
772     }
773     if (ret != true) {
774         LOG(ERROR) << "VersionedIDevice::Core::create failed to register a death recipient for the "
775                       "IDevice object.";
776         return {};
777     }
778 
779     // return a valid Core object
780     return Core(std::move(device), std::move(deathHandler));
781 }
782 
783 // HIDL guarantees all V1_1 interfaces inherit from their corresponding V1_0 interfaces.
Core(sp<V1_0::IDevice> device,sp<IDeviceDeathHandler> deathHandler)784 VersionedIDevice::Core::Core(sp<V1_0::IDevice> device, sp<IDeviceDeathHandler> deathHandler)
785     : mDeviceV1_0(std::move(device)),
786       mDeviceV1_1(V1_1::IDevice::castFrom(mDeviceV1_0).withDefault(nullptr)),
787       mDeviceV1_2(V1_2::IDevice::castFrom(mDeviceV1_0).withDefault(nullptr)),
788       mDeviceV1_3(V1_3::IDevice::castFrom(mDeviceV1_0).withDefault(nullptr)),
789       mDeathHandler(std::move(deathHandler)) {}
790 
~Core()791 VersionedIDevice::Core::~Core() {
792     if (mDeathHandler != nullptr) {
793         CHECK(mDeviceV1_0 != nullptr);
794         // It is safe to ignore any errors resulting from this unlinkToDeath call
795         // because the VersionedIDevice::Core object is already being destroyed and
796         // its underlying IDevice object is no longer being used by the NN runtime.
797         mDeviceV1_0->unlinkToDeath(mDeathHandler).isOk();
798     }
799 }
800 
Core(Core && other)801 VersionedIDevice::Core::Core(Core&& other) noexcept
802     : mDeviceV1_0(std::move(other.mDeviceV1_0)),
803       mDeviceV1_1(std::move(other.mDeviceV1_1)),
804       mDeviceV1_2(std::move(other.mDeviceV1_2)),
805       mDeviceV1_3(std::move(other.mDeviceV1_3)),
806       mDeathHandler(std::move(other.mDeathHandler)) {
807     other.mDeathHandler = nullptr;
808 }
809 
operator =(Core && other)810 VersionedIDevice::Core& VersionedIDevice::Core::operator=(Core&& other) noexcept {
811     if (this != &other) {
812         mDeviceV1_0 = std::move(other.mDeviceV1_0);
813         mDeviceV1_1 = std::move(other.mDeviceV1_1);
814         mDeviceV1_2 = std::move(other.mDeviceV1_2);
815         mDeviceV1_3 = std::move(other.mDeviceV1_3);
816         mDeathHandler = std::move(other.mDeathHandler);
817         other.mDeathHandler = nullptr;
818     }
819     return *this;
820 }
821 
822 template <typename T_IDevice>
getDeviceAndDeathHandler() const823 std::pair<sp<T_IDevice>, sp<IDeviceDeathHandler>> VersionedIDevice::Core::getDeviceAndDeathHandler()
824         const {
825     return {getDevice<T_IDevice>(), mDeathHandler};
826 }
827 
828 template <typename T_Return, typename T_IDevice, typename T_Callback>
callProtected(const char * context,const std::function<Return<T_Return> (const sp<T_IDevice> &)> & fn,const sp<T_IDevice> & device,const sp<T_Callback> & callback,const sp<IDeviceDeathHandler> & deathHandler)829 Return<T_Return> callProtected(const char* context,
830                                const std::function<Return<T_Return>(const sp<T_IDevice>&)>& fn,
831                                const sp<T_IDevice>& device, const sp<T_Callback>& callback,
832                                const sp<IDeviceDeathHandler>& deathHandler) {
833     const auto scoped = deathHandler->protectCallback(callback);
834     Return<T_Return> ret = fn(device);
835     // Suppose there was a transport error.  We have the following cases:
836     // 1. Either not due to a dead device, or due to a device that was
837     //    already dead at the time of the call to protectCallback().  In
838     //    this case, the callback was never signalled.
839     // 2. Due to a device that died after the call to protectCallback() but
840     //    before fn() completed.  In this case, the callback was (or will
841     //    be) signalled by the deathHandler.
842     // Furthermore, what if there was no transport error, but the ErrorStatus is
843     // other than NONE?  We'll conservatively signal the callback anyway, just in
844     // case the driver was sloppy and failed to do so.
845     if (!ret.isOk() || ret != T_Return::NONE) {
846         // What if the deathHandler has signalled or will signal the callback?
847         // This is fine -- we're permitted to signal multiple times; and we're
848         // sending the same signal that the deathHandler does.
849         //
850         // What if the driver signalled the callback?  Then this signal is
851         // ignored.
852 
853         if (ret.isOk()) {
854             LOG(ERROR) << context << " returned " << toString(static_cast<T_Return>(ret));
855         } else {
856             LOG(ERROR) << context << " failure: " << ret.description();
857         }
858         sendFailureMessage(callback.get());
859     }
860     callback->wait();
861     return ret;
862 }
863 template <typename T_Return, typename T_IDevice>
callProtected(const char *,const std::function<Return<T_Return> (const sp<T_IDevice> &)> & fn,const sp<T_IDevice> & device,const std::nullptr_t &,const sp<IDeviceDeathHandler> &)864 Return<T_Return> callProtected(const char*,
865                                const std::function<Return<T_Return>(const sp<T_IDevice>&)>& fn,
866                                const sp<T_IDevice>& device, const std::nullptr_t&,
867                                const sp<IDeviceDeathHandler>&) {
868     return fn(device);
869 }
870 
871 template <typename T_Return, typename T_IDevice, typename T_Callback>
recoverable(const char * context,const std::function<Return<T_Return> (const sp<T_IDevice> &)> & fn,const T_Callback & callback) const872 Return<T_Return> VersionedIDevice::recoverable(
873         const char* context, const std::function<Return<T_Return>(const sp<T_IDevice>&)>& fn,
874         const T_Callback& callback) const EXCLUDES(mMutex) {
875     CHECK_EQ(callback == nullptr, (std::is_same_v<T_Callback, std::nullptr_t>));
876 
877     sp<T_IDevice> device;
878     sp<IDeviceDeathHandler> deathHandler;
879     std::tie(device, deathHandler) = getDeviceAndDeathHandler<T_IDevice>();
880 
881     Return<T_Return> ret = callProtected(context, fn, device, callback, deathHandler);
882 
883     if (ret.isDeadObject()) {
884         {
885             std::unique_lock lock(mMutex);
886             // It's possible that another device has already done the recovery.
887             // It's harmless but wasteful for us to do so in this case.
888             auto pingReturn = mCore.getDevice<T_IDevice>()->ping();
889             if (pingReturn.isDeadObject()) {
890                 VLOG(DRIVER) << "VersionedIDevice::recoverable(" << context << ") -- Recovering "
891                              << kServiceName;
892                 sp<V1_0::IDevice> recoveredDevice = kMakeDevice(/*blocking=*/false);
893                 if (recoveredDevice == nullptr) {
894                     VLOG(DRIVER) << "VersionedIDevice::recoverable got a null IDEVICE for "
895                                  << kServiceName;
896                     return ret;
897                 }
898 
899                 auto core = Core::create(std::move(recoveredDevice));
900                 if (!core.has_value()) {
901                     LOG(ERROR) << "VersionedIDevice::recoverable failed to create Core.";
902                     return ret;
903                 }
904 
905                 mCore = std::move(core.value());
906             } else {
907                 VLOG(DRIVER) << "VersionedIDevice::recoverable(" << context
908                              << ") -- Someone else recovered " << kServiceName;
909                 // Might still have a transport error, which we need to check
910                 // before pingReturn goes out of scope.
911                 (void)pingReturn.isOk();
912             }
913             std::tie(device, deathHandler) = mCore.getDeviceAndDeathHandler<T_IDevice>();
914         }
915         ret = callProtected(context, fn, device, callback, deathHandler);
916         // It's possible that the device died again, but we're only going to
917         // attempt recovery once per call to recoverable().
918     }
919     return ret;
920 }
921 
wait() const922 int VersionedIDevice::wait() const {
923     std::unique_lock lock(mMutex);
924     // It's possible that another device has already done the recovery.
925     // It's harmless but wasteful for us to do so in this case.
926     auto pingReturn = mCore.getDevice<V1_0::IDevice>()->ping();
927     if (pingReturn.isDeadObject()) {
928         VLOG(DRIVER) << "VersionedIDevice::wait -- Recovering " << kServiceName;
929         sp<V1_0::IDevice> recoveredDevice = kMakeDevice(/*blocking=*/true);
930         if (recoveredDevice == nullptr) {
931             LOG(ERROR) << "VersionedIDevice::wait got a null IDevice for " << kServiceName;
932             return ANEURALNETWORKS_OP_FAILED;
933         }
934 
935         auto core = Core::create(std::move(recoveredDevice));
936         if (!core.has_value()) {
937             LOG(ERROR) << "VersionedIDevice::wait failed to create Core.";
938             return ANEURALNETWORKS_OP_FAILED;
939         }
940 
941         mCore = std::move(core.value());
942     } else if (!pingReturn.isOk()) {
943         LOG(ERROR) << "VersionedIDevice::wait failed -- IDevice::ping returned "
944                    << pingReturn.description();
945         return ANEURALNETWORKS_OP_FAILED;
946     }
947 
948     return ANEURALNETWORKS_NO_ERROR;
949 }
950 
getCapabilities() const951 const Capabilities& VersionedIDevice::getCapabilities() const {
952     return kCapabilities;
953 }
954 
getSupportedExtensions() const955 const std::vector<Extension>& VersionedIDevice::getSupportedExtensions() const {
956     return kSupportedExtensions;
957 }
958 
getSupportedOperations(const MetaModel & metaModel) const959 std::pair<ErrorStatus, hidl_vec<bool>> VersionedIDevice::getSupportedOperations(
960         const MetaModel& metaModel) const {
961     const std::pair<ErrorStatus, hidl_vec<bool>> kFailure = {ErrorStatus::GENERAL_FAILURE, {}};
962     std::pair<ErrorStatus, hidl_vec<bool>> result;
963 
964     const Model& model = metaModel.getModel();
965 
966     auto noneSupported = [&model] {
967         hidl_vec<bool> supported(model.main.operations.size());
968         std::fill(supported.begin(), supported.end(), false);
969         return std::make_pair(ErrorStatus::NONE, std::move(supported));
970     };
971 
972     auto remappedResult = [&model](const std::pair<ErrorStatus, hidl_vec<bool>>& result,
973                                    const std::function<uint32_t(uint32_t)>&
974                                            slicedModelOperationIndexToModelOperationIndex) {
975         const ErrorStatus status = result.first;
976         const hidl_vec<bool>& supported = result.second;
977         hidl_vec<bool> remappedSupported(model.main.operations.size());
978         std::fill(remappedSupported.begin(), remappedSupported.end(), false);
979         for (size_t i = 0; i < supported.size(); ++i) {
980             if (supported[i]) {
981                 remappedSupported[slicedModelOperationIndexToModelOperationIndex(i)] = true;
982             }
983         }
984         return std::make_pair(status, std::move(remappedSupported));
985     };
986 
987     // version 1.3+ HAL
988     if (getDevice<V1_3::IDevice>() != nullptr) {
989         NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_COMPILATION, "getSupportedOperations_1_3");
990         Return<void> ret = recoverable<void, V1_3::IDevice>(
991                 __FUNCTION__, [&model, &result](const sp<V1_3::IDevice>& device) {
992                     return device->getSupportedOperations_1_3(
993                             model, [&result](ErrorStatus error, const hidl_vec<bool>& supported) {
994                                 result = std::make_pair(error, supported);
995                             });
996                 });
997         if (!ret.isOk()) {
998             LOG(ERROR) << "getSupportedOperations_1_3 failure: " << ret.description();
999             return kFailure;
1000         }
1001         return result;
1002     }
1003 
1004     // version 1.2 HAL
1005     if (getDevice<V1_2::IDevice>() != nullptr) {
1006         const bool compliant = compliantWithV1_2(model);
1007         V1_2::Model model12;
1008         std::function<uint32_t(uint32_t)> slicedModelOperationIndexToModelOperationIndex;
1009         if (compliant) {
1010             model12 = convertToV1_2(model);
1011         } else {
1012             const auto slice12 = metaModel.getSliceV1_2();
1013             if (!slice12.has_value()) {
1014                 return noneSupported();
1015             }
1016             std::tie(model12, slicedModelOperationIndexToModelOperationIndex) = *slice12;
1017         }
1018         NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_COMPILATION, "getSupportedOperations_1_2");
1019         Return<void> ret = recoverable<void, V1_2::IDevice>(
1020                 __FUNCTION__, [&model12, &result](const sp<V1_2::IDevice>& device) {
1021                     return device->getSupportedOperations_1_2(
1022                             model12,
1023                             [&result](V1_0::ErrorStatus error, const hidl_vec<bool>& supported) {
1024                                 result = std::make_pair(convertToV1_3(error), supported);
1025                             });
1026                 });
1027         if (!ret.isOk()) {
1028             LOG(ERROR) << "getSupportedOperations_1_2 failure: " << ret.description();
1029             return kFailure;
1030         }
1031         if (!compliant) {
1032             return remappedResult(result, slicedModelOperationIndexToModelOperationIndex);
1033         }
1034         return result;
1035     }
1036 
1037     // version 1.1 HAL
1038     if (getDevice<V1_1::IDevice>() != nullptr) {
1039         const bool compliant = compliantWithV1_1(model);
1040         V1_1::Model model11;
1041         std::function<uint32_t(uint32_t)> slicedModelOperationIndexToModelOperationIndex;
1042         if (compliant) {
1043             model11 = convertToV1_1(model);
1044         } else {
1045             const auto slice11 = metaModel.getSliceV1_1();
1046             if (!slice11.has_value()) {
1047                 return noneSupported();
1048             }
1049             std::tie(model11, slicedModelOperationIndexToModelOperationIndex) = *slice11;
1050         }
1051         NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_COMPILATION, "getSupportedOperations_1_1");
1052         Return<void> ret = recoverable<void, V1_1::IDevice>(
1053                 __FUNCTION__, [&model11, &result](const sp<V1_1::IDevice>& device) {
1054                     return device->getSupportedOperations_1_1(
1055                             model11,
1056                             [&result](V1_0::ErrorStatus error, const hidl_vec<bool>& supported) {
1057                                 result = std::make_pair(convertToV1_3(error), supported);
1058                             });
1059                 });
1060         if (!ret.isOk()) {
1061             LOG(ERROR) << "getSupportedOperations_1_1 failure: " << ret.description();
1062             return kFailure;
1063         }
1064         if (!compliant) {
1065             return remappedResult(result, slicedModelOperationIndexToModelOperationIndex);
1066         }
1067         return result;
1068     }
1069 
1070     // version 1.0 HAL
1071     if (getDevice<V1_0::IDevice>() != nullptr) {
1072         const bool compliant = compliantWithV1_0(model);
1073         V1_0::Model model10;
1074         std::function<uint32_t(uint32_t)> slicedModelOperationIndexToModelOperationIndex;
1075         if (compliant) {
1076             model10 = convertToV1_0(model);
1077         } else {
1078             const auto slice10 = metaModel.getSliceV1_0();
1079             if (!slice10.has_value()) {
1080                 return noneSupported();
1081             }
1082             std::tie(model10, slicedModelOperationIndexToModelOperationIndex) = *slice10;
1083         }
1084         NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_COMPILATION, "getSupportedOperations");
1085         Return<void> ret = recoverable<void, V1_0::IDevice>(
1086                 __FUNCTION__, [&model10, &result](const sp<V1_0::IDevice>& device) {
1087                     return device->getSupportedOperations(
1088                             model10,
1089                             [&result](V1_0::ErrorStatus error, const hidl_vec<bool>& supported) {
1090                                 result = std::make_pair(convertToV1_3(error), supported);
1091                             });
1092                 });
1093         if (!ret.isOk()) {
1094             LOG(ERROR) << "getSupportedOperations failure: " << ret.description();
1095             return kFailure;
1096         }
1097         if (!compliant) {
1098             return remappedResult(result, slicedModelOperationIndexToModelOperationIndex);
1099         }
1100         return result;
1101     }
1102 
1103     // No device available
1104     LOG(ERROR) << "Device not available!";
1105     return kFailure;
1106 }
1107 
1108 // Opens cache file by filename and sets the handle to the opened fd. Returns false on fail. The
1109 // handle is expected to come in as empty, and is only set to a fd when the function returns true.
1110 // The file descriptor is always opened with both read and write permission.
createCacheHandle(const std::string & cache,bool createIfNotExist,hidl_handle * handle)1111 static bool createCacheHandle(const std::string& cache, bool createIfNotExist,
1112                               hidl_handle* handle) {
1113     CHECK(handle->getNativeHandle() == nullptr);
1114     int fd = open(cache.c_str(), createIfNotExist ? (O_RDWR | O_CREAT) : O_RDWR, S_IRUSR | S_IWUSR);
1115     NN_RET_CHECK_GE(fd, 0);
1116     native_handle_t* cacheNativeHandle = native_handle_create(1, 0);
1117     if (cacheNativeHandle == nullptr) {
1118         close(fd);
1119         return false;
1120     }
1121     cacheNativeHandle->data[0] = fd;
1122     handle->setTo(cacheNativeHandle, /*shouldOwn=*/true);
1123     return true;
1124 }
1125 
1126 // Opens a list of cache files and returns the handle vector. Returns empty vector on fail.
1127 // The file descriptors are always opened with both read and write permission.
createCacheHandleVec(uint32_t numCacheFiles,const std::string & baseFileName,bool createIfNotExist)1128 static hidl_vec<hidl_handle> createCacheHandleVec(uint32_t numCacheFiles,
1129                                                   const std::string& baseFileName,
1130                                                   bool createIfNotExist) {
1131     CHECK(numCacheFiles <= static_cast<uint32_t>(Constant::MAX_NUMBER_OF_CACHE_FILES));
1132     hidl_vec<hidl_handle> handles(numCacheFiles);
1133     for (uint32_t i = 0; i < numCacheFiles; i++) {
1134         std::string filename = baseFileName + std::to_string(i);
1135         VLOG(COMPILATION) << "Cache " << i << ": " << filename;
1136         if (!createCacheHandle(filename, createIfNotExist, &handles[i])) {
1137             return hidl_vec<hidl_handle>();
1138         }
1139     }
1140     return handles;
1141 }
1142 
1143 // Maps token to cache file names and sets the handle vectors to the opened fds. Returns false on
1144 // fail and leaves the vectors empty. Each vector is expected to come in as empty.
getCacheHandles(const std::string & cacheDir,const CacheToken & token,const std::pair<uint32_t,uint32_t> & numCacheFiles,bool createIfNotExist,hidl_vec<hidl_handle> * modelCache,hidl_vec<hidl_handle> * dataCache)1145 static bool getCacheHandles(const std::string& cacheDir, const CacheToken& token,
1146                             const std::pair<uint32_t, uint32_t>& numCacheFiles,
1147                             bool createIfNotExist, hidl_vec<hidl_handle>* modelCache,
1148                             hidl_vec<hidl_handle>* dataCache) {
1149     // The filename includes ANEURALNETWORKS_BYTE_SIZE_OF_CACHE_TOKEN * 2 characters for token,
1150     // and 1 character for model/data cache identifier.
1151     std::string filename(ANEURALNETWORKS_BYTE_SIZE_OF_CACHE_TOKEN * 2 + 1, '0');
1152     for (uint32_t i = 0; i < ANEURALNETWORKS_BYTE_SIZE_OF_CACHE_TOKEN; i++) {
1153         filename[i * 2] = 'A' + (token[i] & 0x0F);
1154         filename[i * 2 + 1] = 'A' + (token[i] >> 4);
1155     }
1156     CHECK(cacheDir.empty() || cacheDir.back() == '/');
1157     std::string cacheFileName = cacheDir + filename;
1158 
1159     cacheFileName[ANEURALNETWORKS_BYTE_SIZE_OF_CACHE_TOKEN * 2] = '1';
1160     *modelCache = createCacheHandleVec(numCacheFiles.first, cacheFileName, createIfNotExist);
1161     if (modelCache->size() != numCacheFiles.first) {
1162         return false;
1163     }
1164     cacheFileName[ANEURALNETWORKS_BYTE_SIZE_OF_CACHE_TOKEN * 2] = '2';
1165     *dataCache = createCacheHandleVec(numCacheFiles.second, cacheFileName, createIfNotExist);
1166     if (dataCache->size() != numCacheFiles.second) {
1167         modelCache->resize(0);
1168         return false;
1169     }
1170     return true;
1171 }
1172 
prepareModelFailure(ErrorStatus status=ErrorStatus::GENERAL_FAILURE)1173 static std::pair<int, std::shared_ptr<VersionedIPreparedModel>> prepareModelFailure(
1174         ErrorStatus status = ErrorStatus::GENERAL_FAILURE) {
1175     return {convertErrorStatusToResultCode(status), nullptr};
1176 }
1177 
prepareModelResult(const PreparedModelCallback & callback,const char * prepareName,const std::string & serviceName)1178 static std::pair<int, std::shared_ptr<VersionedIPreparedModel>> prepareModelResult(
1179         const PreparedModelCallback& callback, const char* prepareName,
1180         const std::string& serviceName) {
1181     callback.wait();
1182     if (callback.isDeadObject()) {
1183         LOG(ERROR) << prepareName << " on " << serviceName
1184                    << " failed because the PreparedModel object is dead";
1185         return {ANEURALNETWORKS_DEAD_OBJECT, nullptr};
1186     }
1187     const ErrorStatus status = callback.getStatus();
1188     const sp<V1_0::IPreparedModel> preparedModel = callback.getPreparedModel();
1189 
1190     if (status != ErrorStatus::NONE) {
1191         LOG(ERROR) << prepareName << " on " << serviceName << " failed: "
1192                    << "prepareReturnStatus=" << toString(status);
1193         return prepareModelFailure(status);
1194     }
1195     if (preparedModel == nullptr) {
1196         LOG(ERROR) << prepareName << " on " << serviceName << " failed: preparedModel is nullptr";
1197         return prepareModelFailure();
1198     }
1199 
1200     return makeVersionedIPreparedModel(preparedModel);
1201 }
1202 
prepareModelInternal(const Model & model,ExecutionPreference preference,Priority priority,const std::optional<Deadline> & deadline,const std::string & cacheDir,const std::optional<CacheToken> & maybeToken) const1203 std::pair<int, std::shared_ptr<VersionedIPreparedModel>> VersionedIDevice::prepareModelInternal(
1204         const Model& model, ExecutionPreference preference, Priority priority,
1205         const std::optional<Deadline>& deadline, const std::string& cacheDir,
1206         const std::optional<CacheToken>& maybeToken) const {
1207     // Note that some work within VersionedIDevice will be subtracted from the IPC layer
1208     NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_COMPILATION, "prepareModel");
1209     const std::pair<int, std::shared_ptr<VersionedIPreparedModel>> kDeadObject = {
1210             ANEURALNETWORKS_DEAD_OBJECT, nullptr};
1211 
1212     // Get cache files if they exist, otherwise create them.
1213     hidl_vec<hidl_handle> modelCache, dataCache;
1214     if (!maybeToken.has_value() ||
1215         !getCacheHandles(cacheDir, *maybeToken, kNumberOfCacheFilesNeeded,
1216                          /*createIfNotExist=*/true, &modelCache, &dataCache)) {
1217         modelCache.resize(0);
1218         dataCache.resize(0);
1219     }
1220 
1221     // Get the token if it exists, otherwise get a null token.
1222     static const CacheToken kNullToken{};
1223     const CacheToken token = maybeToken.value_or(kNullToken);
1224 
1225     const sp<PreparedModelCallback> callback = new PreparedModelCallback();
1226 
1227     // If 1.3 device, try preparing model
1228     if (getDevice<V1_3::IDevice>() != nullptr) {
1229         const auto otp = makeTimePoint(deadline);
1230         const Return<ErrorStatus> ret = recoverable<ErrorStatus, V1_3::IDevice>(
1231                 __FUNCTION__,
1232                 [&model, preference, priority, &otp, &modelCache, &dataCache, &token,
1233                  &callback](const sp<V1_3::IDevice>& device) {
1234                     return device->prepareModel_1_3(model, preference, priority, otp, modelCache,
1235                                                     dataCache, token, callback);
1236                 },
1237                 callback);
1238         if (ret.isDeadObject()) {
1239             LOG(ERROR) << "prepareModel_1_3 failure: " << ret.description();
1240             return kDeadObject;
1241         }
1242         if (!ret.isOk()) {
1243             LOG(ERROR) << "prepareModel_1_3 failure: " << ret.description();
1244             return prepareModelFailure();
1245         }
1246         if (ret != ErrorStatus::NONE) {
1247             LOG(ERROR) << "prepareModel_1_3 returned " << toString(static_cast<ErrorStatus>(ret));
1248             return prepareModelFailure(ret);
1249         }
1250         return prepareModelResult(*callback, "prepareModel_1_3", kServiceName);
1251     }
1252 
1253     // If 1.2 device, try preparing model (requires conversion)
1254     if (getDevice<V1_2::IDevice>() != nullptr) {
1255         bool compliant = false;
1256         V1_2::Model model12;
1257         {
1258             // Attribute time spent in model inspection and conversion to
1259             // Runtime, as the time may be substantial (0.03ms for mobilenet,
1260             // but could be larger for other models).
1261             NNTRACE_FULL_SUBTRACT(NNTRACE_LAYER_RUNTIME, NNTRACE_PHASE_COMPILATION,
1262                                   "VersionedIDevice::prepareModel_1_2");
1263             compliant = compliantWithV1_2(model);
1264             if (compliant) {
1265                 model12 = convertToV1_2(model);  // copy is elided
1266             }
1267         }
1268         if (compliant) {
1269             const Return<V1_0::ErrorStatus> ret = recoverable<V1_0::ErrorStatus, V1_2::IDevice>(
1270                     __FUNCTION__,
1271                     [&model12, &preference, &modelCache, &dataCache, &token,
1272                      &callback](const sp<V1_2::IDevice>& device) {
1273                         return device->prepareModel_1_2(model12, preference, modelCache, dataCache,
1274                                                         token, callback);
1275                     },
1276                     callback);
1277             if (ret.isDeadObject()) {
1278                 LOG(ERROR) << "prepareModel_1_2 failure: " << ret.description();
1279                 return kDeadObject;
1280             }
1281             if (!ret.isOk()) {
1282                 LOG(ERROR) << "prepareModel_1_2 failure: " << ret.description();
1283                 return prepareModelFailure();
1284             }
1285             const V1_0::ErrorStatus status = static_cast<V1_0::ErrorStatus>(ret);
1286             if (status != V1_0::ErrorStatus::NONE) {
1287                 LOG(ERROR) << "prepareModel_1_2 returned " << toString(status);
1288                 return prepareModelFailure(convertToV1_3(status));
1289             }
1290             return prepareModelResult(*callback, "prepareModel_1_2", kServiceName);
1291         }
1292 
1293         LOG(ERROR) << "Could not handle prepareModel_1_2!";
1294         return prepareModelFailure();
1295     }
1296 
1297     // If 1.1 device, try preparing model (requires conversion)
1298     if (getDevice<V1_1::IDevice>() != nullptr) {
1299         bool compliant = false;
1300         V1_1::Model model11;
1301         {
1302             // Attribute time spent in model inspection and conversion to
1303             // Runtime, as the time may be substantial (0.03ms for mobilenet,
1304             // but could be larger for other models).
1305             NNTRACE_FULL_SUBTRACT(NNTRACE_LAYER_RUNTIME, NNTRACE_PHASE_COMPILATION,
1306                                   "VersionedIDevice::prepareModel_1_1");
1307             compliant = compliantWithV1_1(model);
1308             if (compliant) {
1309                 model11 = convertToV1_1(model);  // copy is elided
1310             }
1311         }
1312         if (compliant) {
1313             const Return<V1_0::ErrorStatus> ret = recoverable<V1_0::ErrorStatus, V1_1::IDevice>(
1314                     __FUNCTION__,
1315                     [&model11, &preference, &callback](const sp<V1_1::IDevice>& device) {
1316                         return device->prepareModel_1_1(model11, preference, callback);
1317                     },
1318                     callback);
1319             if (ret.isDeadObject()) {
1320                 LOG(ERROR) << "prepareModel_1_1 failure: " << ret.description();
1321                 return kDeadObject;
1322             }
1323             if (!ret.isOk()) {
1324                 LOG(ERROR) << "prepareModel_1_1 failure: " << ret.description();
1325                 return prepareModelFailure();
1326             }
1327             const V1_0::ErrorStatus status = static_cast<V1_0::ErrorStatus>(ret);
1328             if (status != V1_0::ErrorStatus::NONE) {
1329                 LOG(ERROR) << "prepareModel_1_1 returned " << toString(status);
1330                 return prepareModelFailure(convertToV1_3(status));
1331             }
1332             return prepareModelResult(*callback, "prepareModel_1_1", kServiceName);
1333         }
1334 
1335         LOG(ERROR) << "Could not handle prepareModel_1_1!";
1336         return prepareModelFailure();
1337     }
1338 
1339     // If 1.0 device, try preparing model (requires conversion)
1340     if (getDevice<V1_0::IDevice>() != nullptr) {
1341         bool compliant = false;
1342         V1_0::Model model10;
1343         {
1344             // Attribute time spent in model inspection and conversion to
1345             // Runtime, as the time may be substantial (0.03ms for mobilenet,
1346             // but could be larger for other models).
1347             NNTRACE_FULL_SUBTRACT(NNTRACE_LAYER_RUNTIME, NNTRACE_PHASE_COMPILATION,
1348                                   "VersionedIDevice::prepareModel");
1349             compliant = compliantWithV1_0(model);
1350             if (compliant) {
1351                 model10 = convertToV1_0(model);  // copy is elided
1352             }
1353         }
1354         if (compliant) {
1355             const Return<V1_0::ErrorStatus> ret = recoverable<V1_0::ErrorStatus, V1_0::IDevice>(
1356                     __FUNCTION__,
1357                     [&model10, &callback](const sp<V1_0::IDevice>& device) {
1358                         return device->prepareModel(model10, callback);
1359                     },
1360                     callback);
1361             if (ret.isDeadObject()) {
1362                 LOG(ERROR) << "prepareModel failure: " << ret.description();
1363                 return kDeadObject;
1364             }
1365             if (!ret.isOk()) {
1366                 LOG(ERROR) << "prepareModel failure: " << ret.description();
1367                 return prepareModelFailure();
1368             }
1369             const V1_0::ErrorStatus status = static_cast<V1_0::ErrorStatus>(ret);
1370             if (status != V1_0::ErrorStatus::NONE) {
1371                 LOG(ERROR) << "prepareModel returned " << toString(status);
1372                 return prepareModelFailure(convertToV1_3(status));
1373             }
1374             return prepareModelResult(*callback, "prepareModel", kServiceName);
1375         }
1376 
1377         LOG(ERROR) << "Could not handle prepareModel!";
1378         return prepareModelFailure();
1379     }
1380 
1381     // Return error because there is no valid device
1382     LOG(ERROR) << "prepareModel called with no device";
1383     return prepareModelFailure();
1384 }
1385 
1386 std::pair<int, std::shared_ptr<VersionedIPreparedModel>>
prepareModelFromCacheInternal(const std::optional<Deadline> & deadline,const std::string & cacheDir,const CacheToken & token) const1387 VersionedIDevice::prepareModelFromCacheInternal(const std::optional<Deadline>& deadline,
1388                                                 const std::string& cacheDir,
1389                                                 const CacheToken& token) const {
1390     // Note that some work within VersionedIDevice will be subtracted from the IPC layer
1391     NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_COMPILATION, "prepareModelFromCache");
1392     VLOG(COMPILATION) << "prepareModelFromCache";
1393     const std::pair<int, std::shared_ptr<VersionedIPreparedModel>> kDeadObject = {
1394             ANEURALNETWORKS_DEAD_OBJECT, nullptr};
1395 
1396     // Get cache files if they exist, otherwise return from the function early.
1397     hidl_vec<hidl_handle> modelCache, dataCache;
1398     if (!getCacheHandles(cacheDir, token, kNumberOfCacheFilesNeeded,
1399                          /*createIfNotExist=*/false, &modelCache, &dataCache)) {
1400         return prepareModelFailure();
1401     }
1402 
1403     // version 1.3+ HAL
1404     if (getDevice<V1_3::IDevice>() != nullptr) {
1405         const auto otp = makeTimePoint(deadline);
1406         const sp<PreparedModelCallback> callback = new PreparedModelCallback();
1407         const Return<ErrorStatus> ret = recoverable<ErrorStatus, V1_3::IDevice>(
1408                 __FUNCTION__,
1409                 [&otp, &modelCache, &dataCache, &token,
1410                  &callback](const sp<V1_3::IDevice>& device) {
1411                     return device->prepareModelFromCache_1_3(otp, modelCache, dataCache, token,
1412                                                              callback);
1413                 },
1414                 callback);
1415         if (ret.isDeadObject()) {
1416             LOG(ERROR) << "prepareModelFromCache_1_3 failure: " << ret.description();
1417             return kDeadObject;
1418         }
1419         if (!ret.isOk()) {
1420             LOG(ERROR) << "prepareModelFromCache_1_3 failure: " << ret.description();
1421             return prepareModelFailure();
1422         }
1423         if (ret != ErrorStatus::NONE) {
1424             LOG(ERROR) << "prepareModelFromCache_1_3 returned "
1425                        << toString(static_cast<ErrorStatus>(ret));
1426             return prepareModelFailure(ret);
1427         }
1428         return prepareModelResult(*callback, "prepareModelFromCache_1_3", kServiceName);
1429     }
1430 
1431     // version 1.2 HAL
1432     if (getDevice<V1_2::IDevice>() != nullptr) {
1433         const sp<PreparedModelCallback> callback = new PreparedModelCallback();
1434         const Return<V1_0::ErrorStatus> ret = recoverable<V1_0::ErrorStatus, V1_2::IDevice>(
1435                 __FUNCTION__,
1436                 [&modelCache, &dataCache, &token, &callback](const sp<V1_2::IDevice>& device) {
1437                     return device->prepareModelFromCache(modelCache, dataCache, token, callback);
1438                 },
1439                 callback);
1440         if (ret.isDeadObject()) {
1441             LOG(ERROR) << "prepareModelFromCache failure: " << ret.description();
1442             return kDeadObject;
1443         }
1444         if (!ret.isOk()) {
1445             LOG(ERROR) << "prepareModelFromCache failure: " << ret.description();
1446             return prepareModelFailure();
1447         }
1448         const V1_0::ErrorStatus status = static_cast<V1_0::ErrorStatus>(ret);
1449         if (status != V1_0::ErrorStatus::NONE) {
1450             LOG(ERROR) << "prepareModelFromCache returned " << toString(status);
1451             return prepareModelFailure(convertToV1_3(status));
1452         }
1453         return prepareModelResult(*callback, "prepareModelFromCache", kServiceName);
1454     }
1455 
1456     // version too low
1457     if (getDevice<V1_0::IDevice>() != nullptr) {
1458         LOG(ERROR) << "prepareModelFromCache called on V1_1 or V1_0 device";
1459         return prepareModelFailure();
1460     }
1461 
1462     // No device available
1463     LOG(ERROR) << "prepareModelFromCache called with no device";
1464     return prepareModelFailure();
1465 }
1466 
prepareModel(const ModelFactory & makeModel,ExecutionPreference preference,Priority priority,const std::optional<Deadline> & deadline,const std::string & cacheDir,const std::optional<CacheToken> & maybeToken) const1467 std::pair<int, std::shared_ptr<VersionedIPreparedModel>> VersionedIDevice::prepareModel(
1468         const ModelFactory& makeModel, ExecutionPreference preference, Priority priority,
1469         const std::optional<Deadline>& deadline, const std::string& cacheDir,
1470         const std::optional<CacheToken>& maybeToken) const {
1471     // Attempt to compile from cache if token is present.
1472     if (maybeToken.has_value()) {
1473         const auto [n, preparedModel] =
1474                 prepareModelFromCacheInternal(deadline, cacheDir, *maybeToken);
1475         if (n == ANEURALNETWORKS_NO_ERROR) {
1476             return {n, preparedModel};
1477         }
1478     }
1479 
1480     // Fallback to full compilation (possibly with token) if
1481     // prepareModelFromCache could not be used or failed.
1482     const Model model = makeModel();
1483     return prepareModelInternal(model, preference, priority, deadline, cacheDir, maybeToken);
1484 }
1485 
getFeatureLevel() const1486 int64_t VersionedIDevice::getFeatureLevel() const {
1487     constexpr int64_t kFailure = -1;
1488 
1489     if (getDevice<V1_3::IDevice>() != nullptr) {
1490         return __ANDROID_API_R__;
1491     } else if (getDevice<V1_2::IDevice>() != nullptr) {
1492         return __ANDROID_API_Q__;
1493     } else if (getDevice<V1_1::IDevice>() != nullptr) {
1494         return __ANDROID_API_P__;
1495     } else if (getDevice<V1_0::IDevice>() != nullptr) {
1496         return __ANDROID_API_O_MR1__;
1497     } else {
1498         LOG(ERROR) << "Device not available!";
1499         return kFailure;
1500     }
1501 }
1502 
getType() const1503 int32_t VersionedIDevice::getType() const {
1504     return kType;
1505 }
1506 
getVersionString() const1507 const std::string& VersionedIDevice::getVersionString() const {
1508     return kVersionString;
1509 }
1510 
getNumberOfCacheFilesNeeded() const1511 std::pair<uint32_t, uint32_t> VersionedIDevice::getNumberOfCacheFilesNeeded() const {
1512     return kNumberOfCacheFilesNeeded;
1513 }
1514 
getName() const1515 const std::string& VersionedIDevice::getName() const {
1516     return kServiceName;
1517 }
1518 
allocate(const BufferDesc & desc,const std::vector<std::shared_ptr<VersionedIPreparedModel>> & versionedPreparedModels,const hidl_vec<BufferRole> & inputRoles,const hidl_vec<BufferRole> & outputRoles) const1519 std::tuple<ErrorStatus, sp<IBuffer>, uint32_t> VersionedIDevice::allocate(
1520         const BufferDesc& desc,
1521         const std::vector<std::shared_ptr<VersionedIPreparedModel>>& versionedPreparedModels,
1522         const hidl_vec<BufferRole>& inputRoles, const hidl_vec<BufferRole>& outputRoles) const {
1523     const auto kFailure = std::make_tuple<ErrorStatus, sp<IBuffer>, uint32_t>(
1524             ErrorStatus::GENERAL_FAILURE, nullptr, 0);
1525 
1526     // version 1.3+ HAL
1527     if (getDevice<V1_3::IDevice>() != nullptr) {
1528         hidl_vec<sp<V1_3::IPreparedModel>> preparedModels(versionedPreparedModels.size());
1529         std::transform(versionedPreparedModels.begin(), versionedPreparedModels.end(),
1530                        preparedModels.begin(),
1531                        [](const auto& preparedModel) { return preparedModel->getV1_3(); });
1532 
1533         std::tuple<ErrorStatus, sp<IBuffer>, int32_t> result;
1534         const Return<void> ret = recoverable<void, V1_3::IDevice>(
1535                 __FUNCTION__, [&](const sp<V1_3::IDevice>& device) {
1536                     return device->allocate(desc, preparedModels, inputRoles, outputRoles,
1537                                             [&result](ErrorStatus error, const sp<IBuffer>& buffer,
1538                                                       uint32_t token) {
1539                                                 result = {error, buffer, token};
1540                                             });
1541                 });
1542         if (!ret.isOk()) {
1543             LOG(ERROR) << "allocate failure: " << ret.description();
1544             return kFailure;
1545         }
1546         return result;
1547     }
1548 
1549     // version too low or no device available
1550     LOG(ERROR) << "Could not handle allocate";
1551     return kFailure;
1552 }
1553 
1554 }  // namespace nn
1555 }  // namespace android
1556