1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "VersionedInterfaces.h"
18 
19 #include "Callbacks.h"
20 #include "ExecutionBurstController.h"
21 #include "Tracing.h"
22 #include "Utils.h"
23 
24 #include <android-base/logging.h>
25 #include <android-base/scopeguard.h>
26 #include <android-base/thread_annotations.h>
27 #include <functional>
28 #include <type_traits>
29 
30 namespace android {
31 namespace nn {
32 
33 // anonymous namespace
34 namespace {
35 
36 using HidlToken = hidl_array<uint8_t, static_cast<uint32_t>(Constant::BYTE_SIZE_OF_CACHE_TOKEN)>;
37 
38 const Timing kBadTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
39 
sendFailureMessage(const sp<IPreparedModelCallback> & cb)40 void sendFailureMessage(const sp<IPreparedModelCallback>& cb) {
41     cb->notify(ErrorStatus::GENERAL_FAILURE, nullptr);
42 }
43 
sendFailureMessage(const sp<PreparedModelCallback> & cb)44 void sendFailureMessage(const sp<PreparedModelCallback>& cb) {
45     sendFailureMessage(static_cast<sp<IPreparedModelCallback>>(cb));
46 }
47 
sendFailureMessage(const sp<IExecutionCallback> & cb)48 void sendFailureMessage(const sp<IExecutionCallback>& cb) {
49     cb->notify(ErrorStatus::GENERAL_FAILURE);
50 }
51 
sendFailureMessage(const sp<ExecutionCallback> & cb)52 void sendFailureMessage(const sp<ExecutionCallback>& cb) {
53     sendFailureMessage(static_cast<sp<IExecutionCallback>>(cb));
54 }
55 
56 // This class is thread safe
57 template <typename ICallback>
58 class DeathHandler : public hardware::hidl_death_recipient {
59    public:
serviceDied(uint64_t,const wp<hidl::base::V1_0::IBase> &)60     void serviceDied(uint64_t /*cookie*/, const wp<hidl::base::V1_0::IBase>& /*who*/) override {
61         LOG(ERROR) << "DeathHandler::serviceDied -- service unexpectedly died!";
62         std::lock_guard<std::mutex> hold(mMutex);
63         std::for_each(mCallbacks.begin(), mCallbacks.end(),
64                       [](const auto& cb) { sendFailureMessage(cb); });
65     }
66 
protectCallback(const sp<ICallback> & callback)67     [[nodiscard]] base::ScopeGuard<std::function<void()>> protectCallback(
68             const sp<ICallback>& callback) {
69         registerCallback(callback);
70         return ::android::base::make_scope_guard(
71                 [this, callback] { unregisterCallback(callback); });
72     }
73 
registerCallback(const sp<ICallback> & callback)74     private : void registerCallback(const sp<ICallback>& callback) {
75         std::lock_guard<std::mutex> hold(mMutex);
76         mCallbacks.push_back(callback);
77     }
78 
unregisterCallback(const sp<ICallback> & callback)79     void unregisterCallback(const sp<ICallback>& callback) {
80         std::lock_guard<std::mutex> hold(mMutex);
81         mCallbacks.erase(std::remove(mCallbacks.begin(), mCallbacks.end(), callback),
82                          mCallbacks.end());
83     }
84 
85     std::mutex mMutex;
86     std::vector<sp<ICallback>> mCallbacks GUARDED_BY(mMutex);
87 };
88 
89 }  // anonymous namespace
90 
91 class IDeviceDeathHandler : public DeathHandler<IPreparedModelCallback> {};
92 class IPreparedModelDeathHandler : public DeathHandler<IExecutionCallback> {};
93 
makeVersionedIPreparedModel(sp<V1_0::IPreparedModel> preparedModel)94 static std::shared_ptr<VersionedIPreparedModel> makeVersionedIPreparedModel(
95         sp<V1_0::IPreparedModel> preparedModel) {
96     // verify input
97     if (!preparedModel) {
98         LOG(ERROR) << "makeVersionedIPreparedModel -- passed invalid preparedModel object.";
99         return nullptr;
100     }
101 
102     // create death handler object
103     sp<IPreparedModelDeathHandler> deathHandler = new (std::nothrow) IPreparedModelDeathHandler();
104     if (!deathHandler) {
105         LOG(ERROR) << "makeVersionedIPreparedModel -- Failed to create IPreparedModelDeathHandler.";
106         return nullptr;
107     }
108 
109     // linkToDeath registers a callback that will be invoked on service death to
110     // proactively handle service crashes. If the linkToDeath call fails,
111     // asynchronous calls are susceptible to hangs if the service crashes before
112     // providing the response.
113     const Return<bool> ret = preparedModel->linkToDeath(deathHandler, 0);
114     if (!ret.isOk() || ret != true) {
115         LOG(ERROR) << "makeVersionedIPreparedModel -- Failed to register a death recipient for the "
116                       "IPreparedModel object.";
117         return nullptr;
118     }
119 
120     // return a valid VersionedIPreparedModel object
121     return std::make_shared<VersionedIPreparedModel>(std::move(preparedModel),
122                                                      std::move(deathHandler));
123 }
124 
VersionedIPreparedModel(sp<V1_0::IPreparedModel> preparedModel,sp<IPreparedModelDeathHandler> deathHandler)125 VersionedIPreparedModel::VersionedIPreparedModel(sp<V1_0::IPreparedModel> preparedModel,
126                                                  sp<IPreparedModelDeathHandler> deathHandler)
127     : mPreparedModelV1_0(std::move(preparedModel)),
128       mPreparedModelV1_2(V1_2::IPreparedModel::castFrom(mPreparedModelV1_0).withDefault(nullptr)),
129       mDeathHandler(std::move(deathHandler)) {}
130 
~VersionedIPreparedModel()131 VersionedIPreparedModel::~VersionedIPreparedModel() {
132     // It is safe to ignore any errors resulting from this unlinkToDeath call
133     // because the VersionedIPreparedModel object is already being destroyed and
134     // its underlying IPreparedModel object is no longer being used by the NN
135     // runtime.
136     mPreparedModelV1_0->unlinkToDeath(mDeathHandler).isOk();
137 }
138 
execute(const Request & request,MeasureTiming measure,const sp<ExecutionCallback> & callback)139 ErrorStatus VersionedIPreparedModel::execute(const Request& request, MeasureTiming measure,
140                                              const sp<ExecutionCallback>& callback) {
141     const auto scoped = mDeathHandler->protectCallback(callback);
142 
143     if (mPreparedModelV1_2 != nullptr) {
144         Return<ErrorStatus> ret = mPreparedModelV1_2->execute_1_2(request, measure, callback);
145         if (!ret.isOk()) {
146             sendFailureMessage(callback);
147             LOG(ERROR) << "execute_1_2 failure: " << ret.description();
148             return ErrorStatus::GENERAL_FAILURE;
149         }
150         if (ret != ErrorStatus::NONE) {
151             sendFailureMessage(callback);
152             LOG(ERROR) << "execute_1_2 returned " << toString(static_cast<ErrorStatus>(ret));
153             return static_cast<ErrorStatus>(ret);
154         }
155         callback->wait();
156         return static_cast<ErrorStatus>(ret);
157     } else if (mPreparedModelV1_0 != nullptr) {
158         Return<ErrorStatus> ret = mPreparedModelV1_0->execute(request, callback);
159         if (!ret.isOk()) {
160             sendFailureMessage(callback);
161             LOG(ERROR) << "execute failure: " << ret.description();
162             return ErrorStatus::GENERAL_FAILURE;
163         }
164         if (ret != ErrorStatus::NONE) {
165             sendFailureMessage(callback);
166             LOG(ERROR) << "execute returned " << toString(static_cast<ErrorStatus>(ret));
167             return static_cast<ErrorStatus>(ret);
168         }
169         callback->wait();
170         return static_cast<ErrorStatus>(ret);
171     } else {
172         sendFailureMessage(callback);
173         LOG(ERROR) << "execute called with no preparedModel";
174         return ErrorStatus::GENERAL_FAILURE;
175     }
176 }
177 
178 std::tuple<ErrorStatus, hidl_vec<OutputShape>, Timing>
executeSynchronously(const Request & request,MeasureTiming measure)179 VersionedIPreparedModel::executeSynchronously(const Request& request, MeasureTiming measure) {
180     const std::tuple<ErrorStatus, hidl_vec<OutputShape>, Timing> kFailure = {
181             ErrorStatus::GENERAL_FAILURE, {}, kBadTiming};
182 
183     if (mPreparedModelV1_2 != nullptr) {
184         std::tuple<ErrorStatus, hidl_vec<OutputShape>, Timing> result;
185         Return<void> ret = mPreparedModelV1_2->executeSynchronously(
186                 request, measure,
187                 [&result](ErrorStatus error, const hidl_vec<OutputShape>& outputShapes,
188                           const Timing& timing) {
189                     result = std::make_tuple(error, outputShapes, timing);
190                 });
191         if (!ret.isOk()) {
192             LOG(ERROR) << "executeSynchronously failure: " << ret.description();
193             return kFailure;
194         }
195         return result;
196     } else {
197         // Simulate synchronous execution.
198         sp<ExecutionCallback> callback = new ExecutionCallback();
199         ErrorStatus ret = execute(request, measure, callback);
200         if (ret != ErrorStatus::NONE) {
201             return {ret, {}, kBadTiming};
202         }
203         callback->wait();
204         // callback->getOutputShapes() will always return an empty hidl vector.
205         // callback->getTiming() will always return values indicating no measurement.
206         return {callback->getStatus(), callback->getOutputShapes(), callback->getTiming()};
207     }
208 }
209 
configureExecutionBurst(bool blocking) const210 std::shared_ptr<ExecutionBurstController> VersionedIPreparedModel::configureExecutionBurst(
211         bool blocking) const {
212     if (mPreparedModelV1_2 != nullptr) {
213         return ExecutionBurstController::create(mPreparedModelV1_2, blocking);
214     } else {
215         return nullptr;
216     }
217 }
218 
operator ==(nullptr_t) const219 bool VersionedIPreparedModel::operator==(nullptr_t) const {
220     return mPreparedModelV1_0 == nullptr;
221 }
222 
operator !=(nullptr_t) const223 bool VersionedIPreparedModel::operator!=(nullptr_t) const {
224     return mPreparedModelV1_0 != nullptr;
225 }
226 
create(std::string serviceName,sp<V1_0::IDevice> device)227 std::shared_ptr<VersionedIDevice> VersionedIDevice::create(std::string serviceName,
228                                                            sp<V1_0::IDevice> device) {
229     auto core = Core::create(std::move(device));
230     if (!core.has_value()) {
231         LOG(ERROR) << "VersionedIDevice::create -- Failed to create Core.";
232         return nullptr;
233     }
234 
235     // return a valid VersionedIDevice object
236     return std::make_shared<VersionedIDevice>(std::move(serviceName), std::move(core.value()));
237 }
238 
VersionedIDevice(std::string serviceName,Core core)239 VersionedIDevice::VersionedIDevice(std::string serviceName, Core core)
240     : mServiceName(std::move(serviceName)), mCore(std::move(core)) {}
241 
create(sp<V1_0::IDevice> device)242 std::optional<VersionedIDevice::Core> VersionedIDevice::Core::create(sp<V1_0::IDevice> device) {
243     // verify input
244     if (!device) {
245         LOG(ERROR) << "VersionedIDevice::Core::create -- passed invalid device object.";
246         return {};
247     }
248 
249     // create death handler object
250     sp<IDeviceDeathHandler> deathHandler = new (std::nothrow) IDeviceDeathHandler();
251     if (!deathHandler) {
252         LOG(ERROR) << "VersionedIDevice::Core::create -- Failed to create IDeviceDeathHandler.";
253         return {};
254     }
255 
256     // linkToDeath registers a callback that will be invoked on service death to
257     // proactively handle service crashes. If the linkToDeath call fails,
258     // asynchronous calls are susceptible to hangs if the service crashes before
259     // providing the response.
260     const Return<bool> ret = device->linkToDeath(deathHandler, 0);
261     if (!ret.isOk() || ret != true) {
262         LOG(ERROR)
263                 << "VersionedIDevice::Core::create -- Failed to register a death recipient for the "
264                    "IDevice object.";
265         return {};
266     }
267 
268     // return a valid Core object
269     return Core(std::move(device), std::move(deathHandler));
270 }
271 
272 // HIDL guarantees all V1_1 interfaces inherit from their corresponding V1_0 interfaces.
Core(sp<V1_0::IDevice> device,sp<IDeviceDeathHandler> deathHandler)273 VersionedIDevice::Core::Core(sp<V1_0::IDevice> device, sp<IDeviceDeathHandler> deathHandler)
274     : mDeviceV1_0(std::move(device)),
275       mDeviceV1_1(V1_1::IDevice::castFrom(mDeviceV1_0).withDefault(nullptr)),
276       mDeviceV1_2(V1_2::IDevice::castFrom(mDeviceV1_0).withDefault(nullptr)),
277       mDeathHandler(std::move(deathHandler)) {}
278 
~Core()279 VersionedIDevice::Core::~Core() {
280     if (mDeathHandler != nullptr) {
281         CHECK(mDeviceV1_0 != nullptr);
282         // It is safe to ignore any errors resulting from this unlinkToDeath call
283         // because the VersionedIDevice::Core object is already being destroyed and
284         // its underlying IDevice object is no longer being used by the NN runtime.
285         mDeviceV1_0->unlinkToDeath(mDeathHandler).isOk();
286     }
287 }
288 
Core(Core && other)289 VersionedIDevice::Core::Core(Core&& other) noexcept
290     : mDeviceV1_0(std::move(other.mDeviceV1_0)),
291       mDeviceV1_1(std::move(other.mDeviceV1_1)),
292       mDeviceV1_2(std::move(other.mDeviceV1_2)),
293       mDeathHandler(std::move(other.mDeathHandler)) {
294     other.mDeathHandler = nullptr;
295 }
296 
operator =(Core && other)297 VersionedIDevice::Core& VersionedIDevice::Core::operator=(Core&& other) noexcept {
298     if (this != &other) {
299         mDeviceV1_0 = std::move(other.mDeviceV1_0);
300         mDeviceV1_1 = std::move(other.mDeviceV1_1);
301         mDeviceV1_2 = std::move(other.mDeviceV1_2);
302         mDeathHandler = std::move(other.mDeathHandler);
303         other.mDeathHandler = nullptr;
304     }
305     return *this;
306 }
307 
308 template <typename T_IDevice>
getDeviceAndDeathHandler() const309 std::pair<sp<T_IDevice>, sp<IDeviceDeathHandler>> VersionedIDevice::Core::getDeviceAndDeathHandler()
310         const {
311     return {getDevice<T_IDevice>(), mDeathHandler};
312 }
313 
314 template <typename T_IDevice, typename T_Callback>
callProtected(const char * context,const std::function<Return<ErrorStatus> (const sp<T_IDevice> &)> & fn,const sp<T_IDevice> & device,const sp<T_Callback> & callback,const sp<IDeviceDeathHandler> & deathHandler)315 Return<ErrorStatus> callProtected(
316         const char* context, const std::function<Return<ErrorStatus>(const sp<T_IDevice>&)>& fn,
317         const sp<T_IDevice>& device, const sp<T_Callback>& callback,
318         const sp<IDeviceDeathHandler>& deathHandler) {
319     const auto scoped = deathHandler->protectCallback(callback);
320     Return<ErrorStatus> ret = fn(device);
321     // Suppose there was a transport error.  We have the following cases:
322     // 1. Either not due to a dead device, or due to a device that was
323     //    already dead at the time of the call to protectCallback().  In
324     //    this case, the callback was never signalled.
325     // 2. Due to a device that died after the call to protectCallback() but
326     //    before fn() completed.  In this case, the callback was (or will
327     //    be) signalled by the deathHandler.
328     // Furthermore, what if there was no transport error, but the ErrorStatus is
329     // other than NONE?  We'll conservatively signal the callback anyway, just in
330     // case the driver was sloppy and failed to do so.
331     if (!ret.isOk() || ret != ErrorStatus::NONE) {
332         // What if the deathHandler has signalled or will signal the callback?
333         // This is fine -- we're permitted to signal multiple times; and we're
334         // sending the same signal that the deathHandler does.
335         //
336         // What if the driver signalled the callback?  Then this signal is
337         // ignored.
338 
339         if (ret.isOk()) {
340             LOG(ERROR) << context << " returned " << toString(static_cast<ErrorStatus>(ret));
341         } else {
342             LOG(ERROR) << context << " failure: " << ret.description();
343         }
344         sendFailureMessage(callback);
345     }
346     callback->wait();
347     return ret;
348 }
349 template <typename T_Return, typename T_IDevice>
callProtected(const char *,const std::function<Return<T_Return> (const sp<T_IDevice> &)> & fn,const sp<T_IDevice> & device,const std::nullptr_t &,const sp<IDeviceDeathHandler> &)350 Return<T_Return> callProtected(const char*,
351                                const std::function<Return<T_Return>(const sp<T_IDevice>&)>& fn,
352                                const sp<T_IDevice>& device, const std::nullptr_t&,
353                                const sp<IDeviceDeathHandler>&) {
354     return fn(device);
355 }
356 
357 template <typename T_Return, typename T_IDevice, typename T_Callback>
recoverable(const char * context,const std::function<Return<T_Return> (const sp<T_IDevice> &)> & fn,const T_Callback & callback) const358 Return<T_Return> VersionedIDevice::recoverable(
359         const char* context, const std::function<Return<T_Return>(const sp<T_IDevice>&)>& fn,
360         const T_Callback& callback) const EXCLUDES(mMutex) {
361     CHECK_EQ(callback == nullptr, (std::is_same_v<T_Callback, std::nullptr_t>));
362 
363     sp<T_IDevice> device;
364     sp<IDeviceDeathHandler> deathHandler;
365     std::tie(device, deathHandler) = getDeviceAndDeathHandler<T_IDevice>();
366 
367     Return<T_Return> ret = callProtected(context, fn, device, callback, deathHandler);
368 
369     if (ret.isDeadObject()) {
370         {
371             std::unique_lock lock(mMutex);
372             // It's possible that another device has already done the recovery.
373             // It's harmless but wasteful for us to do so in this case.
374             auto pingReturn = mCore.getDevice<T_IDevice>()->ping();
375             if (pingReturn.isDeadObject()) {
376                 VLOG(DRIVER) << "VersionedIDevice::recoverable(" << context << ") -- Recovering "
377                              << mServiceName;
378                 sp<V1_0::IDevice> recoveredDevice = V1_0::IDevice::tryGetService(mServiceName);
379                 if (recoveredDevice == nullptr) {
380                     VLOG(DRIVER) << "VersionedIDevice::recoverable got a null IDEVICE for "
381                                  << mServiceName;
382                     return ret;
383                 }
384 
385                 auto core = Core::create(std::move(recoveredDevice));
386                 if (!core.has_value()) {
387                     LOG(ERROR) << "VersionedIDevice::recoverable -- Failed to create Core.";
388                     return ret;
389                 }
390 
391                 mCore = std::move(core.value());
392             } else {
393                 VLOG(DRIVER) << "VersionedIDevice::recoverable(" << context
394                              << ") -- Someone else recovered " << mServiceName;
395                 // Might still have a transport error, which we need to check
396                 // before pingReturn goes out of scope.
397                 (void)pingReturn.isOk();
398             }
399             std::tie(device, deathHandler) = mCore.getDeviceAndDeathHandler<T_IDevice>();
400         }
401         ret = callProtected(context, fn, device, callback, deathHandler);
402         // It's possible that the device died again, but we're only going to
403         // attempt recovery once per call to recoverable().
404     }
405     return ret;
406 }
407 
getCapabilities()408 std::pair<ErrorStatus, Capabilities> VersionedIDevice::getCapabilities() {
409     const std::pair<ErrorStatus, Capabilities> kFailure = {ErrorStatus::GENERAL_FAILURE, {}};
410     std::pair<ErrorStatus, Capabilities> result;
411 
412     if (getDevice<V1_2::IDevice>() != nullptr) {
413         NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_INITIALIZATION, "getCapabilities_1_2");
414         Return<void> ret = recoverable<void, V1_2::IDevice>(
415                 __FUNCTION__, [&result](const sp<V1_2::IDevice>& device) {
416                     return device->getCapabilities_1_2(
417                             [&result](ErrorStatus error, const Capabilities& capabilities) {
418                                 result = std::make_pair(error, capabilities);
419                             });
420                 });
421         if (!ret.isOk()) {
422             LOG(ERROR) << "getCapabilities_1_2 failure: " << ret.description();
423             return {ErrorStatus::GENERAL_FAILURE, {}};
424         }
425     } else if (getDevice<V1_1::IDevice>() != nullptr) {
426         NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_INITIALIZATION, "getCapabilities_1_1");
427         Return<void> ret = recoverable<void, V1_1::IDevice>(
428                 __FUNCTION__, [&result](const sp<V1_1::IDevice>& device) {
429                     return device->getCapabilities_1_1(
430                             [&result](ErrorStatus error, const V1_1::Capabilities& capabilities) {
431                                 // Time taken to convert capabilities is trivial
432                                 result = std::make_pair(error, convertToV1_2(capabilities));
433                             });
434                 });
435         if (!ret.isOk()) {
436             LOG(ERROR) << "getCapabilities_1_1 failure: " << ret.description();
437             return kFailure;
438         }
439     } else if (getDevice<V1_0::IDevice>() != nullptr) {
440         NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_INITIALIZATION, "getCapabilities");
441         Return<void> ret = recoverable<void, V1_0::IDevice>(
442                 __FUNCTION__, [&result](const sp<V1_0::IDevice>& device) {
443                     return device->getCapabilities(
444                             [&result](ErrorStatus error, const V1_0::Capabilities& capabilities) {
445                                 // Time taken to convert capabilities is trivial
446                                 result = std::make_pair(error, convertToV1_2(capabilities));
447                             });
448                 });
449         if (!ret.isOk()) {
450             LOG(ERROR) << "getCapabilities failure: " << ret.description();
451             return kFailure;
452         }
453     } else {
454         LOG(ERROR) << "Device not available!";
455         return {ErrorStatus::DEVICE_UNAVAILABLE, {}};
456     }
457 
458     return result;
459 }
460 
getSupportedExtensions()461 std::pair<ErrorStatus, hidl_vec<Extension>> VersionedIDevice::getSupportedExtensions() {
462     const std::pair<ErrorStatus, hidl_vec<Extension>> kFailure = {ErrorStatus::GENERAL_FAILURE, {}};
463     NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_COMPILATION, "getSupportedExtensions");
464     if (getDevice<V1_2::IDevice>() != nullptr) {
465         std::pair<ErrorStatus, hidl_vec<Extension>> result;
466         Return<void> ret = recoverable<void, V1_2::IDevice>(
467                 __FUNCTION__, [&result](const sp<V1_2::IDevice>& device) {
468                     return device->getSupportedExtensions(
469                             [&result](ErrorStatus error, const hidl_vec<Extension>& extensions) {
470                                 result = std::make_pair(error, extensions);
471                             });
472                 });
473         if (!ret.isOk()) {
474             LOG(ERROR) << "getSupportedExtensions failure: " << ret.description();
475             return kFailure;
476         }
477         return result;
478     } else if (getDevice<V1_0::IDevice>() != nullptr) {
479         return {ErrorStatus::NONE, {/* No extensions. */}};
480     } else {
481         LOG(ERROR) << "Device not available!";
482         return {ErrorStatus::DEVICE_UNAVAILABLE, {}};
483     }
484 }
485 
getSupportedOperations(const Model & model,IModelSlicer * slicer)486 std::pair<ErrorStatus, hidl_vec<bool>> VersionedIDevice::getSupportedOperations(
487         const Model& model, IModelSlicer* slicer) {
488     const std::pair<ErrorStatus, hidl_vec<bool>> kFailure = {ErrorStatus::GENERAL_FAILURE, {}};
489     std::pair<ErrorStatus, hidl_vec<bool>> result;
490 
491     auto noneSupported = [&model] {
492         hidl_vec<bool> supported(model.operations.size());
493         std::fill(supported.begin(), supported.end(), false);
494         return std::make_pair(ErrorStatus::NONE, std::move(supported));
495     };
496 
497     auto remappedResult = [&model](const std::pair<ErrorStatus, hidl_vec<bool>>& result,
498                                    const std::function<uint32_t(uint32_t)>&
499                                            submodelOperationIndexToModelOperationIndex) {
500         const ErrorStatus status = result.first;
501         const hidl_vec<bool>& supported = result.second;
502         hidl_vec<bool> remappedSupported(model.operations.size());
503         std::fill(remappedSupported.begin(), remappedSupported.end(), false);
504         for (size_t i = 0; i < supported.size(); ++i) {
505             if (supported[i]) {
506                 remappedSupported[submodelOperationIndexToModelOperationIndex(i)] = true;
507             }
508         }
509         return std::make_pair(status, std::move(remappedSupported));
510     };
511 
512     if (getDevice<V1_2::IDevice>() != nullptr) {
513         NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_COMPILATION, "getSupportedOperations_1_2");
514         Return<void> ret = recoverable<void, V1_2::IDevice>(
515                 __FUNCTION__, [&model, &result](const sp<V1_2::IDevice>& device) {
516                     return device->getSupportedOperations_1_2(
517                             model, [&result](ErrorStatus error, const hidl_vec<bool>& supported) {
518                                 result = std::make_pair(error, supported);
519                             });
520                 });
521         if (!ret.isOk()) {
522             LOG(ERROR) << "getSupportedOperations_1_2 failure: " << ret.description();
523             return kFailure;
524         }
525         return result;
526     }
527 
528     if (getDevice<V1_1::IDevice>() != nullptr) {
529         const bool compliant = compliantWithV1_1(model);
530         if (compliant || slicer) {
531             V1_1::Model model11;
532             std::function<uint32_t(uint32_t)> submodelOperationIndexToModelOperationIndex;
533             if (compliant) {
534                 model11 = convertToV1_1(model);
535             } else {
536                 const auto slice11 = slicer->getSliceV1_1();
537                 if (!slice11.has_value()) {
538                     return noneSupported();
539                 }
540                 std::tie(model11, submodelOperationIndexToModelOperationIndex) = *slice11;
541             }
542             NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_COMPILATION,
543                          "getSupportedOperations_1_1");
544             Return<void> ret = recoverable<void, V1_1::IDevice>(
545                     __FUNCTION__, [&model11, &result](const sp<V1_1::IDevice>& device) {
546                         return device->getSupportedOperations_1_1(
547                                 model11,
548                                 [&result](ErrorStatus error, const hidl_vec<bool>& supported) {
549                                     result = std::make_pair(error, supported);
550                                 });
551                     });
552             if (!ret.isOk()) {
553                 LOG(ERROR) << "getSupportedOperations_1_1 failure: " << ret.description();
554                 return kFailure;
555             }
556             if (!compliant) {
557                 return remappedResult(result, submodelOperationIndexToModelOperationIndex);
558             }
559         }
560         return result;
561     }
562 
563     if (getDevice<V1_0::IDevice>() != nullptr) {
564         const bool compliant = compliantWithV1_0(model);
565         if (compliant || slicer) {
566             V1_0::Model model10;
567             std::function<uint32_t(uint32_t)> submodelOperationIndexToModelOperationIndex;
568             if (compliant) {
569                 model10 = convertToV1_0(model);
570             } else {
571                 const auto slice10 = slicer->getSliceV1_0();
572                 if (!slice10.has_value()) {
573                     return noneSupported();
574                 }
575                 std::tie(model10, submodelOperationIndexToModelOperationIndex) = *slice10;
576             }
577             NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_COMPILATION, "getSupportedOperations");
578             Return<void> ret = recoverable<void, V1_0::IDevice>(
579                     __FUNCTION__, [&model10, &result](const sp<V1_0::IDevice>& device) {
580                         return device->getSupportedOperations(
581                                 model10,
582                                 [&result](ErrorStatus error, const hidl_vec<bool>& supported) {
583                                     result = std::make_pair(error, supported);
584                                 });
585                     });
586             if (!ret.isOk()) {
587                 LOG(ERROR) << "getSupportedOperations failure: " << ret.description();
588                 return kFailure;
589             }
590             if (!compliant) {
591                 return remappedResult(result, submodelOperationIndexToModelOperationIndex);
592             }
593         }
594         return result;
595     }
596 
597     return kFailure;
598 }
599 
prepareModel(const Model & model,ExecutionPreference preference,const hidl_vec<hidl_handle> & modelCache,const hidl_vec<hidl_handle> & dataCache,const HidlToken & token)600 std::pair<ErrorStatus, std::shared_ptr<VersionedIPreparedModel>> VersionedIDevice::prepareModel(
601         const Model& model, ExecutionPreference preference, const hidl_vec<hidl_handle>& modelCache,
602         const hidl_vec<hidl_handle>& dataCache, const HidlToken& token) {
603     const std::pair<ErrorStatus, std::shared_ptr<VersionedIPreparedModel>> kFailure = {
604             ErrorStatus::GENERAL_FAILURE, nullptr};
605 
606     const sp<PreparedModelCallback> callback = new (std::nothrow) PreparedModelCallback();
607     if (callback == nullptr) {
608         LOG(ERROR) << "prepareModel failed to create callback object";
609         return kFailure;
610     }
611 
612     // If 1.2 device, try preparing model
613     if (getDevice<V1_2::IDevice>() != nullptr) {
614         const Return<ErrorStatus> ret = recoverable<ErrorStatus, V1_2::IDevice>(
615                 __FUNCTION__,
616                 [&model, &preference, &modelCache, &dataCache, &token,
617                  &callback](const sp<V1_2::IDevice>& device) {
618                     return device->prepareModel_1_2(model, preference, modelCache, dataCache, token,
619                                                     callback);
620                 },
621                 callback);
622         if (!ret.isOk()) {
623             LOG(ERROR) << "prepareModel_1_2 failure: " << ret.description();
624             return kFailure;
625         }
626         if (ret != ErrorStatus::NONE) {
627             LOG(ERROR) << "prepareModel_1_2 returned " << toString(static_cast<ErrorStatus>(ret));
628             return kFailure;
629         }
630         callback->wait();
631         return {callback->getStatus(), makeVersionedIPreparedModel(callback->getPreparedModel())};
632     }
633 
634     // If 1.1 device, try preparing model (requires conversion)
635     if (getDevice<V1_1::IDevice>() != nullptr) {
636         bool compliant = false;
637         V1_1::Model model11;
638         {
639             // Attribute time spent in model inspection and conversion to
640             // Runtime, as the time may be substantial (0.03ms for mobilenet,
641             // but could be larger for other models).
642             NNTRACE_FULL_SUBTRACT(NNTRACE_LAYER_RUNTIME, NNTRACE_PHASE_COMPILATION,
643                                   "VersionedIDevice::prepareModel_1_1");
644             compliant = compliantWithV1_1(model);
645             if (compliant) {
646                 model11 = convertToV1_1(model);  // copy is elided
647             }
648         }
649         if (compliant) {
650             const Return<ErrorStatus> ret = recoverable<ErrorStatus, V1_1::IDevice>(
651                     __FUNCTION__,
652                     [&model11, &preference, &callback](const sp<V1_1::IDevice>& device) {
653                         return device->prepareModel_1_1(model11, preference, callback);
654                     },
655                     callback);
656             if (!ret.isOk()) {
657                 LOG(ERROR) << "prepareModel_1_1 failure: " << ret.description();
658                 return kFailure;
659             }
660             if (ret != ErrorStatus::NONE) {
661                 LOG(ERROR) << "prepareModel_1_1 returned "
662                            << toString(static_cast<ErrorStatus>(ret));
663                 return kFailure;
664             }
665             callback->wait();
666             return {callback->getStatus(),
667                     makeVersionedIPreparedModel(callback->getPreparedModel())};
668         }
669 
670         LOG(ERROR) << "Could not handle prepareModel_1_1!";
671         return kFailure;
672     }
673 
674     // If 1.0 device, try preparing model (requires conversion)
675     if (getDevice<V1_0::IDevice>() != nullptr) {
676         bool compliant = false;
677         V1_0::Model model10;
678         {
679             // Attribute time spent in model inspection and conversion to
680             // Runtime, as the time may be substantial (0.03ms for mobilenet,
681             // but could be larger for other models).
682             NNTRACE_FULL_SUBTRACT(NNTRACE_LAYER_RUNTIME, NNTRACE_PHASE_COMPILATION,
683                                   "VersionedIDevice::prepareModel");
684             compliant = compliantWithV1_0(model);
685             if (compliant) {
686                 model10 = convertToV1_0(model);  // copy is elided
687             }
688         }
689         if (compliant) {
690             const Return<ErrorStatus> ret = recoverable<ErrorStatus, V1_0::IDevice>(
691                     __FUNCTION__,
692                     [&model10, &callback](const sp<V1_0::IDevice>& device) {
693                         return device->prepareModel(model10, callback);
694                     },
695                     callback);
696             if (!ret.isOk()) {
697                 LOG(ERROR) << "prepareModel failure: " << ret.description();
698                 return kFailure;
699             }
700             if (ret != ErrorStatus::NONE) {
701                 LOG(ERROR) << "prepareModel returned " << toString(static_cast<ErrorStatus>(ret));
702                 return kFailure;
703             }
704             callback->wait();
705             return {callback->getStatus(),
706                     makeVersionedIPreparedModel(callback->getPreparedModel())};
707         }
708 
709         LOG(ERROR) << "Could not handle prepareModel!";
710         return kFailure;
711     }
712 
713     // Return error because there is no valid device
714     LOG(ERROR) << "prepareModel called with no device";
715     return kFailure;
716 }
717 
718 std::pair<ErrorStatus, std::shared_ptr<VersionedIPreparedModel>>
prepareModelFromCache(const hidl_vec<hidl_handle> & modelCache,const hidl_vec<hidl_handle> & dataCache,const HidlToken & token)719 VersionedIDevice::prepareModelFromCache(const hidl_vec<hidl_handle>& modelCache,
720                                         const hidl_vec<hidl_handle>& dataCache,
721                                         const HidlToken& token) {
722     const std::pair<ErrorStatus, std::shared_ptr<VersionedIPreparedModel>> kFailure = {
723             ErrorStatus::GENERAL_FAILURE, nullptr};
724 
725     const sp<PreparedModelCallback> callback = new (std::nothrow) PreparedModelCallback();
726     if (callback == nullptr) {
727         LOG(ERROR) << "prepareModelFromCache failed to create callback object";
728         return kFailure;
729     }
730 
731     if (getDevice<V1_2::IDevice>() != nullptr) {
732         const Return<ErrorStatus> ret = recoverable<ErrorStatus, V1_2::IDevice>(
733                 __FUNCTION__,
734                 [&modelCache, &dataCache, &token, &callback](const sp<V1_2::IDevice>& device) {
735                     return device->prepareModelFromCache(modelCache, dataCache, token, callback);
736                 },
737                 callback);
738         if (!ret.isOk()) {
739             LOG(ERROR) << "prepareModelFromCache failure: " << ret.description();
740             return kFailure;
741         }
742         if (ret != ErrorStatus::NONE) {
743             LOG(ERROR) << "prepareModelFromCache returned "
744                        << toString(static_cast<ErrorStatus>(ret));
745             return kFailure;
746         }
747         callback->wait();
748         return {callback->getStatus(), makeVersionedIPreparedModel(callback->getPreparedModel())};
749     }
750 
751     if (getDevice<V1_1::IDevice>() != nullptr || getDevice<V1_0::IDevice>() != nullptr) {
752         LOG(ERROR) << "prepareModelFromCache called on V1_1 or V1_0 device";
753         return kFailure;
754     }
755 
756     LOG(ERROR) << "prepareModelFromCache called with no device";
757     return kFailure;
758 }
759 
getStatus()760 DeviceStatus VersionedIDevice::getStatus() {
761     if (getDevice<V1_0::IDevice>() == nullptr) {
762         LOG(ERROR) << "Device not available!";
763         return DeviceStatus::UNKNOWN;
764     }
765 
766     Return<DeviceStatus> ret = recoverable<DeviceStatus, V1_0::IDevice>(
767             __FUNCTION__, [](const sp<V1_0::IDevice>& device) { return device->getStatus(); });
768 
769     if (!ret.isOk()) {
770         LOG(ERROR) << "getStatus failure: " << ret.description();
771         return DeviceStatus::UNKNOWN;
772     }
773     return static_cast<DeviceStatus>(ret);
774 }
775 
getFeatureLevel()776 int64_t VersionedIDevice::getFeatureLevel() {
777     constexpr int64_t kFailure = -1;
778 
779     if (getDevice<V1_2::IDevice>() != nullptr) {
780         return __ANDROID_API_Q__;
781     } else if (getDevice<V1_1::IDevice>() != nullptr) {
782         return __ANDROID_API_P__;
783     } else if (getDevice<V1_0::IDevice>() != nullptr) {
784         return __ANDROID_API_O_MR1__;
785     } else {
786         LOG(ERROR) << "Device not available!";
787         return kFailure;
788     }
789 }
790 
getType() const791 int32_t VersionedIDevice::getType() const {
792     constexpr int32_t kFailure = -1;
793     std::pair<ErrorStatus, DeviceType> result;
794 
795     if (getDevice<V1_2::IDevice>() != nullptr) {
796         Return<void> ret = recoverable<void, V1_2::IDevice>(
797                 __FUNCTION__, [&result](const sp<V1_2::IDevice>& device) {
798                     return device->getType([&result](ErrorStatus error, DeviceType deviceType) {
799                         result = std::make_pair(error, deviceType);
800                     });
801                 });
802         if (!ret.isOk()) {
803             LOG(ERROR) << "getType failure: " << ret.description();
804             return kFailure;
805         }
806         return static_cast<int32_t>(result.second);
807     } else {
808         LOG(INFO) << "Unknown NNAPI device type.";
809         return ANEURALNETWORKS_DEVICE_UNKNOWN;
810     }
811 }
812 
getVersionString()813 std::pair<ErrorStatus, hidl_string> VersionedIDevice::getVersionString() {
814     const std::pair<ErrorStatus, hidl_string> kFailure = {ErrorStatus::GENERAL_FAILURE, ""};
815     std::pair<ErrorStatus, hidl_string> result;
816 
817     if (getDevice<V1_2::IDevice>() != nullptr) {
818         Return<void> ret = recoverable<void, V1_2::IDevice>(
819                 __FUNCTION__, [&result](const sp<V1_2::IDevice>& device) {
820                     return device->getVersionString(
821                             [&result](ErrorStatus error, const hidl_string& version) {
822                                 result = std::make_pair(error, version);
823                             });
824                 });
825         if (!ret.isOk()) {
826             LOG(ERROR) << "getVersion failure: " << ret.description();
827             return kFailure;
828         }
829         return result;
830     } else if (getDevice<V1_1::IDevice>() != nullptr || getDevice<V1_0::IDevice>() != nullptr) {
831         return {ErrorStatus::NONE, "UNKNOWN"};
832     } else {
833         LOG(ERROR) << "Could not handle getVersionString";
834         return kFailure;
835     }
836 }
837 
getNumberOfCacheFilesNeeded()838 std::tuple<ErrorStatus, uint32_t, uint32_t> VersionedIDevice::getNumberOfCacheFilesNeeded() {
839     constexpr std::tuple<ErrorStatus, uint32_t, uint32_t> kFailure = {ErrorStatus::GENERAL_FAILURE,
840                                                                       0, 0};
841     std::tuple<ErrorStatus, uint32_t, uint32_t> result;
842 
843     if (getDevice<V1_2::IDevice>() != nullptr) {
844         Return<void> ret = recoverable<void, V1_2::IDevice>(
845                 __FUNCTION__, [&result](const sp<V1_2::IDevice>& device) {
846                     return device->getNumberOfCacheFilesNeeded([&result](ErrorStatus error,
847                                                                          uint32_t numModelCache,
848                                                                          uint32_t numDataCache) {
849                         result = {error, numModelCache, numDataCache};
850                     });
851                 });
852         if (!ret.isOk()) {
853             LOG(ERROR) << "getNumberOfCacheFilesNeeded failure: " << ret.description();
854             return kFailure;
855         }
856         return result;
857     } else if (getDevice<V1_1::IDevice>() != nullptr || getDevice<V1_0::IDevice>() != nullptr) {
858         return {ErrorStatus::NONE, 0, 0};
859     } else {
860         LOG(ERROR) << "Could not handle getNumberOfCacheFilesNeeded";
861         return kFailure;
862     }
863 }
864 
operator ==(nullptr_t) const865 bool VersionedIDevice::operator==(nullptr_t) const {
866     return getDevice<V1_0::IDevice>() == nullptr;
867 }
868 
operator !=(nullptr_t) const869 bool VersionedIDevice::operator!=(nullptr_t) const {
870     return getDevice<V1_0::IDevice>() != nullptr;
871 }
872 
873 }  // namespace nn
874 }  // namespace android
875