1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "TypeUtils.h"
18 
19 #include <android-base/logging.h>
20 #include <android-base/properties.h>
21 #include <android-base/strings.h>
22 
23 #include <algorithm>
24 #include <chrono>
25 #include <limits>
26 #include <memory>
27 #include <ostream>
28 #include <string>
29 #include <type_traits>
30 #include <unordered_map>
31 #include <utility>
32 #include <vector>
33 
34 #include "OperandTypes.h"
35 #include "OperationTypes.h"
36 #include "OperationsUtils.h"
37 #include "Result.h"
38 #include "SharedMemory.h"
39 #include "Types.h"
40 
41 namespace android::nn {
42 namespace {
43 
44 template <typename Type>
underlyingType(Type object)45 constexpr std::underlying_type_t<Type> underlyingType(Type object) {
46     return static_cast<std::underlying_type_t<Type>>(object);
47 }
48 
49 template <typename Type>
operator <<(std::ostream & os,const std::vector<Type> & vec)50 std::ostream& operator<<(std::ostream& os, const std::vector<Type>& vec) {
51     constexpr size_t kMaxVectorPrint = 20;
52     os << "[";
53     size_t count = 0;
54     for (const auto& element : vec) {
55         if (count > 0) {
56             os << ", ";
57         }
58         os << element;
59         count++;
60         if (count >= kMaxVectorPrint) {
61             return os << "...]";
62         }
63     }
64     return os << "]";
65 }
66 
makeOperandPerformance(const Capabilities::PerformanceInfo & perfInfo)67 std::vector<Capabilities::OperandPerformance> makeOperandPerformance(
68         const Capabilities::PerformanceInfo& perfInfo) {
69     static constexpr OperandType kOperandTypes[] = {
70             OperandType::FLOAT32,
71             OperandType::INT32,
72             OperandType::UINT32,
73             OperandType::TENSOR_FLOAT32,
74             OperandType::TENSOR_INT32,
75             OperandType::TENSOR_QUANT8_ASYMM,
76             OperandType::BOOL,
77             OperandType::TENSOR_QUANT16_SYMM,
78             OperandType::TENSOR_FLOAT16,
79             OperandType::TENSOR_BOOL8,
80             OperandType::FLOAT16,
81             OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL,
82             OperandType::TENSOR_QUANT16_ASYMM,
83             OperandType::TENSOR_QUANT8_SYMM,
84             OperandType::TENSOR_QUANT8_ASYMM_SIGNED,
85             // OperandType::SUBGRAPH, OperandType::OEM, and OperandType::TENSOR_OEM_BYTE
86             // intentionally omitted.
87     };
88 
89     std::vector<Capabilities::OperandPerformance> operandPerformance;
90     operandPerformance.reserve(std::size(kOperandTypes));
91     std::transform(std::begin(kOperandTypes), std::end(kOperandTypes),
92                    std::back_inserter(operandPerformance), [&perfInfo](OperandType op) {
93                        return Capabilities::OperandPerformance{.type = op, .info = perfInfo};
94                    });
95     return operandPerformance;
96 }
97 
update(std::vector<Capabilities::OperandPerformance> * operandPerformance,OperandType type,const Capabilities::PerformanceInfo & info)98 void update(std::vector<Capabilities::OperandPerformance>* operandPerformance, OperandType type,
99             const Capabilities::PerformanceInfo& info) {
100     CHECK(operandPerformance != nullptr);
101     auto it = std::lower_bound(operandPerformance->begin(), operandPerformance->end(), type,
102                                [](const Capabilities::OperandPerformance& perf, OperandType type) {
103                                    return perf.type < type;
104                                });
105     CHECK(it != operandPerformance->end());
106     CHECK_EQ(it->type, type);
107     it->info = info;
108 }
109 
110 }  // namespace
111 
isExtension(OperandType type)112 bool isExtension(OperandType type) {
113     return getExtensionPrefix(underlyingType(type)) != 0;
114 }
115 
isExtension(OperationType type)116 bool isExtension(OperationType type) {
117     return getExtensionPrefix(underlyingType(type)) != 0;
118 }
119 
isNonExtensionScalar(OperandType operandType)120 bool isNonExtensionScalar(OperandType operandType) {
121     CHECK(!isExtension(operandType));
122     switch (operandType) {
123         case OperandType::FLOAT32:
124         case OperandType::INT32:
125         case OperandType::UINT32:
126         case OperandType::BOOL:
127         case OperandType::FLOAT16:
128         case OperandType::SUBGRAPH:
129         case OperandType::OEM:
130             return true;
131         case OperandType::TENSOR_FLOAT32:
132         case OperandType::TENSOR_INT32:
133         case OperandType::TENSOR_QUANT8_ASYMM:
134         case OperandType::TENSOR_QUANT16_SYMM:
135         case OperandType::TENSOR_FLOAT16:
136         case OperandType::TENSOR_BOOL8:
137         case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
138         case OperandType::TENSOR_QUANT16_ASYMM:
139         case OperandType::TENSOR_QUANT8_SYMM:
140         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
141         case OperandType::TENSOR_OEM_BYTE:
142             return false;
143     }
144     return false;
145 }
146 
getNonExtensionSize(OperandType operandType)147 size_t getNonExtensionSize(OperandType operandType) {
148     CHECK(!isExtension(operandType));
149     switch (operandType) {
150         case OperandType::SUBGRAPH:
151         case OperandType::OEM:
152             return 0;
153         case OperandType::TENSOR_QUANT8_ASYMM:
154         case OperandType::BOOL:
155         case OperandType::TENSOR_BOOL8:
156         case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
157         case OperandType::TENSOR_QUANT8_SYMM:
158         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
159         case OperandType::TENSOR_OEM_BYTE:
160             return 1;
161         case OperandType::TENSOR_QUANT16_SYMM:
162         case OperandType::TENSOR_FLOAT16:
163         case OperandType::FLOAT16:
164         case OperandType::TENSOR_QUANT16_ASYMM:
165             return 2;
166         case OperandType::FLOAT32:
167         case OperandType::INT32:
168         case OperandType::UINT32:
169         case OperandType::TENSOR_FLOAT32:
170         case OperandType::TENSOR_INT32:
171             return 4;
172     }
173     return 0;
174 }
175 
getNonExtensionSize(OperandType operandType,const Dimensions & dimensions)176 std::optional<size_t> getNonExtensionSize(OperandType operandType, const Dimensions& dimensions) {
177     CHECK(!isExtension(operandType)) << "Size of extension operand data is unknown";
178     size_t size = getNonExtensionSize(operandType);
179     if (isNonExtensionScalar(operandType)) {
180         return size;
181     } else if (dimensions.empty()) {
182         return 0;
183     }
184     for (Dimension dimension : dimensions) {
185         if (dimension != 0 && size > std::numeric_limits<size_t>::max() / dimension) {
186             return std::nullopt;
187         }
188         size *= dimension;
189     }
190     return size;
191 }
192 
getNonExtensionSize(const Operand & operand)193 std::optional<size_t> getNonExtensionSize(const Operand& operand) {
194     return getNonExtensionSize(operand.type, operand.dimensions);
195 }
196 
tensorHasUnspecifiedDimensions(OperandType type,const std::vector<uint32_t> & dimensions)197 bool tensorHasUnspecifiedDimensions(OperandType type, const std::vector<uint32_t>& dimensions) {
198     if (!isExtension(type)) {
199         CHECK(!isNonExtensionScalar(type)) << "A scalar type can never have unspecified dimensions";
200     }
201     return dimensions.empty() ||
202            std::find(dimensions.begin(), dimensions.end(), 0) != dimensions.end();
203 }
204 
tensorHasUnspecifiedDimensions(const Operand & operand)205 bool tensorHasUnspecifiedDimensions(const Operand& operand) {
206     return tensorHasUnspecifiedDimensions(operand.type, operand.dimensions);
207 }
208 
getOffsetFromInts(int lower,int higher)209 size_t getOffsetFromInts(int lower, int higher) {
210     const int32_t lowBits = static_cast<int32_t>(lower);
211     const int32_t highBits = static_cast<int32_t>(higher);
212     const uint32_t lowOffsetBits = *reinterpret_cast<const uint32_t*>(&lowBits);
213     const uint32_t highOffsetBits = *reinterpret_cast<const uint32_t*>(&highBits);
214     const uint64_t offset = lowOffsetBits | (static_cast<uint64_t>(highOffsetBits) << 32);
215     return offset;
216 }
217 
getIntsFromOffset(size_t offset)218 std::pair<int32_t, int32_t> getIntsFromOffset(size_t offset) {
219     const uint64_t bits = static_cast<uint64_t>(offset);
220     const uint32_t lowBits = static_cast<uint32_t>(bits & 0xffffffff);
221     const uint32_t highBits = static_cast<uint32_t>(bits >> 32);
222     const int32_t lowOffsetBits = *reinterpret_cast<const int32_t*>(&lowBits);
223     const int32_t highOffsetBits = *reinterpret_cast<const int32_t*>(&highBits);
224     return std::make_pair(lowOffsetBits, highOffsetBits);
225 }
226 
countNumberOfConsumers(size_t numberOfOperands,const std::vector<nn::Operation> & operations)227 Result<std::vector<uint32_t>> countNumberOfConsumers(size_t numberOfOperands,
228                                                      const std::vector<nn::Operation>& operations) {
229     std::vector<uint32_t> numberOfConsumers(numberOfOperands, 0);
230     for (const auto& operation : operations) {
231         for (uint32_t operandIndex : operation.inputs) {
232             if (operandIndex >= numberOfConsumers.size()) {
233                 return NN_ERROR()
234                        << "countNumberOfConsumers: tried to access out-of-bounds operand ("
235                        << operandIndex << " vs " << numberOfConsumers.size() << ")";
236             }
237             numberOfConsumers[operandIndex]++;
238         }
239     }
240     return numberOfConsumers;
241 }
242 
combineDimensions(const Dimensions & lhs,const Dimensions & rhs)243 Result<Dimensions> combineDimensions(const Dimensions& lhs, const Dimensions& rhs) {
244     if (rhs.empty()) return lhs;
245     if (lhs.empty()) return rhs;
246     if (lhs.size() != rhs.size()) {
247         std::ostringstream os;
248         os << "Incompatible ranks: " << lhs << " and " << rhs;
249         return NN_ERROR() << os.str();
250     }
251     Dimensions combined = lhs;
252     for (size_t i = 0; i < lhs.size(); i++) {
253         if (lhs[i] == 0) {
254             combined[i] = rhs[i];
255         } else if (rhs[i] != 0 && lhs[i] != rhs[i]) {
256             std::ostringstream os;
257             os << "Incompatible dimensions: " << lhs << " and " << rhs;
258             return NN_ERROR() << os.str();
259         }
260     }
261     return combined;
262 }
263 
getMemorySizes(const Model & model)264 std::pair<size_t, std::vector<size_t>> getMemorySizes(const Model& model) {
265     const size_t operandValuesSize = model.operandValues.size();
266 
267     std::vector<size_t> poolSizes;
268     poolSizes.reserve(model.pools.size());
269     std::transform(model.pools.begin(), model.pools.end(), std::back_inserter(poolSizes),
270                    [](const SharedMemory& memory) { return getSize(memory); });
271 
272     return std::make_pair(operandValuesSize, std::move(poolSizes));
273 }
274 
roundUp(size_t size,size_t multiple)275 size_t roundUp(size_t size, size_t multiple) {
276     CHECK(multiple != 0);
277     CHECK((multiple & (multiple - 1)) == 0) << multiple << " is not a power of two";
278     return (size + (multiple - 1)) & ~(multiple - 1);
279 }
280 
getAlignmentForLength(size_t length)281 size_t getAlignmentForLength(size_t length) {
282     if (length < 2) {
283         return 1;  // No alignment necessary
284     } else if (length < 4) {
285         return 2;  // Align on 2-byte boundary
286     } else {
287         return 4;  // Align on 4-byte boundary
288     }
289 }
290 
makeCapabilities(const Capabilities::PerformanceInfo & defaultInfo,const Capabilities::PerformanceInfo & float32Info,const Capabilities::PerformanceInfo & relaxedInfo)291 Capabilities makeCapabilities(const Capabilities::PerformanceInfo& defaultInfo,
292                               const Capabilities::PerformanceInfo& float32Info,
293                               const Capabilities::PerformanceInfo& relaxedInfo) {
294     auto operandPerformance = makeOperandPerformance(defaultInfo);
295     update(&operandPerformance, OperandType::TENSOR_FLOAT32, float32Info);
296     update(&operandPerformance, OperandType::FLOAT32, float32Info);
297     auto table =
298             Capabilities::OperandPerformanceTable::create(std::move(operandPerformance)).value();
299 
300     return {.relaxedFloat32toFloat16PerformanceScalar = relaxedInfo,
301             .relaxedFloat32toFloat16PerformanceTensor = relaxedInfo,
302             .operandPerformance = std::move(table),
303             .ifPerformance = defaultInfo,
304             .whilePerformance = defaultInfo};
305 }
306 
operator <<(std::ostream & os,const DeviceStatus & deviceStatus)307 std::ostream& operator<<(std::ostream& os, const DeviceStatus& deviceStatus) {
308     switch (deviceStatus) {
309         case DeviceStatus::AVAILABLE:
310             return os << "AVAILABLE";
311         case DeviceStatus::BUSY:
312             return os << "BUSY";
313         case DeviceStatus::OFFLINE:
314             return os << "OFFLINE";
315         case DeviceStatus::UNKNOWN:
316             return os << "UNKNOWN";
317     }
318     return os << "DeviceStatus{" << underlyingType(deviceStatus) << "}";
319 }
320 
operator <<(std::ostream & os,const ExecutionPreference & executionPreference)321 std::ostream& operator<<(std::ostream& os, const ExecutionPreference& executionPreference) {
322     switch (executionPreference) {
323         case ExecutionPreference::LOW_POWER:
324             return os << "LOW_POWER";
325         case ExecutionPreference::FAST_SINGLE_ANSWER:
326             return os << "FAST_SINGLE_ANSWER";
327         case ExecutionPreference::SUSTAINED_SPEED:
328             return os << "SUSTAINED_SPEED";
329     }
330     return os << "ExecutionPreference{" << underlyingType(executionPreference) << "}";
331 }
332 
operator <<(std::ostream & os,const DeviceType & deviceType)333 std::ostream& operator<<(std::ostream& os, const DeviceType& deviceType) {
334     switch (deviceType) {
335         case DeviceType::UNKNOWN:
336             return os << "UNKNOWN";
337         case DeviceType::OTHER:
338             return os << "OTHER";
339         case DeviceType::CPU:
340             return os << "CPU";
341         case DeviceType::GPU:
342             return os << "GPU";
343         case DeviceType::ACCELERATOR:
344             return os << "ACCELERATOR";
345     }
346     return os << "DeviceType{" << underlyingType(deviceType) << "}";
347 }
348 
operator <<(std::ostream & os,const MeasureTiming & measureTiming)349 std::ostream& operator<<(std::ostream& os, const MeasureTiming& measureTiming) {
350     switch (measureTiming) {
351         case MeasureTiming::NO:
352             return os << "NO";
353         case MeasureTiming::YES:
354             return os << "YES";
355     }
356     return os << "MeasureTiming{" << underlyingType(measureTiming) << "}";
357 }
358 
operator <<(std::ostream & os,const OperandType & operandType)359 std::ostream& operator<<(std::ostream& os, const OperandType& operandType) {
360     switch (operandType) {
361         case OperandType::FLOAT32:
362             return os << "FLOAT32";
363         case OperandType::INT32:
364             return os << "INT32";
365         case OperandType::UINT32:
366             return os << "UINT32";
367         case OperandType::TENSOR_FLOAT32:
368             return os << "TENSOR_FLOAT32";
369         case OperandType::TENSOR_INT32:
370             return os << "TENSOR_INT32";
371         case OperandType::TENSOR_QUANT8_ASYMM:
372             return os << "TENSOR_QUANT8_ASYMM";
373         case OperandType::BOOL:
374             return os << "BOOL";
375         case OperandType::TENSOR_QUANT16_SYMM:
376             return os << "TENSOR_QUANT16_SYMM";
377         case OperandType::TENSOR_FLOAT16:
378             return os << "TENSOR_FLOAT16";
379         case OperandType::TENSOR_BOOL8:
380             return os << "TENSOR_BOOL8";
381         case OperandType::FLOAT16:
382             return os << "FLOAT16";
383         case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
384             return os << "TENSOR_QUANT8_SYMM_PER_CHANNEL";
385         case OperandType::TENSOR_QUANT16_ASYMM:
386             return os << "TENSOR_QUANT16_ASYMM";
387         case OperandType::TENSOR_QUANT8_SYMM:
388             return os << "TENSOR_QUANT8_SYMM";
389         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
390             return os << "TENSOR_QUANT8_ASYMM_SIGNED";
391         case OperandType::SUBGRAPH:
392             return os << "SUBGRAPH";
393         case OperandType::OEM:
394             return os << "OEM";
395         case OperandType::TENSOR_OEM_BYTE:
396             return os << "TENSOR_OEM_BYTE";
397     }
398     if (isExtension(operandType)) {
399         return os << "Extension OperandType " << underlyingType(operandType);
400     }
401     return os << "OperandType{" << underlyingType(operandType) << "}";
402 }
403 
operator <<(std::ostream & os,const Operand::LifeTime & lifetime)404 std::ostream& operator<<(std::ostream& os, const Operand::LifeTime& lifetime) {
405     switch (lifetime) {
406         case Operand::LifeTime::TEMPORARY_VARIABLE:
407             return os << "TEMPORARY_VARIABLE";
408         case Operand::LifeTime::SUBGRAPH_INPUT:
409             return os << "SUBGRAPH_INPUT";
410         case Operand::LifeTime::SUBGRAPH_OUTPUT:
411             return os << "SUBGRAPH_OUTPUT";
412         case Operand::LifeTime::CONSTANT_COPY:
413             return os << "CONSTANT_COPY";
414         case Operand::LifeTime::CONSTANT_REFERENCE:
415             return os << "CONSTANT_REFERENCE";
416         case Operand::LifeTime::NO_VALUE:
417             return os << "NO_VALUE";
418         case Operand::LifeTime::SUBGRAPH:
419             return os << "SUBGRAPH";
420         case Operand::LifeTime::POINTER:
421             return os << "POINTER";
422     }
423     return os << "Operand::LifeTime{" << underlyingType(lifetime) << "}";
424 }
425 
operator <<(std::ostream & os,const OperationType & operationType)426 std::ostream& operator<<(std::ostream& os, const OperationType& operationType) {
427 #define NN_HANDLE_SWITCH_CASE(opType) \
428     case OperationType::opType:       \
429         return os << #opType;
430     switch (operationType) { NN_FOR_EACH_OPERATION(NN_HANDLE_SWITCH_CASE) }
431 #undef NN_HANDLE_SWITCH_CASE
432 
433     if (isExtension(operationType)) {
434         return os << "Extension OperationType " << underlyingType(operationType);
435     }
436     return os << "OperationType{" << underlyingType(operationType) << "}";
437 }
438 
operator <<(std::ostream & os,const Request::Argument::LifeTime & lifetime)439 std::ostream& operator<<(std::ostream& os, const Request::Argument::LifeTime& lifetime) {
440     switch (lifetime) {
441         case Request::Argument::LifeTime::POOL:
442             return os << "POOL";
443         case Request::Argument::LifeTime::NO_VALUE:
444             return os << "NO_VALUE";
445         case Request::Argument::LifeTime::POINTER:
446             return os << "POINTER";
447     }
448     return os << "Request::Argument::LifeTime{" << underlyingType(lifetime) << "}";
449 }
450 
operator <<(std::ostream & os,const Priority & priority)451 std::ostream& operator<<(std::ostream& os, const Priority& priority) {
452     switch (priority) {
453         case Priority::LOW:
454             return os << "LOW";
455         case Priority::MEDIUM:
456             return os << "MEDIUM";
457         case Priority::HIGH:
458             return os << "HIGH";
459     }
460     return os << "Priority{" << underlyingType(priority) << "}";
461 }
462 
operator <<(std::ostream & os,const ErrorStatus & errorStatus)463 std::ostream& operator<<(std::ostream& os, const ErrorStatus& errorStatus) {
464     switch (errorStatus) {
465         case ErrorStatus::NONE:
466             return os << "NONE";
467         case ErrorStatus::DEVICE_UNAVAILABLE:
468             return os << "DEVICE_UNAVAILABLE";
469         case ErrorStatus::GENERAL_FAILURE:
470             return os << "GENERAL_FAILURE";
471         case ErrorStatus::OUTPUT_INSUFFICIENT_SIZE:
472             return os << "OUTPUT_INSUFFICIENT_SIZE";
473         case ErrorStatus::INVALID_ARGUMENT:
474             return os << "INVALID_ARGUMENT";
475         case ErrorStatus::MISSED_DEADLINE_TRANSIENT:
476             return os << "MISSED_DEADLINE_TRANSIENT";
477         case ErrorStatus::MISSED_DEADLINE_PERSISTENT:
478             return os << "MISSED_DEADLINE_PERSISTENT";
479         case ErrorStatus::RESOURCE_EXHAUSTED_TRANSIENT:
480             return os << "RESOURCE_EXHAUSTED_TRANSIENT";
481         case ErrorStatus::RESOURCE_EXHAUSTED_PERSISTENT:
482             return os << "RESOURCE_EXHAUSTED_PERSISTENT";
483         case ErrorStatus::DEAD_OBJECT:
484             return os << "DEAD_OBJECT";
485     }
486     return os << "ErrorStatus{" << underlyingType(errorStatus) << "}";
487 }
488 
operator <<(std::ostream & os,const FusedActivationFunc & activation)489 std::ostream& operator<<(std::ostream& os, const FusedActivationFunc& activation) {
490     switch (activation) {
491         case FusedActivationFunc::NONE:
492             return os << "NONE";
493         case FusedActivationFunc::RELU:
494             return os << "RELU";
495         case FusedActivationFunc::RELU1:
496             return os << "RELU1";
497         case FusedActivationFunc::RELU6:
498             return os << "RELU6";
499     }
500     return os << "FusedActivationFunc{" << underlyingType(activation) << "}";
501 }
502 
operator <<(std::ostream & os,const OutputShape & outputShape)503 std::ostream& operator<<(std::ostream& os, const OutputShape& outputShape) {
504     return os << "OutputShape{.dimensions=" << outputShape.dimensions
505               << ", .isSufficient=" << (outputShape.isSufficient ? "true" : "false") << "}";
506 }
507 
operator <<(std::ostream & os,const Timing & timing)508 std::ostream& operator<<(std::ostream& os, const Timing& timing) {
509     return os << "Timing{.timeOnDevice=" << timing.timeOnDevice
510               << ", .timeInDriver=" << timing.timeInDriver << "}";
511 }
512 
operator <<(std::ostream & os,const Capabilities::PerformanceInfo & performanceInfo)513 std::ostream& operator<<(std::ostream& os, const Capabilities::PerformanceInfo& performanceInfo) {
514     return os << "Capabilities::PerformanceInfo{.execTime=" << performanceInfo.execTime
515               << ", .powerUsage=" << performanceInfo.powerUsage << "}";
516 }
517 
operator <<(std::ostream & os,const Capabilities::OperandPerformance & operandPerformance)518 std::ostream& operator<<(std::ostream& os,
519                          const Capabilities::OperandPerformance& operandPerformance) {
520     return os << "Capabilities::OperandPerformance{.type=" << operandPerformance.type
521               << ", .info=" << operandPerformance.info << "}";
522 }
523 
operator <<(std::ostream & os,const Capabilities::OperandPerformanceTable & operandPerformances)524 std::ostream& operator<<(std::ostream& os,
525                          const Capabilities::OperandPerformanceTable& operandPerformances) {
526     return os << operandPerformances.asVector();
527 }
528 
operator <<(std::ostream & os,const Capabilities & capabilities)529 std::ostream& operator<<(std::ostream& os, const Capabilities& capabilities) {
530     return os << "Capabilities{.relaxedFloat32toFloat16PerformanceScalar="
531               << capabilities.relaxedFloat32toFloat16PerformanceScalar
532               << ", .relaxedFloat32toFloat16PerformanceTensor="
533               << capabilities.relaxedFloat32toFloat16PerformanceTensor
534               << ", .operandPerformance=" << capabilities.operandPerformance
535               << ", .ifPerformance=" << capabilities.ifPerformance
536               << ", .whilePerformance=" << capabilities.whilePerformance << "}";
537 }
538 
operator <<(std::ostream & os,const Extension::OperandTypeInformation & operandTypeInformation)539 std::ostream& operator<<(std::ostream& os,
540                          const Extension::OperandTypeInformation& operandTypeInformation) {
541     return os << "Extension::OperandTypeInformation{.type=" << operandTypeInformation.type
542               << ", .isTensor=" << (operandTypeInformation.isTensor ? "true" : "false")
543               << ", .byteSize=" << operandTypeInformation.byteSize << "}";
544 }
545 
operator <<(std::ostream & os,const Extension & extension)546 std::ostream& operator<<(std::ostream& os, const Extension& extension) {
547     return os << "Extension{.name=" << extension.name
548               << ", .operandTypes=" << extension.operandTypes << "}";
549 }
550 
operator <<(std::ostream & os,const DataLocation & location)551 std::ostream& operator<<(std::ostream& os, const DataLocation& location) {
552     const auto printPointer = [&os](const std::variant<const void*, void*>& pointer) {
553         os << (std::holds_alternative<const void*>(pointer) ? "<constant " : "<mutable ");
554         os << std::visit(
555                 [](const auto* ptr) {
556                     return ptr == nullptr ? "null pointer>" : "non-null pointer>";
557                 },
558                 pointer);
559     };
560     os << "DataLocation{.pointer=";
561     printPointer(location.pointer);
562     return os << ", .poolIndex=" << location.poolIndex << ", .offset=" << location.offset
563               << ", .length=" << location.length << ", .padding=" << location.padding << "}";
564 }
565 
operator <<(std::ostream & os,const Operand::SymmPerChannelQuantParams & symmPerChannelQuantParams)566 std::ostream& operator<<(std::ostream& os,
567                          const Operand::SymmPerChannelQuantParams& symmPerChannelQuantParams) {
568     return os << "Operand::SymmPerChannelQuantParams{.scales=" << symmPerChannelQuantParams.scales
569               << ", .channelDim=" << symmPerChannelQuantParams.channelDim << "}";
570 }
571 
operator <<(std::ostream & os,const Operand::ExtraParams & extraParams)572 std::ostream& operator<<(std::ostream& os, const Operand::ExtraParams& extraParams) {
573     os << "Operand::ExtraParams{";
574     if (std::holds_alternative<Operand::NoParams>(extraParams)) {
575         os << "<no params>";
576     } else if (std::holds_alternative<Operand::SymmPerChannelQuantParams>(extraParams)) {
577         os << std::get<Operand::SymmPerChannelQuantParams>(extraParams);
578     } else if (std::holds_alternative<Operand::ExtensionParams>(extraParams)) {
579         os << std::get<Operand::ExtensionParams>(extraParams);
580     }
581     return os << "}";
582 }
583 
operator <<(std::ostream & os,const Operand & operand)584 std::ostream& operator<<(std::ostream& os, const Operand& operand) {
585     return os << "Operand{.type=" << operand.type << ", .dimensions=" << operand.dimensions
586               << ", .scale=" << operand.scale << ", .zeroPoint=" << operand.zeroPoint
587               << ", lifetime=" << operand.lifetime << ", .location=" << operand.location
588               << ", .extraParams=" << operand.extraParams << "}";
589 }
590 
operator <<(std::ostream & os,const Operation & operation)591 std::ostream& operator<<(std::ostream& os, const Operation& operation) {
592     return os << "Operation{.type=" << operation.type << ", .inputs=" << operation.inputs
593               << ", .outputs=" << operation.outputs << "}";
594 }
595 
operator <<(std::ostream & os,const Handle & handle)596 static std::ostream& operator<<(std::ostream& os, const Handle& handle) {
597     return os << (handle.ok() ? "<valid handle>" : "<invalid handle>");
598 }
599 
operator <<(std::ostream & os,const SharedHandle & handle)600 std::ostream& operator<<(std::ostream& os, const SharedHandle& handle) {
601     if (handle == nullptr) {
602         return os << "<empty handle>";
603     }
604     return os << *handle;
605 }
606 
operator <<(std::ostream & os,const Memory::Ashmem & memory)607 static std::ostream& operator<<(std::ostream& os, const Memory::Ashmem& memory) {
608     return os << "Ashmem{.fd=" << (memory.fd.ok() ? "<valid fd>" : "<invalid fd>")
609               << ", .size=" << memory.size << "}";
610 }
611 
operator <<(std::ostream & os,const Memory::Fd & memory)612 static std::ostream& operator<<(std::ostream& os, const Memory::Fd& memory) {
613     return os << "Fd{.size=" << memory.size << ", .prot=" << memory.prot
614               << ", .fd=" << (memory.fd.ok() ? "<valid fd>" : "<invalid fd>")
615               << ", .offset=" << memory.offset << "}";
616 }
617 
operator <<(std::ostream & os,const Memory::HardwareBuffer & memory)618 static std::ostream& operator<<(std::ostream& os, const Memory::HardwareBuffer& memory) {
619     if (memory.handle.get() == nullptr) {
620         return os << "<empty HardwareBuffer::Handle>";
621     }
622     return os << (isAhwbBlob(memory) ? "<AHardwareBuffer blob>" : "<non-blob AHardwareBuffer>");
623 }
624 
operator <<(std::ostream & os,const Memory::Unknown::Handle & handle)625 static std::ostream& operator<<(std::ostream& os, const Memory::Unknown::Handle& handle) {
626     return os << "<handle with " << handle.fds.size() << " fds and " << handle.ints.size()
627               << " ints>";
628 }
629 
operator <<(std::ostream & os,const Memory::Unknown & memory)630 static std::ostream& operator<<(std::ostream& os, const Memory::Unknown& memory) {
631     return os << "Unknown{.handle=" << memory.handle << ", .size=" << memory.size
632               << ", .name=" << memory.name << "}";
633 }
634 
operator <<(std::ostream & os,const Memory & memory)635 std::ostream& operator<<(std::ostream& os, const Memory& memory) {
636     os << "Memory{.handle=";
637     std::visit([&os](const auto& x) { os << x; }, memory.handle);
638     return os << "}";
639 }
640 
operator <<(std::ostream & os,const SharedMemory & memory)641 std::ostream& operator<<(std::ostream& os, const SharedMemory& memory) {
642     if (memory == nullptr) {
643         return os << "<empty memory>";
644     }
645     return os << *memory;
646 }
647 
operator <<(std::ostream & os,const MemoryPreference & memoryPreference)648 std::ostream& operator<<(std::ostream& os, const MemoryPreference& memoryPreference) {
649     return os << "MemoryPreference{.alignment=" << memoryPreference.alignment
650               << ", .padding=" << memoryPreference.padding << "}";
651 }
652 
operator <<(std::ostream & os,const Model::Subgraph & subgraph)653 std::ostream& operator<<(std::ostream& os, const Model::Subgraph& subgraph) {
654     std::vector<Operand> operands;
655     std::vector<Operation> operations;
656     std::vector<uint32_t> inputIndexes;
657     std::vector<uint32_t> outputIndexes;
658     return os << "Model::Subgraph{.operands=" << subgraph.operands
659               << ", .operations=" << subgraph.operations
660               << ", .inputIndexes=" << subgraph.inputIndexes
661               << ", .outputIndexes=" << subgraph.outputIndexes << "}";
662 }
663 
operator <<(std::ostream & os,const Model::OperandValues & operandValues)664 std::ostream& operator<<(std::ostream& os, const Model::OperandValues& operandValues) {
665     return os << "Model::OperandValues{<" << operandValues.size() << "bytes>}";
666 }
667 
operator <<(std::ostream & os,const ExtensionNameAndPrefix & extensionNameAndPrefix)668 std::ostream& operator<<(std::ostream& os, const ExtensionNameAndPrefix& extensionNameAndPrefix) {
669     return os << "ExtensionNameAndPrefix{.name=" << extensionNameAndPrefix.name
670               << ", .prefix=" << extensionNameAndPrefix.prefix << "}";
671 }
672 
operator <<(std::ostream & os,const Model & model)673 std::ostream& operator<<(std::ostream& os, const Model& model) {
674     return os << "Model{.main=" << model.main << ", .referenced=" << model.referenced
675               << ", .operandValues=" << model.operandValues << ", .pools=" << model.pools
676               << ", .relaxComputationFloat32toFloat16="
677               << (model.relaxComputationFloat32toFloat16 ? "true" : "false")
678               << ", extensionNameToPrefix=" << model.extensionNameToPrefix << "}";
679 }
680 
operator <<(std::ostream & os,const BufferDesc & bufferDesc)681 std::ostream& operator<<(std::ostream& os, const BufferDesc& bufferDesc) {
682     return os << "BufferDesc{.dimensions=" << bufferDesc.dimensions << "}";
683 }
684 
operator <<(std::ostream & os,const BufferRole & bufferRole)685 std::ostream& operator<<(std::ostream& os, const BufferRole& bufferRole) {
686     return os << "BufferRole{.modelIndex=" << bufferRole.modelIndex
687               << ", .ioIndex=" << bufferRole.ioIndex << ", .probability=" << bufferRole.probability
688               << "}";
689 }
690 
operator <<(std::ostream & os,const Request::Argument & requestArgument)691 std::ostream& operator<<(std::ostream& os, const Request::Argument& requestArgument) {
692     return os << "Request::Argument{.lifetime=" << requestArgument.lifetime
693               << ", .location=" << requestArgument.location
694               << ", .dimensions=" << requestArgument.dimensions << "}";
695 }
696 
operator <<(std::ostream & os,const Request::MemoryPool & memoryPool)697 std::ostream& operator<<(std::ostream& os, const Request::MemoryPool& memoryPool) {
698     os << "Request::MemoryPool{";
699     if (std::holds_alternative<SharedMemory>(memoryPool)) {
700         os << std::get<SharedMemory>(memoryPool);
701     } else if (std::holds_alternative<Request::MemoryDomainToken>(memoryPool)) {
702         const auto& token = std::get<Request::MemoryDomainToken>(memoryPool);
703         if (token == Request::MemoryDomainToken{}) {
704             os << "<invalid MemoryDomainToken>";
705         } else {
706             os << "MemoryDomainToken=" << underlyingType(token);
707         }
708     } else if (std::holds_alternative<SharedBuffer>(memoryPool)) {
709         const auto& buffer = std::get<SharedBuffer>(memoryPool);
710         os << (buffer != nullptr ? "<non-null IBuffer>" : "<null IBuffer>");
711     }
712     return os << "}";
713 }
714 
operator <<(std::ostream & os,const Request & request)715 std::ostream& operator<<(std::ostream& os, const Request& request) {
716     return os << "Request{.inputs=" << request.inputs << ", .outputs=" << request.outputs
717               << ", .pools=" << request.pools << "}";
718 }
719 
operator <<(std::ostream & os,const SyncFence::FenceState & fenceState)720 std::ostream& operator<<(std::ostream& os, const SyncFence::FenceState& fenceState) {
721     switch (fenceState) {
722         case SyncFence::FenceState::ACTIVE:
723             return os << "ACTIVE";
724         case SyncFence::FenceState::SIGNALED:
725             return os << "SIGNALED";
726         case SyncFence::FenceState::ERROR:
727             return os << "ERROR";
728         case SyncFence::FenceState::UNKNOWN:
729             return os << "UNKNOWN";
730     }
731     return os << "SyncFence::FenceState{" << underlyingType(fenceState) << "}";
732 }
733 
operator <<(std::ostream & os,const TimePoint & timePoint)734 std::ostream& operator<<(std::ostream& os, const TimePoint& timePoint) {
735     return os << timePoint.time_since_epoch() << " since epoch";
736 }
737 
operator <<(std::ostream & os,const OptionalTimePoint & optionalTimePoint)738 std::ostream& operator<<(std::ostream& os, const OptionalTimePoint& optionalTimePoint) {
739     if (!optionalTimePoint.has_value()) {
740         return os << "<no time point>";
741     }
742     return os << optionalTimePoint.value();
743 }
744 
operator <<(std::ostream & os,const Duration & timeoutDuration)745 std::ostream& operator<<(std::ostream& os, const Duration& timeoutDuration) {
746     return os << timeoutDuration.count() << "ns";
747 }
748 
operator <<(std::ostream & os,const OptionalDuration & optionalTimeoutDuration)749 std::ostream& operator<<(std::ostream& os, const OptionalDuration& optionalTimeoutDuration) {
750     if (!optionalTimeoutDuration.has_value()) {
751         return os << "<no duration>";
752     }
753     return os << optionalTimeoutDuration.value();
754 }
755 
operator <<(std::ostream & os,const Version::Level & versionLevel)756 std::ostream& operator<<(std::ostream& os, const Version::Level& versionLevel) {
757     switch (versionLevel) {
758         case Version::Level::FEATURE_LEVEL_1:
759             return os << "FEATURE_LEVEL_1";
760         case Version::Level::FEATURE_LEVEL_2:
761             return os << "FEATURE_LEVEL_2";
762         case Version::Level::FEATURE_LEVEL_3:
763             return os << "FEATURE_LEVEL_3";
764         case Version::Level::FEATURE_LEVEL_4:
765             return os << "FEATURE_LEVEL_4";
766         case Version::Level::FEATURE_LEVEL_5:
767             return os << "FEATURE_LEVEL_5";
768         case Version::Level::FEATURE_LEVEL_6:
769             return os << "FEATURE_LEVEL_6";
770         case Version::Level::FEATURE_LEVEL_7:
771             return os << "FEATURE_LEVEL_7";
772         case Version::Level::FEATURE_LEVEL_8:
773             return os << "FEATURE_LEVEL_8";
774 #ifdef NN_EXPERIMENTAL_FEATURE
775         case Version::Level::FEATURE_LEVEL_EXPERIMENTAL:
776             return os << "FEATURE_LEVEL_EXPERIMENTAL";
777 #endif  // NN_EXPERIMENTAL_FEATURE
778     }
779     return os << "Version{" << static_cast<uint32_t>(underlyingType(versionLevel)) << "}";
780 }
781 
operator <<(std::ostream & os,const Version & version)782 std::ostream& operator<<(std::ostream& os, const Version& version) {
783     os << version.level;
784     if (version.runtimeOnlyFeatures) {
785         os << " (with runtime-specific features)";
786     }
787     return os;
788 }
789 
operator ==(const Timing & a,const Timing & b)790 bool operator==(const Timing& a, const Timing& b) {
791     return a.timeOnDevice == b.timeOnDevice && a.timeInDriver == b.timeInDriver;
792 }
operator !=(const Timing & a,const Timing & b)793 bool operator!=(const Timing& a, const Timing& b) {
794     return !(a == b);
795 }
796 
operator ==(const Capabilities::PerformanceInfo & a,const Capabilities::PerformanceInfo & b)797 bool operator==(const Capabilities::PerformanceInfo& a, const Capabilities::PerformanceInfo& b) {
798     return a.execTime == b.execTime && a.powerUsage == b.powerUsage;
799 }
operator !=(const Capabilities::PerformanceInfo & a,const Capabilities::PerformanceInfo & b)800 bool operator!=(const Capabilities::PerformanceInfo& a, const Capabilities::PerformanceInfo& b) {
801     return !(a == b);
802 }
803 
operator ==(const Capabilities::OperandPerformance & a,const Capabilities::OperandPerformance & b)804 bool operator==(const Capabilities::OperandPerformance& a,
805                 const Capabilities::OperandPerformance& b) {
806     return a.type == b.type && a.info == b.info;
807 }
operator !=(const Capabilities::OperandPerformance & a,const Capabilities::OperandPerformance & b)808 bool operator!=(const Capabilities::OperandPerformance& a,
809                 const Capabilities::OperandPerformance& b) {
810     return !(a == b);
811 }
812 
operator ==(const Capabilities & a,const Capabilities & b)813 bool operator==(const Capabilities& a, const Capabilities& b) {
814     return a.relaxedFloat32toFloat16PerformanceScalar ==
815                    b.relaxedFloat32toFloat16PerformanceScalar &&
816            a.relaxedFloat32toFloat16PerformanceTensor ==
817                    b.relaxedFloat32toFloat16PerformanceTensor &&
818            a.operandPerformance.asVector() == b.operandPerformance.asVector() &&
819            a.ifPerformance == b.ifPerformance && a.whilePerformance == b.whilePerformance;
820 }
operator !=(const Capabilities & a,const Capabilities & b)821 bool operator!=(const Capabilities& a, const Capabilities& b) {
822     return !(a == b);
823 }
824 
operator ==(const Extension::OperandTypeInformation & a,const Extension::OperandTypeInformation & b)825 bool operator==(const Extension::OperandTypeInformation& a,
826                 const Extension::OperandTypeInformation& b) {
827     return a.type == b.type && a.isTensor == b.isTensor && a.byteSize == b.byteSize;
828 }
operator !=(const Extension::OperandTypeInformation & a,const Extension::OperandTypeInformation & b)829 bool operator!=(const Extension::OperandTypeInformation& a,
830                 const Extension::OperandTypeInformation& b) {
831     return !(a == b);
832 }
833 
operator ==(const Extension & a,const Extension & b)834 bool operator==(const Extension& a, const Extension& b) {
835     return a.name == b.name && a.operandTypes == b.operandTypes;
836 }
operator !=(const Extension & a,const Extension & b)837 bool operator!=(const Extension& a, const Extension& b) {
838     return !(a == b);
839 }
840 
operator ==(const MemoryPreference & a,const MemoryPreference & b)841 bool operator==(const MemoryPreference& a, const MemoryPreference& b) {
842     return a.alignment == b.alignment && a.padding == b.padding;
843 }
operator !=(const MemoryPreference & a,const MemoryPreference & b)844 bool operator!=(const MemoryPreference& a, const MemoryPreference& b) {
845     return !(a == b);
846 }
847 
operator ==(const Operand::SymmPerChannelQuantParams & a,const Operand::SymmPerChannelQuantParams & b)848 bool operator==(const Operand::SymmPerChannelQuantParams& a,
849                 const Operand::SymmPerChannelQuantParams& b) {
850     return a.scales == b.scales && a.channelDim == b.channelDim;
851 }
operator !=(const Operand::SymmPerChannelQuantParams & a,const Operand::SymmPerChannelQuantParams & b)852 bool operator!=(const Operand::SymmPerChannelQuantParams& a,
853                 const Operand::SymmPerChannelQuantParams& b) {
854     return !(a == b);
855 }
856 
operator ==(const DataLocation & a,const DataLocation & b)857 static bool operator==(const DataLocation& a, const DataLocation& b) {
858     constexpr auto toTuple = [](const DataLocation& location) {
859         return std::tie(location.pointer, location.poolIndex, location.offset, location.length,
860                         location.padding);
861     };
862     return toTuple(a) == toTuple(b);
863 }
864 
operator ==(const Operand & a,const Operand & b)865 bool operator==(const Operand& a, const Operand& b) {
866     constexpr auto toTuple = [](const Operand& operand) {
867         return std::tie(operand.type, operand.dimensions, operand.scale, operand.zeroPoint,
868                         operand.lifetime, operand.location, operand.extraParams);
869     };
870     return toTuple(a) == toTuple(b);
871 }
operator !=(const Operand & a,const Operand & b)872 bool operator!=(const Operand& a, const Operand& b) {
873     return !(a == b);
874 }
875 
operator ==(const Operation & a,const Operation & b)876 bool operator==(const Operation& a, const Operation& b) {
877     constexpr auto toTuple = [](const Operation& operation) {
878         return std::tie(operation.type, operation.inputs, operation.outputs);
879     };
880     return toTuple(a) == toTuple(b);
881 }
operator !=(const Operation & a,const Operation & b)882 bool operator!=(const Operation& a, const Operation& b) {
883     return !(a == b);
884 }
885 
operator ==(const Version & a,const Version & b)886 bool operator==(const Version& a, const Version& b) {
887     return a.level == b.level && a.runtimeOnlyFeatures == b.runtimeOnlyFeatures;
888 }
operator !=(const Version & a,const Version & b)889 bool operator!=(const Version& a, const Version& b) {
890     return !(a == b);
891 }
892 
893 const char kVLogPropKey[] = "debug.nn.vlog";
894 int vLogMask = ~0;
895 
896 // Split the space separated list of tags from verbose log setting and build the
897 // logging mask from it. note that '1' and 'all' are special cases to enable all
898 // verbose logging.
899 //
900 // NN API verbose logging setting comes from system property debug.nn.vlog.
901 // Example:
902 // setprop debug.nn.vlog 1 : enable all logging tags.
903 // setprop debug.nn.vlog "model compilation" : only enable logging for MODEL and
904 //                                             COMPILATION tags.
initVLogMask()905 void initVLogMask() {
906     vLogMask = 0;
907     const std::string vLogSetting = android::base::GetProperty(kVLogPropKey, "");
908     if (vLogSetting.empty()) {
909         return;
910     }
911 
912     std::unordered_map<std::string, int> vLogFlags = {{"1", -1},
913                                                       {"all", -1},
914                                                       {"model", MODEL},
915                                                       {"compilation", COMPILATION},
916                                                       {"execution", EXECUTION},
917                                                       {"cpuexe", CPUEXE},
918                                                       {"manager", MANAGER},
919                                                       {"driver", DRIVER},
920                                                       {"memory", MEMORY}};
921 
922     std::vector<std::string> elements = android::base::Split(vLogSetting, " ,:");
923     for (const auto& elem : elements) {
924         const auto& flag = vLogFlags.find(elem);
925         if (flag == vLogFlags.end()) {
926             LOG(ERROR) << "Unknown trace flag: " << elem;
927             continue;
928         }
929 
930         if (flag->second == -1) {
931             // -1 is used for the special values "1" and "all" that enable all
932             // tracing.
933             vLogMask = ~0;
934             return;
935         } else {
936             vLogMask |= 1 << flag->second;
937         }
938     }
939 }
940 
941 }  // namespace android::nn
942