1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ResilientBurst.h"
18 
19 #include <android-base/logging.h>
20 #include <android-base/thread_annotations.h>
21 #include <nnapi/IBurst.h>
22 #include <nnapi/IPreparedModel.h>
23 #include <nnapi/Result.h>
24 #include <nnapi/TypeUtils.h>
25 #include <nnapi/Types.h>
26 
27 #include <functional>
28 #include <memory>
29 #include <mutex>
30 #include <optional>
31 #include <utility>
32 
33 #include "InvalidExecution.h"
34 #include "ResilientExecution.h"
35 
36 namespace android::hardware::neuralnetworks::utils {
37 namespace {
38 
39 template <typename FnType>
protect(const ResilientBurst & resilientBurst,const FnType & fn)40 auto protect(const ResilientBurst& resilientBurst, const FnType& fn)
41         -> decltype(fn(*resilientBurst.getBurst())) {
42     auto burst = resilientBurst.getBurst();
43     auto result = fn(*burst);
44 
45     // Immediately return if burst is not dead.
46     if (result.has_value() || result.error().code != nn::ErrorStatus::DEAD_OBJECT) {
47         return result;
48     }
49 
50     // Attempt recovery and return if it fails.
51     auto maybeBurst = resilientBurst.recover(burst.get());
52     if (!maybeBurst.has_value()) {
53         const auto& [message, code] = maybeBurst.error();
54         std::ostringstream oss;
55         oss << ", and failed to recover dead burst object with error " << code << ": " << message;
56         result.error().message += oss.str();
57         return result;
58     }
59     burst = std::move(maybeBurst).value();
60 
61     return fn(*burst);
62 }
63 
64 }  // namespace
65 
create(Factory makeBurst)66 nn::GeneralResult<std::shared_ptr<const ResilientBurst>> ResilientBurst::create(Factory makeBurst) {
67     if (makeBurst == nullptr) {
68         return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT)
69                << "utils::ResilientBurst::create must have non-empty makeBurst";
70     }
71     auto burst = NN_TRY(makeBurst());
72     CHECK(burst != nullptr);
73     return std::make_shared<ResilientBurst>(PrivateConstructorTag{}, std::move(makeBurst),
74                                             std::move(burst));
75 }
76 
ResilientBurst(PrivateConstructorTag,Factory makeBurst,nn::SharedBurst burst)77 ResilientBurst::ResilientBurst(PrivateConstructorTag /*tag*/, Factory makeBurst,
78                                nn::SharedBurst burst)
79     : kMakeBurst(std::move(makeBurst)), mBurst(std::move(burst)) {
80     CHECK(kMakeBurst != nullptr);
81     CHECK(mBurst != nullptr);
82 }
83 
getBurst() const84 nn::SharedBurst ResilientBurst::getBurst() const {
85     std::lock_guard guard(mMutex);
86     return mBurst;
87 }
88 
recover(const nn::IBurst * failingBurst) const89 nn::GeneralResult<nn::SharedBurst> ResilientBurst::recover(const nn::IBurst* failingBurst) const {
90     std::lock_guard guard(mMutex);
91 
92     // Another caller updated the failing burst.
93     if (mBurst.get() != failingBurst) {
94         return mBurst;
95     }
96 
97     mBurst = NN_TRY(kMakeBurst());
98     return mBurst;
99 }
100 
cacheMemory(const nn::SharedMemory & memory) const101 ResilientBurst::OptionalCacheHold ResilientBurst::cacheMemory(
102         const nn::SharedMemory& memory) const {
103     return getBurst()->cacheMemory(memory);
104 }
105 
execute(const nn::Request & request,nn::MeasureTiming measure,const nn::OptionalTimePoint & deadline,const nn::OptionalDuration & loopTimeoutDuration) const106 nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> ResilientBurst::execute(
107         const nn::Request& request, nn::MeasureTiming measure,
108         const nn::OptionalTimePoint& deadline,
109         const nn::OptionalDuration& loopTimeoutDuration) const {
110     const auto fn = [&request, measure, deadline, loopTimeoutDuration](const nn::IBurst& burst) {
111         return burst.execute(request, measure, deadline, loopTimeoutDuration);
112     };
113     return protect(*this, fn);
114 }
115 
createReusableExecution(const nn::Request & request,nn::MeasureTiming measure,const nn::OptionalDuration & loopTimeoutDuration) const116 nn::GeneralResult<nn::SharedExecution> ResilientBurst::createReusableExecution(
117         const nn::Request& request, nn::MeasureTiming measure,
118         const nn::OptionalDuration& loopTimeoutDuration) const {
119 #if 0
120     auto self = shared_from_this();
121     ResilientExecution::Factory makeExecution =
122             [burst = std::move(self), request, measure, loopTimeoutDuration] {
123         return burst->createReusableExecutionInternal(request, measure, loopTimeoutDuration);
124     };
125     return ResilientExecution::create(std::move(makeExecution));
126 #else
127     return createReusableExecutionInternal(request, measure, loopTimeoutDuration);
128 #endif
129 }
130 
createReusableExecutionInternal(const nn::Request & request,nn::MeasureTiming measure,const nn::OptionalDuration & loopTimeoutDuration) const131 nn::GeneralResult<nn::SharedExecution> ResilientBurst::createReusableExecutionInternal(
132         const nn::Request& request, nn::MeasureTiming measure,
133         const nn::OptionalDuration& loopTimeoutDuration) const {
134     if (!isValidInternal()) {
135         return std::make_shared<const InvalidExecution>();
136     }
137     const auto fn = [&request, measure, &loopTimeoutDuration](const nn::IBurst& burst) {
138         return burst.createReusableExecution(request, measure, loopTimeoutDuration);
139     };
140     return protect(*this, fn);
141 }
142 
isValidInternal() const143 bool ResilientBurst::isValidInternal() const {
144     return true;
145 }
146 
147 }  // namespace android::hardware::neuralnetworks::utils
148