1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ResilientBuffer.h"
18 
19 #include <android-base/logging.h>
20 #include <android-base/thread_annotations.h>
21 #include <nnapi/IBuffer.h>
22 #include <nnapi/Result.h>
23 #include <nnapi/TypeUtils.h>
24 #include <nnapi/Types.h>
25 
26 #include <functional>
27 #include <memory>
28 #include <mutex>
29 #include <utility>
30 #include <vector>
31 
32 namespace android::hardware::neuralnetworks::utils {
33 namespace {
34 
35 template <typename FnType>
protect(const ResilientBuffer & resilientBuffer,const FnType & fn)36 auto protect(const ResilientBuffer& resilientBuffer, const FnType& fn)
37         -> decltype(fn(*resilientBuffer.getBuffer())) {
38     auto buffer = resilientBuffer.getBuffer();
39     auto result = fn(*buffer);
40 
41     // Immediately return if device is not dead.
42     if (result.has_value() || result.error().code != nn::ErrorStatus::DEAD_OBJECT) {
43         return result;
44     }
45 
46     // Attempt recovery and return if it fails.
47     auto maybeBuffer = resilientBuffer.recover(buffer.get());
48     if (!maybeBuffer.has_value()) {
49         const auto& [resultErrorMessage, resultErrorCode] = result.error();
50         const auto& [recoveryErrorMessage, recoveryErrorCode] = maybeBuffer.error();
51         return nn::error(resultErrorCode)
52                << resultErrorMessage << ", and failed to recover dead buffer with error "
53                << recoveryErrorCode << ": " << recoveryErrorMessage;
54     }
55     buffer = std::move(maybeBuffer).value();
56 
57     return fn(*buffer);
58 }
59 
60 }  // namespace
61 
create(Factory makeBuffer)62 nn::GeneralResult<std::shared_ptr<const ResilientBuffer>> ResilientBuffer::create(
63         Factory makeBuffer) {
64     if (makeBuffer == nullptr) {
65         return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT)
66                << "utils::ResilientBuffer::create must have non-empty makeBuffer";
67     }
68     auto buffer = NN_TRY(makeBuffer());
69     CHECK(buffer != nullptr);
70     return std::make_shared<const ResilientBuffer>(PrivateConstructorTag{}, std::move(makeBuffer),
71                                                    std::move(buffer));
72 }
73 
ResilientBuffer(PrivateConstructorTag,Factory makeBuffer,nn::SharedBuffer buffer)74 ResilientBuffer::ResilientBuffer(PrivateConstructorTag /*tag*/, Factory makeBuffer,
75                                  nn::SharedBuffer buffer)
76     : kMakeBuffer(std::move(makeBuffer)), mBuffer(std::move(buffer)) {
77     CHECK(kMakeBuffer != nullptr);
78     CHECK(mBuffer != nullptr);
79 }
80 
getBuffer() const81 nn::SharedBuffer ResilientBuffer::getBuffer() const {
82     std::lock_guard guard(mMutex);
83     return mBuffer;
84 }
recover(const nn::IBuffer * failingBuffer) const85 nn::GeneralResult<nn::SharedBuffer> ResilientBuffer::recover(
86         const nn::IBuffer* failingBuffer) const {
87     std::lock_guard guard(mMutex);
88 
89     // Another caller updated the failing prepared model.
90     if (mBuffer.get() != failingBuffer) {
91         return mBuffer;
92     }
93 
94     mBuffer = NN_TRY(kMakeBuffer());
95     return mBuffer;
96 }
97 
getToken() const98 nn::Request::MemoryDomainToken ResilientBuffer::getToken() const {
99     return getBuffer()->getToken();
100 }
101 
copyTo(const nn::SharedMemory & dst) const102 nn::GeneralResult<void> ResilientBuffer::copyTo(const nn::SharedMemory& dst) const {
103     const auto fn = [&dst](const nn::IBuffer& buffer) { return buffer.copyTo(dst); };
104     return protect(*this, fn);
105 }
106 
copyFrom(const nn::SharedMemory & src,const nn::Dimensions & dimensions) const107 nn::GeneralResult<void> ResilientBuffer::copyFrom(const nn::SharedMemory& src,
108                                                   const nn::Dimensions& dimensions) const {
109     const auto fn = [&src, &dimensions](const nn::IBuffer& buffer) {
110         return buffer.copyFrom(src, dimensions);
111     };
112     return protect(*this, fn);
113 }
114 
115 }  // namespace android::hardware::neuralnetworks::utils
116