1 /*
2 * Copyright (C) 2020 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef ANDROID_PACKAGES_MODULES_NEURALNETWORKS_COMMON_TYPES_NNAPI_TYPE_UTILS_H
18 #define ANDROID_PACKAGES_MODULES_NEURALNETWORKS_COMMON_TYPES_NNAPI_TYPE_UTILS_H
19
20 #include <android-base/expected.h>
21 #include <android-base/logging.h>
22 #include <android-base/macros.h>
23
24 #include <ostream>
25 #include <string>
26 #include <utility>
27 #include <vector>
28
29 #include "nnapi/OperandTypes.h"
30 #include "nnapi/OperationTypes.h"
31 #include "nnapi/Result.h"
32 #include "nnapi/Types.h"
33
34 namespace android::nn {
35
36 bool isExtension(OperandType type);
37 bool isExtension(OperationType type);
38
39 bool isNonExtensionScalar(OperandType operandType);
40
41 size_t getNonExtensionSize(OperandType operandType);
42
getExtensionPrefix(uint32_t type)43 inline uint16_t getExtensionPrefix(uint32_t type) {
44 return static_cast<uint16_t>(type >> kExtensionTypeBits);
45 }
46
getTypeWithinExtension(uint32_t type)47 inline uint16_t getTypeWithinExtension(uint32_t type) {
48 return static_cast<uint16_t>(type & kTypeWithinExtensionMask);
49 }
50
51 std::optional<size_t> getNonExtensionSize(OperandType operandType, const Dimensions& dimensions);
52 std::optional<size_t> getNonExtensionSize(const Operand& operand);
53
54 bool tensorHasUnspecifiedDimensions(OperandType type, const Dimensions& dimensions);
55 bool tensorHasUnspecifiedDimensions(const Operand& operand);
56
57 size_t getOffsetFromInts(int lower, int higher);
58 std::pair<int32_t, int32_t> getIntsFromOffset(size_t offset);
59
60 Result<std::vector<uint32_t>> countNumberOfConsumers(size_t numberOfOperands,
61 const std::vector<nn::Operation>& operations);
62
63 // Combine two tensor dimensions, both may have unspecified dimensions or rank.
64 Result<Dimensions> combineDimensions(const Dimensions& lhs, const Dimensions& rhs);
65
66 // Returns the operandValues's size and a size for each pool in the provided model.
67 std::pair<size_t, std::vector<size_t>> getMemorySizes(const Model& model);
68
69 // Round up "size" to the nearest multiple of "multiple". "multiple" must be a power of 2.
70 size_t roundUp(size_t size, size_t multiple);
71
72 // Returns the alignment for data of the specified length. It aligns object of length:
73 // 2, 3 on a 2 byte boundary,
74 // 4+ on a 4 byte boundary.
75 // We may want to have different alignments for tensors.
76 // TODO: This is arbitrary, more a proof of concept. We need to determine what this should be.
77 //
78 // Note that Types.cpp ensures `new` has sufficient alignment for all alignments returned by this
79 // function. If this function is changed to return different alignments (e.g., 8 byte boundary
80 // alignment), the code check in Types.cpp similarly needs to be updated.
81 size_t getAlignmentForLength(size_t length);
82
83 // Make Capabilities provided three granularities of performance info.
84 Capabilities makeCapabilities(const Capabilities::PerformanceInfo& defaultInfo,
85 const Capabilities::PerformanceInfo& float32Info,
86 const Capabilities::PerformanceInfo& relaxedInfo);
87
88 // Set of output utility functions.
89 std::ostream& operator<<(std::ostream& os, const DeviceStatus& deviceStatus);
90 std::ostream& operator<<(std::ostream& os, const ExecutionPreference& executionPreference);
91 std::ostream& operator<<(std::ostream& os, const DeviceType& deviceType);
92 std::ostream& operator<<(std::ostream& os, const MeasureTiming& measureTiming);
93 std::ostream& operator<<(std::ostream& os, const OperandType& operandType);
94 std::ostream& operator<<(std::ostream& os, const Operand::LifeTime& lifetime);
95 std::ostream& operator<<(std::ostream& os, const OperationType& operationType);
96 std::ostream& operator<<(std::ostream& os, const Request::Argument::LifeTime& lifetime);
97 std::ostream& operator<<(std::ostream& os, const Priority& priority);
98 std::ostream& operator<<(std::ostream& os, const ErrorStatus& errorStatus);
99 std::ostream& operator<<(std::ostream& os, const FusedActivationFunc& activation);
100 std::ostream& operator<<(std::ostream& os, const OutputShape& outputShape);
101 std::ostream& operator<<(std::ostream& os, const Timing& timing);
102 std::ostream& operator<<(std::ostream& os, const Capabilities::PerformanceInfo& performanceInfo);
103 std::ostream& operator<<(std::ostream& os,
104 const Capabilities::OperandPerformance& operandPerformance);
105 std::ostream& operator<<(std::ostream& os,
106 const Capabilities::OperandPerformanceTable& operandPerformances);
107 std::ostream& operator<<(std::ostream& os, const Capabilities& capabilities);
108 std::ostream& operator<<(std::ostream& os,
109 const Extension::OperandTypeInformation& operandTypeInformation);
110 std::ostream& operator<<(std::ostream& os, const Extension& extension);
111 std::ostream& operator<<(std::ostream& os, const DataLocation& location);
112 std::ostream& operator<<(std::ostream& os,
113 const Operand::SymmPerChannelQuantParams& symmPerChannelQuantParams);
114 std::ostream& operator<<(std::ostream& os, const Operand::ExtraParams& extraParams);
115 std::ostream& operator<<(std::ostream& os, const Operand& operand);
116 std::ostream& operator<<(std::ostream& os, const Operation& operation);
117 std::ostream& operator<<(std::ostream& os, const SharedHandle& handle);
118 std::ostream& operator<<(std::ostream& os, const Memory& memory);
119 std::ostream& operator<<(std::ostream& os, const SharedMemory& memory);
120 std::ostream& operator<<(std::ostream& os, const MemoryPreference& memoryPreference);
121 std::ostream& operator<<(std::ostream& os, const Model::Subgraph& subgraph);
122 std::ostream& operator<<(std::ostream& os, const Model::OperandValues& operandValues);
123 std::ostream& operator<<(std::ostream& os, const ExtensionNameAndPrefix& extensionNameAndPrefix);
124 std::ostream& operator<<(std::ostream& os, const Model& model);
125 std::ostream& operator<<(std::ostream& os, const BufferDesc& bufferDesc);
126 std::ostream& operator<<(std::ostream& os, const BufferRole& bufferRole);
127 std::ostream& operator<<(std::ostream& os, const Request::Argument& requestArgument);
128 std::ostream& operator<<(std::ostream& os, const Request::MemoryPool& memoryPool);
129 std::ostream& operator<<(std::ostream& os, const Request& request);
130 std::ostream& operator<<(std::ostream& os, const SyncFence::FenceState& fenceState);
131 std::ostream& operator<<(std::ostream& os, const TimePoint& timePoint);
132 std::ostream& operator<<(std::ostream& os, const OptionalTimePoint& optionalTimePoint);
133 std::ostream& operator<<(std::ostream& os, const Duration& timeoutDuration);
134 std::ostream& operator<<(std::ostream& os, const OptionalDuration& optionalTimeoutDuration);
135 std::ostream& operator<<(std::ostream& os, const Version::Level& versionLevel);
136 std::ostream& operator<<(std::ostream& os, const Version& version);
137
138 bool operator==(const Timing& a, const Timing& b);
139 bool operator!=(const Timing& a, const Timing& b);
140 bool operator==(const Capabilities::PerformanceInfo& a, const Capabilities::PerformanceInfo& b);
141 bool operator!=(const Capabilities::PerformanceInfo& a, const Capabilities::PerformanceInfo& b);
142 bool operator==(const Capabilities::OperandPerformance& a,
143 const Capabilities::OperandPerformance& b);
144 bool operator!=(const Capabilities::OperandPerformance& a,
145 const Capabilities::OperandPerformance& b);
146 bool operator==(const Capabilities& a, const Capabilities& b);
147 bool operator!=(const Capabilities& a, const Capabilities& b);
148 bool operator==(const Extension::OperandTypeInformation& a,
149 const Extension::OperandTypeInformation& b);
150 bool operator!=(const Extension::OperandTypeInformation& a,
151 const Extension::OperandTypeInformation& b);
152 bool operator==(const Extension& a, const Extension& b);
153 bool operator!=(const Extension& a, const Extension& b);
154 bool operator==(const MemoryPreference& a, const MemoryPreference& b);
155 bool operator!=(const MemoryPreference& a, const MemoryPreference& b);
156 bool operator==(const Operand::SymmPerChannelQuantParams& a,
157 const Operand::SymmPerChannelQuantParams& b);
158 bool operator!=(const Operand::SymmPerChannelQuantParams& a,
159 const Operand::SymmPerChannelQuantParams& b);
160 bool operator==(const Operand& a, const Operand& b);
161 bool operator!=(const Operand& a, const Operand& b);
162 bool operator==(const Operation& a, const Operation& b);
163 bool operator!=(const Operation& a, const Operation& b);
164 bool operator==(const Version& a, const Version& b);
165 bool operator!=(const Version& a, const Version& b);
166
toString(uint32_t obj)167 inline std::string toString(uint32_t obj) {
168 return std::to_string(obj);
169 }
170
171 template <typename A, typename B>
toString(const std::pair<A,B> & pair)172 std::string toString(const std::pair<A, B>& pair) {
173 std::ostringstream oss;
174 oss << "(" << pair.first << ", " << pair.second << ")";
175 return oss.str();
176 }
177
178 template <typename Type>
toString(const std::vector<Type> & vec)179 std::string toString(const std::vector<Type>& vec) {
180 std::string os = "[";
181 for (size_t i = 0; i < vec.size(); ++i) {
182 os += (i == 0 ? "" : ", ") + toString(vec[i]);
183 }
184 return os += "]";
185 }
186
187 /* IMPORTANT: if you change the following list, don't
188 * forget to update the corresponding 'tags' table in
189 * the initVlogMask() function implemented in Utils.cpp.
190 */
191 enum VLogFlags { MODEL = 0, COMPILATION, EXECUTION, CPUEXE, MANAGER, DRIVER, MEMORY };
192
193 #define VLOG_IS_ON(TAG) ((vLogMask & (1 << (TAG))) != 0)
194
195 #define VLOG(TAG) \
196 if (LIKELY(!VLOG_IS_ON(TAG))) \
197 ; \
198 else \
199 LOG(INFO)
200
201 extern int vLogMask;
202 void initVLogMask();
203
204 // The NN_RET_CHECK family of macros defined below is similar to the CHECK family defined in
205 // system/libbase/include/android-base/logging.h
206 //
207 // The difference is that NN_RET_CHECK macros use LOG(ERROR) instead of LOG(FATAL)
208 // and return false instead of aborting.
209
210 // Logs an error and returns false. Append context using << after. For example:
211 //
212 // NN_RET_CHECK_FAIL() << "Something went wrong";
213 //
214 // The containing function must return a bool or a base::expected (including nn::Result,
215 // nn::GeneralResult, and nn::ExecutionResult).
216 #define NN_RET_CHECK_FAIL() \
217 return ::android::nn::NnRetCheckErrorStream() \
218 << "NN_RET_CHECK failed (" << __FILE__ << ":" << __LINE__ << "): "
219
220 // Logs an error and returns false if condition is false. Extra logging can be appended using <<
221 // after. For example:
222 //
223 // NN_RET_CHECK(false) << "Something went wrong";
224 //
225 // The containing function must return a bool or a base::expected (including nn::Result,
226 // nn::GeneralResult, and nn::ExecutionResult).
227 #define NN_RET_CHECK(condition) \
228 while (UNLIKELY(!(condition))) NN_RET_CHECK_FAIL() << #condition << " "
229
230 // Helper for NN_CHECK_xx(x, y) macros.
231 #define NN_RET_CHECK_OP(LHS, RHS, OP) \
232 for (auto _values = ::android::base::MakeEagerEvaluator(LHS, RHS); \
233 UNLIKELY(!(_values.lhs.v OP _values.rhs.v)); \
234 /* empty */) \
235 NN_RET_CHECK_FAIL() \
236 << #LHS << " " << #OP << " " << #RHS << " (" << #LHS << " = " \
237 << ::android::base::LogNullGuard<decltype(_values.lhs.v)>::Guard(_values.lhs.v) \
238 << ", " << #RHS << " = " \
239 << ::android::base::LogNullGuard<decltype(_values.rhs.v)>::Guard(_values.rhs.v) \
240 << ") "
241
242 // Logs an error and returns false if a condition between x and y does not hold. Extra logging can
243 // be appended using << after. For example:
244 //
245 // NN_RET_CHECK_EQ(a, b) << "Something went wrong";
246 //
247 // The values must implement the appropriate comparison operator as well as
248 // `operator<<(std::ostream&, ...)`.
249 // The containing function must return a bool or a base::expected (including nn::Result,
250 // nn::GeneralResult, and nn::ExecutionResult).
251 #define NN_RET_CHECK_EQ(x, y) NN_RET_CHECK_OP(x, y, ==)
252 #define NN_RET_CHECK_NE(x, y) NN_RET_CHECK_OP(x, y, !=)
253 #define NN_RET_CHECK_LE(x, y) NN_RET_CHECK_OP(x, y, <=)
254 #define NN_RET_CHECK_LT(x, y) NN_RET_CHECK_OP(x, y, <)
255 #define NN_RET_CHECK_GE(x, y) NN_RET_CHECK_OP(x, y, >=)
256 #define NN_RET_CHECK_GT(x, y) NN_RET_CHECK_OP(x, y, >)
257
258 // Ensure that every user of NnRetCheckErrorStream is linked to the
259 // correct instance, using the correct LOG_TAG
260 namespace {
261
262 // A wrapper around an error message that can be implicitly converted to bool (always evaluates to
263 // false) and logs via LOG(ERROR) or base::expected (always evaluates to base::unexpected). Used to
264 // implement stream logging in NN_RET_CHECK.
265 class NnRetCheckErrorStream {
266 DISALLOW_COPY_AND_ASSIGN(NnRetCheckErrorStream);
267
268 public:
269 constexpr NnRetCheckErrorStream() = default;
270
271 template <typename T>
272 NnRetCheckErrorStream& operator<<(const T& value) {
273 (*mBuffer) << value;
274 return *this;
275 }
276
~NnRetCheckErrorStream()277 ~NnRetCheckErrorStream() {
278 if (mBuffer.has_value()) {
279 LOG(ERROR) << mBuffer->str();
280 }
281 }
282
283 constexpr operator bool() const { return false; } // NOLINT(google-explicit-constructor)
284
285 template <typename T, typename E>
286 constexpr operator base::expected<T, E>() { // NOLINT(google-explicit-constructor)
287 auto result = base::unexpected<E>(std::move(mBuffer)->str());
288 mBuffer.reset();
289 return result;
290 }
291
292 // This is needed because conversion to Result<int> is ambiguous
293 // due to the above bool() operator overload
294 operator Result<int>() { // NOLINT(google-explicit-constructor)
295 auto result = base::unexpected(std::move(mBuffer)->str());
296 mBuffer.reset();
297 return result;
298 }
299
300 private:
301 std::optional<std::ostringstream> mBuffer = std::ostringstream{};
302 };
303
304 } // namespace
305
306 } // namespace android::nn
307
308 #endif // ANDROID_PACKAGES_MODULES_NEURALNETWORKS_COMMON_TYPES_NNAPI_TYPE_UTILS_H
309