1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef ANDROID_PACKAGES_MODULES_NEURALNETWORKS_COMMON_TYPES_NNAPI_TYPE_UTILS_H
18 #define ANDROID_PACKAGES_MODULES_NEURALNETWORKS_COMMON_TYPES_NNAPI_TYPE_UTILS_H
19 
20 #include <android-base/expected.h>
21 #include <android-base/logging.h>
22 #include <android-base/macros.h>
23 
24 #include <ostream>
25 #include <string>
26 #include <utility>
27 #include <vector>
28 
29 #include "nnapi/OperandTypes.h"
30 #include "nnapi/OperationTypes.h"
31 #include "nnapi/Result.h"
32 #include "nnapi/Types.h"
33 
34 namespace android::nn {
35 
36 bool isExtension(OperandType type);
37 bool isExtension(OperationType type);
38 
39 bool isNonExtensionScalar(OperandType operandType);
40 
41 size_t getNonExtensionSize(OperandType operandType);
42 
getExtensionPrefix(uint32_t type)43 inline uint16_t getExtensionPrefix(uint32_t type) {
44     return static_cast<uint16_t>(type >> kExtensionTypeBits);
45 }
46 
getTypeWithinExtension(uint32_t type)47 inline uint16_t getTypeWithinExtension(uint32_t type) {
48     return static_cast<uint16_t>(type & kTypeWithinExtensionMask);
49 }
50 
51 std::optional<size_t> getNonExtensionSize(OperandType operandType, const Dimensions& dimensions);
52 std::optional<size_t> getNonExtensionSize(const Operand& operand);
53 
54 bool tensorHasUnspecifiedDimensions(OperandType type, const Dimensions& dimensions);
55 bool tensorHasUnspecifiedDimensions(const Operand& operand);
56 
57 size_t getOffsetFromInts(int lower, int higher);
58 std::pair<int32_t, int32_t> getIntsFromOffset(size_t offset);
59 
60 Result<std::vector<uint32_t>> countNumberOfConsumers(size_t numberOfOperands,
61                                                      const std::vector<nn::Operation>& operations);
62 
63 // Combine two tensor dimensions, both may have unspecified dimensions or rank.
64 Result<Dimensions> combineDimensions(const Dimensions& lhs, const Dimensions& rhs);
65 
66 // Returns the operandValues's size and a size for each pool in the provided model.
67 std::pair<size_t, std::vector<size_t>> getMemorySizes(const Model& model);
68 
69 // Round up "size" to the nearest multiple of "multiple". "multiple" must be a power of 2.
70 size_t roundUp(size_t size, size_t multiple);
71 
72 // Returns the alignment for data of the specified length.  It aligns object of length:
73 // 2, 3 on a 2 byte boundary,
74 // 4+ on a 4 byte boundary.
75 // We may want to have different alignments for tensors.
76 // TODO: This is arbitrary, more a proof of concept.  We need to determine what this should be.
77 //
78 // Note that Types.cpp ensures `new` has sufficient alignment for all alignments returned by this
79 // function. If this function is changed to return different alignments (e.g., 8 byte boundary
80 // alignment), the code check in Types.cpp similarly needs to be updated.
81 size_t getAlignmentForLength(size_t length);
82 
83 // Make Capabilities provided three granularities of performance info.
84 Capabilities makeCapabilities(const Capabilities::PerformanceInfo& defaultInfo,
85                               const Capabilities::PerformanceInfo& float32Info,
86                               const Capabilities::PerformanceInfo& relaxedInfo);
87 
88 // Set of output utility functions.
89 std::ostream& operator<<(std::ostream& os, const DeviceStatus& deviceStatus);
90 std::ostream& operator<<(std::ostream& os, const ExecutionPreference& executionPreference);
91 std::ostream& operator<<(std::ostream& os, const DeviceType& deviceType);
92 std::ostream& operator<<(std::ostream& os, const MeasureTiming& measureTiming);
93 std::ostream& operator<<(std::ostream& os, const OperandType& operandType);
94 std::ostream& operator<<(std::ostream& os, const Operand::LifeTime& lifetime);
95 std::ostream& operator<<(std::ostream& os, const OperationType& operationType);
96 std::ostream& operator<<(std::ostream& os, const Request::Argument::LifeTime& lifetime);
97 std::ostream& operator<<(std::ostream& os, const Priority& priority);
98 std::ostream& operator<<(std::ostream& os, const ErrorStatus& errorStatus);
99 std::ostream& operator<<(std::ostream& os, const FusedActivationFunc& activation);
100 std::ostream& operator<<(std::ostream& os, const OutputShape& outputShape);
101 std::ostream& operator<<(std::ostream& os, const Timing& timing);
102 std::ostream& operator<<(std::ostream& os, const Capabilities::PerformanceInfo& performanceInfo);
103 std::ostream& operator<<(std::ostream& os,
104                          const Capabilities::OperandPerformance& operandPerformance);
105 std::ostream& operator<<(std::ostream& os,
106                          const Capabilities::OperandPerformanceTable& operandPerformances);
107 std::ostream& operator<<(std::ostream& os, const Capabilities& capabilities);
108 std::ostream& operator<<(std::ostream& os,
109                          const Extension::OperandTypeInformation& operandTypeInformation);
110 std::ostream& operator<<(std::ostream& os, const Extension& extension);
111 std::ostream& operator<<(std::ostream& os, const DataLocation& location);
112 std::ostream& operator<<(std::ostream& os,
113                          const Operand::SymmPerChannelQuantParams& symmPerChannelQuantParams);
114 std::ostream& operator<<(std::ostream& os, const Operand::ExtraParams& extraParams);
115 std::ostream& operator<<(std::ostream& os, const Operand& operand);
116 std::ostream& operator<<(std::ostream& os, const Operation& operation);
117 std::ostream& operator<<(std::ostream& os, const SharedHandle& handle);
118 std::ostream& operator<<(std::ostream& os, const Memory& memory);
119 std::ostream& operator<<(std::ostream& os, const SharedMemory& memory);
120 std::ostream& operator<<(std::ostream& os, const MemoryPreference& memoryPreference);
121 std::ostream& operator<<(std::ostream& os, const Model::Subgraph& subgraph);
122 std::ostream& operator<<(std::ostream& os, const Model::OperandValues& operandValues);
123 std::ostream& operator<<(std::ostream& os, const ExtensionNameAndPrefix& extensionNameAndPrefix);
124 std::ostream& operator<<(std::ostream& os, const Model& model);
125 std::ostream& operator<<(std::ostream& os, const BufferDesc& bufferDesc);
126 std::ostream& operator<<(std::ostream& os, const BufferRole& bufferRole);
127 std::ostream& operator<<(std::ostream& os, const Request::Argument& requestArgument);
128 std::ostream& operator<<(std::ostream& os, const Request::MemoryPool& memoryPool);
129 std::ostream& operator<<(std::ostream& os, const Request& request);
130 std::ostream& operator<<(std::ostream& os, const SyncFence::FenceState& fenceState);
131 std::ostream& operator<<(std::ostream& os, const TimePoint& timePoint);
132 std::ostream& operator<<(std::ostream& os, const OptionalTimePoint& optionalTimePoint);
133 std::ostream& operator<<(std::ostream& os, const Duration& timeoutDuration);
134 std::ostream& operator<<(std::ostream& os, const OptionalDuration& optionalTimeoutDuration);
135 std::ostream& operator<<(std::ostream& os, const Version::Level& versionLevel);
136 std::ostream& operator<<(std::ostream& os, const Version& version);
137 
138 bool operator==(const Timing& a, const Timing& b);
139 bool operator!=(const Timing& a, const Timing& b);
140 bool operator==(const Capabilities::PerformanceInfo& a, const Capabilities::PerformanceInfo& b);
141 bool operator!=(const Capabilities::PerformanceInfo& a, const Capabilities::PerformanceInfo& b);
142 bool operator==(const Capabilities::OperandPerformance& a,
143                 const Capabilities::OperandPerformance& b);
144 bool operator!=(const Capabilities::OperandPerformance& a,
145                 const Capabilities::OperandPerformance& b);
146 bool operator==(const Capabilities& a, const Capabilities& b);
147 bool operator!=(const Capabilities& a, const Capabilities& b);
148 bool operator==(const Extension::OperandTypeInformation& a,
149                 const Extension::OperandTypeInformation& b);
150 bool operator!=(const Extension::OperandTypeInformation& a,
151                 const Extension::OperandTypeInformation& b);
152 bool operator==(const Extension& a, const Extension& b);
153 bool operator!=(const Extension& a, const Extension& b);
154 bool operator==(const MemoryPreference& a, const MemoryPreference& b);
155 bool operator!=(const MemoryPreference& a, const MemoryPreference& b);
156 bool operator==(const Operand::SymmPerChannelQuantParams& a,
157                 const Operand::SymmPerChannelQuantParams& b);
158 bool operator!=(const Operand::SymmPerChannelQuantParams& a,
159                 const Operand::SymmPerChannelQuantParams& b);
160 bool operator==(const Operand& a, const Operand& b);
161 bool operator!=(const Operand& a, const Operand& b);
162 bool operator==(const Operation& a, const Operation& b);
163 bool operator!=(const Operation& a, const Operation& b);
164 bool operator==(const Version& a, const Version& b);
165 bool operator!=(const Version& a, const Version& b);
166 
toString(uint32_t obj)167 inline std::string toString(uint32_t obj) {
168     return std::to_string(obj);
169 }
170 
171 template <typename A, typename B>
toString(const std::pair<A,B> & pair)172 std::string toString(const std::pair<A, B>& pair) {
173     std::ostringstream oss;
174     oss << "(" << pair.first << ", " << pair.second << ")";
175     return oss.str();
176 }
177 
178 template <typename Type>
toString(const std::vector<Type> & vec)179 std::string toString(const std::vector<Type>& vec) {
180     std::string os = "[";
181     for (size_t i = 0; i < vec.size(); ++i) {
182         os += (i == 0 ? "" : ", ") + toString(vec[i]);
183     }
184     return os += "]";
185 }
186 
187 /* IMPORTANT: if you change the following list, don't
188  * forget to update the corresponding 'tags' table in
189  * the initVlogMask() function implemented in Utils.cpp.
190  */
191 enum VLogFlags { MODEL = 0, COMPILATION, EXECUTION, CPUEXE, MANAGER, DRIVER, MEMORY };
192 
193 #define VLOG_IS_ON(TAG) ((vLogMask & (1 << (TAG))) != 0)
194 
195 #define VLOG(TAG)                 \
196     if (LIKELY(!VLOG_IS_ON(TAG))) \
197         ;                         \
198     else                          \
199         LOG(INFO)
200 
201 extern int vLogMask;
202 void initVLogMask();
203 
204 // The NN_RET_CHECK family of macros defined below is similar to the CHECK family defined in
205 // system/libbase/include/android-base/logging.h
206 //
207 // The difference is that NN_RET_CHECK macros use LOG(ERROR) instead of LOG(FATAL)
208 // and return false instead of aborting.
209 
210 // Logs an error and returns false. Append context using << after. For example:
211 //
212 //   NN_RET_CHECK_FAIL() << "Something went wrong";
213 //
214 // The containing function must return a bool or a base::expected (including nn::Result,
215 // nn::GeneralResult, and nn::ExecutionResult).
216 #define NN_RET_CHECK_FAIL()                       \
217     return ::android::nn::NnRetCheckErrorStream() \
218            << "NN_RET_CHECK failed (" << __FILE__ << ":" << __LINE__ << "): "
219 
220 // Logs an error and returns false if condition is false. Extra logging can be appended using <<
221 // after. For example:
222 //
223 //   NN_RET_CHECK(false) << "Something went wrong";
224 //
225 // The containing function must return a bool or a base::expected (including nn::Result,
226 // nn::GeneralResult, and nn::ExecutionResult).
227 #define NN_RET_CHECK(condition) \
228     while (UNLIKELY(!(condition))) NN_RET_CHECK_FAIL() << #condition << " "
229 
230 // Helper for NN_CHECK_xx(x, y) macros.
231 #define NN_RET_CHECK_OP(LHS, RHS, OP)                                                       \
232     for (auto _values = ::android::base::MakeEagerEvaluator(LHS, RHS);                      \
233          UNLIKELY(!(_values.lhs.v OP _values.rhs.v));                                       \
234          /* empty */)                                                                       \
235     NN_RET_CHECK_FAIL()                                                                     \
236             << #LHS << " " << #OP << " " << #RHS << " (" << #LHS << " = "                   \
237             << ::android::base::LogNullGuard<decltype(_values.lhs.v)>::Guard(_values.lhs.v) \
238             << ", " << #RHS << " = "                                                        \
239             << ::android::base::LogNullGuard<decltype(_values.rhs.v)>::Guard(_values.rhs.v) \
240             << ") "
241 
242 // Logs an error and returns false if a condition between x and y does not hold. Extra logging can
243 // be appended using << after. For example:
244 //
245 //   NN_RET_CHECK_EQ(a, b) << "Something went wrong";
246 //
247 // The values must implement the appropriate comparison operator as well as
248 // `operator<<(std::ostream&, ...)`.
249 // The containing function must return a bool or a base::expected (including nn::Result,
250 // nn::GeneralResult, and nn::ExecutionResult).
251 #define NN_RET_CHECK_EQ(x, y) NN_RET_CHECK_OP(x, y, ==)
252 #define NN_RET_CHECK_NE(x, y) NN_RET_CHECK_OP(x, y, !=)
253 #define NN_RET_CHECK_LE(x, y) NN_RET_CHECK_OP(x, y, <=)
254 #define NN_RET_CHECK_LT(x, y) NN_RET_CHECK_OP(x, y, <)
255 #define NN_RET_CHECK_GE(x, y) NN_RET_CHECK_OP(x, y, >=)
256 #define NN_RET_CHECK_GT(x, y) NN_RET_CHECK_OP(x, y, >)
257 
258 // Ensure that every user of NnRetCheckErrorStream is linked to the
259 // correct instance, using the correct LOG_TAG
260 namespace {
261 
262 // A wrapper around an error message that can be implicitly converted to bool (always evaluates to
263 // false) and logs via LOG(ERROR) or base::expected (always evaluates to base::unexpected). Used to
264 // implement stream logging in NN_RET_CHECK.
265 class NnRetCheckErrorStream {
266     DISALLOW_COPY_AND_ASSIGN(NnRetCheckErrorStream);
267 
268    public:
269     constexpr NnRetCheckErrorStream() = default;
270 
271     template <typename T>
272     NnRetCheckErrorStream& operator<<(const T& value) {
273         (*mBuffer) << value;
274         return *this;
275     }
276 
~NnRetCheckErrorStream()277     ~NnRetCheckErrorStream() {
278         if (mBuffer.has_value()) {
279             LOG(ERROR) << mBuffer->str();
280         }
281     }
282 
283     constexpr operator bool() const { return false; }  // NOLINT(google-explicit-constructor)
284 
285     template <typename T, typename E>
286     constexpr operator base::expected<T, E>() {  // NOLINT(google-explicit-constructor)
287         auto result = base::unexpected<E>(std::move(mBuffer)->str());
288         mBuffer.reset();
289         return result;
290     }
291 
292     // This is needed because conversion to Result<int> is ambiguous
293     // due to the above bool() operator overload
294     operator Result<int>() {  // NOLINT(google-explicit-constructor)
295         auto result = base::unexpected(std::move(mBuffer)->str());
296         mBuffer.reset();
297         return result;
298     }
299 
300    private:
301     std::optional<std::ostringstream> mBuffer = std::ostringstream{};
302 };
303 
304 }  // namespace
305 
306 }  // namespace android::nn
307 
308 #endif  // ANDROID_PACKAGES_MODULES_NEURALNETWORKS_COMMON_TYPES_NNAPI_TYPE_UTILS_H
309