1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "Types.h"
18 
19 #include <android-base/logging.h>
20 #include <errno.h>
21 #include <poll.h>
22 
23 #include <algorithm>
24 #include <cstddef>
25 #include <iterator>
26 #include <limits>
27 #include <memory>
28 #include <optional>
29 #include <utility>
30 #include <vector>
31 
32 #include "OperandTypes.h"
33 #include "OperationTypes.h"
34 #include "Result.h"
35 #include "TypeUtils.h"
36 #include "Validation.h"
37 
38 namespace android::nn {
39 
40 // Ensure that std::vector<uint8_t>::data() will always have sufficient alignment to hold all NNAPI
41 // primitive types. "4" is chosen because that is the maximum alignment returned by
42 // `getAlignmentForLength`. However, this value will have to be changed if `getAlignmentForLength`
43 // returns a larger alignment.
44 static_assert(__STDCPP_DEFAULT_NEW_ALIGNMENT__ >= 4, "`New` alignment is not sufficient");
45 
GeneralError(std::string message,ErrorStatus code)46 GeneralError::GeneralError(std::string message, ErrorStatus code)
47     : message(std::move(message)), code(code) {}
48 
ExecutionError(std::string message,ErrorStatus code,std::vector<OutputShape> outputShapes)49 ExecutionError::ExecutionError(std::string message, ErrorStatus code,
50                                std::vector<OutputShape> outputShapes)
51     : message(std::move(message)), code(code), outputShapes(std::move(outputShapes)) {}
52 
ExecutionError(GeneralError error)53 ExecutionError::ExecutionError(GeneralError error)
54     : message(std::move(error.message)), code(error.code) {}
55 
OperandValues()56 Model::OperandValues::OperandValues() {
57     constexpr size_t kNumberBytes = 4 * 1024;
58     mData.reserve(kNumberBytes);
59 }
60 
OperandValues(const uint8_t * data,size_t length)61 Model::OperandValues::OperandValues(const uint8_t* data, size_t length)
62     : mData(data, data + length) {}
63 
append(const uint8_t * data,size_t length)64 DataLocation Model::OperandValues::append(const uint8_t* data, size_t length) {
65     CHECK_GT(length, 0u);
66     CHECK_LE(length, std::numeric_limits<uint32_t>::max());
67     const size_t alignment = getAlignmentForLength(length);
68     const size_t offset = roundUp(size(), alignment);
69     CHECK_LE(offset, std::numeric_limits<uint32_t>::max());
70     mData.resize(offset + length);
71     CHECK_LE(size(), std::numeric_limits<uint32_t>::max());
72     std::memcpy(mData.data() + offset, data, length);
73     return {.offset = static_cast<uint32_t>(offset), .length = static_cast<uint32_t>(length)};
74 }
75 
data() const76 const uint8_t* Model::OperandValues::data() const {
77     return mData.data();
78 }
79 
size() const80 size_t Model::OperandValues::size() const {
81     return mData.size();
82 }
83 
OperandPerformanceTable(std::vector<OperandPerformance> operandPerformances)84 Capabilities::OperandPerformanceTable::OperandPerformanceTable(
85         std::vector<OperandPerformance> operandPerformances)
86     : mSorted(std::move(operandPerformances)) {}
87 
create(std::vector<OperandPerformance> operandPerformances)88 Result<Capabilities::OperandPerformanceTable> Capabilities::OperandPerformanceTable::create(
89         std::vector<OperandPerformance> operandPerformances) {
90     const auto notUnique = [](const auto& lhs, const auto& rhs) { return !(lhs.type < rhs.type); };
91     const bool isUnique = std::adjacent_find(operandPerformances.begin(), operandPerformances.end(),
92                                              notUnique) == operandPerformances.end();
93     if (!isUnique) {
94         return NN_ERROR() << "Failed to create OperandPerformanceTable: Input must be sorted by "
95                              "key (in ascending order), and there must be no duplicate keys";
96     }
97 
98     return Capabilities::OperandPerformanceTable(std::move(operandPerformances));
99 }
100 
lookup(OperandType operandType) const101 Capabilities::PerformanceInfo Capabilities::OperandPerformanceTable::lookup(
102         OperandType operandType) const {
103     // Search for operand type in the sorted collection.
104     constexpr auto cmp = [](const auto& performance, auto type) { return performance.type < type; };
105     const auto it = std::lower_bound(mSorted.begin(), mSorted.end(), operandType, cmp);
106 
107     // If the operand type is found, return its corresponding info.
108     if (it != mSorted.end() && it->type == operandType) {
109         return it->info;
110     }
111 
112     // If no performance info is defined, use the default value (float's max).
113     return Capabilities::PerformanceInfo{};
114 }
115 
116 const std::vector<Capabilities::OperandPerformance>&
asVector() const117 Capabilities::OperandPerformanceTable::asVector() const {
118     return mSorted;
119 }
120 
createAsSignaled()121 SyncFence SyncFence::createAsSignaled() {
122     return SyncFence(nullptr);
123 }
124 
create(base::unique_fd fd)125 SyncFence SyncFence::create(base::unique_fd fd) {
126     CHECK(fd.ok());
127     return SyncFence(std::make_shared<const Handle>(std::move(fd)));
128 }
129 
create(SharedHandle syncFence)130 Result<SyncFence> SyncFence::create(SharedHandle syncFence) {
131     if (!validate(syncFence).ok()) {
132         return NN_ERROR() << "Invalid sync fence handle passed to SyncFence::create";
133     }
134     return SyncFence(std::move(syncFence));
135 }
136 
SyncFence(SharedHandle syncFence)137 SyncFence::SyncFence(SharedHandle syncFence) : mSyncFence(std::move(syncFence)) {}
138 
syncWait(OptionalTimeout optionalTimeout) const139 SyncFence::FenceState SyncFence::syncWait(OptionalTimeout optionalTimeout) const {
140     if (mSyncFence == nullptr) {
141         return FenceState::SIGNALED;
142     }
143 
144     const int fd = mSyncFence->get();
145     const int timeout = optionalTimeout.value_or(Timeout{-1}).count();
146 
147     // This implementation is directly based on the ::sync_wait() implementation.
148 
149     struct pollfd fds;
150     int ret;
151 
152     if (fd < 0) {
153         errno = EINVAL;
154         return FenceState::UNKNOWN;
155     }
156 
157     fds.fd = fd;
158     fds.events = POLLIN;
159 
160     do {
161         ret = poll(&fds, 1, timeout);
162         if (ret > 0) {
163             if (fds.revents & POLLNVAL) {
164                 errno = EINVAL;
165                 return FenceState::UNKNOWN;
166             }
167             if (fds.revents & POLLERR) {
168                 errno = EINVAL;
169                 return FenceState::ERROR;
170             }
171             return FenceState::SIGNALED;
172         } else if (ret == 0) {
173             errno = ETIME;
174             return FenceState::ACTIVE;
175         }
176     } while (ret == -1 && (errno == EINTR || errno == EAGAIN));
177 
178     return FenceState::UNKNOWN;
179 }
180 
getSharedHandle() const181 SharedHandle SyncFence::getSharedHandle() const {
182     return mSyncFence;
183 }
184 
hasFd() const185 bool SyncFence::hasFd() const {
186     return mSyncFence != nullptr;
187 }
188 
getFd() const189 int SyncFence::getFd() const {
190     return mSyncFence == nullptr ? -1 : mSyncFence->get();
191 }
192 
193 }  // namespace android::nn
194