1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "Conversions.h"
18 
19 #include <aidl/android/hardware/common/Ashmem.h>
20 #include <aidl/android/hardware/common/MappableFile.h>
21 #include <aidl/android/hardware/common/NativeHandle.h>
22 #include <aidl/android/hardware/graphics/common/HardwareBuffer.h>
23 #include <aidlcommonsupport/NativeHandle.h>
24 #include <android-base/logging.h>
25 #include <android-base/mapped_file.h>
26 #include <android-base/unique_fd.h>
27 #include <android/binder_auto_utils.h>
28 #include <cutils/native_handle.h>
29 #include <nnapi/OperandTypes.h>
30 #include <nnapi/OperationTypes.h>
31 #include <nnapi/Result.h>
32 #include <nnapi/SharedMemory.h>
33 #include <nnapi/TypeUtils.h>
34 #include <nnapi/Types.h>
35 #include <nnapi/Validation.h>
36 #include <nnapi/hal/CommonUtils.h>
37 
38 #include <algorithm>
39 #include <chrono>
40 #include <functional>
41 #include <iterator>
42 #include <limits>
43 #include <type_traits>
44 #include <utility>
45 
46 #include "Utils.h"
47 
48 #ifdef __ANDROID__
49 #include <android/hardware_buffer.h>
50 #include <vndk/hardware_buffer.h>
51 #endif  // __ANDROID__
52 
53 #define VERIFY_NON_NEGATIVE(value) \
54     while (UNLIKELY(value < 0)) return NN_ERROR()
55 
56 #define VERIFY_LE_INT32_MAX(value) \
57     while (UNLIKELY(value > std::numeric_limits<int32_t>::max())) return NN_ERROR()
58 
59 namespace {
60 
61 constexpr int64_t kNoTiming = -1;
62 
63 }  // namespace
64 
65 namespace android::nn {
66 namespace {
67 
68 using ::aidl::android::hardware::common::NativeHandle;
69 using ::aidl::android::hardware::neuralnetworks::utils::underlyingType;
70 
71 template <typename Input>
72 using UnvalidatedConvertOutput =
73         std::decay_t<decltype(unvalidatedConvert(std::declval<Input>()).value())>;
74 
75 template <typename Type>
unvalidatedConvertVec(const std::vector<Type> & arguments)76 GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> unvalidatedConvertVec(
77         const std::vector<Type>& arguments) {
78     std::vector<UnvalidatedConvertOutput<Type>> canonical;
79     canonical.reserve(arguments.size());
80     for (const auto& argument : arguments) {
81         canonical.push_back(NN_TRY(nn::unvalidatedConvert(argument)));
82     }
83     return canonical;
84 }
85 
86 template <typename Type>
unvalidatedConvert(const std::vector<Type> & arguments)87 GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> unvalidatedConvert(
88         const std::vector<Type>& arguments) {
89     return unvalidatedConvertVec(arguments);
90 }
91 
92 template <typename Type>
validatedConvert(const Type & halObject)93 GeneralResult<UnvalidatedConvertOutput<Type>> validatedConvert(const Type& halObject) {
94     auto canonical = NN_TRY(nn::unvalidatedConvert(halObject));
95     NN_TRY(aidl_hal::utils::compliantVersion(canonical));
96     return canonical;
97 }
98 
99 template <typename Type>
validatedConvert(const std::vector<Type> & arguments)100 GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> validatedConvert(
101         const std::vector<Type>& arguments) {
102     std::vector<UnvalidatedConvertOutput<Type>> canonical;
103     canonical.reserve(arguments.size());
104     for (const auto& argument : arguments) {
105         canonical.push_back(NN_TRY(validatedConvert(argument)));
106     }
107     return canonical;
108 }
109 
110 struct NativeHandleDeleter {
operator ()android::nn::__anon497330ac0211::NativeHandleDeleter111     void operator()(native_handle_t* handle) const {
112         if (handle) {
113             native_handle_close(handle);
114             native_handle_delete(handle);
115         }
116     }
117 };
118 
119 using UniqueNativeHandle = std::unique_ptr<native_handle_t, NativeHandleDeleter>;
120 
121 #ifdef __ANDROID__
nativeHandleFromAidlHandle(const NativeHandle & handle)122 GeneralResult<UniqueNativeHandle> nativeHandleFromAidlHandle(const NativeHandle& handle) {
123     auto nativeHandle = UniqueNativeHandle(dupFromAidl(handle));
124     if (nativeHandle.get() == nullptr) {
125         return NN_ERROR() << "android::dupFromAidl failed to convert the common::NativeHandle to a "
126                              "native_handle_t";
127     }
128     if (!std::all_of(nativeHandle->data + 0, nativeHandle->data + nativeHandle->numFds,
129                      [](int fd) { return fd >= 0; })) {
130         return NN_ERROR() << "android::dupFromAidl returned an invalid native_handle_t";
131     }
132     return nativeHandle;
133 }
134 #endif  // __ANDROID__
135 
136 }  // anonymous namespace
137 
unvalidatedConvert(const aidl_hal::OperandType & operandType)138 GeneralResult<OperandType> unvalidatedConvert(const aidl_hal::OperandType& operandType) {
139     VERIFY_NON_NEGATIVE(underlyingType(operandType)) << "Negative operand types are not allowed.";
140     const auto canonical = static_cast<OperandType>(operandType);
141     if (canonical == OperandType::OEM || canonical == OperandType::TENSOR_OEM_BYTE) {
142         return NN_ERROR() << "Unable to convert invalid OperandType " << canonical;
143     }
144     return canonical;
145 }
146 
unvalidatedConvert(const aidl_hal::OperationType & operationType)147 GeneralResult<OperationType> unvalidatedConvert(const aidl_hal::OperationType& operationType) {
148     VERIFY_NON_NEGATIVE(underlyingType(operationType))
149             << "Negative operation types are not allowed.";
150     const auto canonical = static_cast<OperationType>(operationType);
151     if (canonical == OperationType::OEM_OPERATION) {
152         return NN_ERROR() << "Unable to convert invalid OperationType OEM_OPERATION";
153     }
154     return canonical;
155 }
156 
unvalidatedConvert(const aidl_hal::DeviceType & deviceType)157 GeneralResult<DeviceType> unvalidatedConvert(const aidl_hal::DeviceType& deviceType) {
158     return static_cast<DeviceType>(deviceType);
159 }
160 
unvalidatedConvert(const aidl_hal::Priority & priority)161 GeneralResult<Priority> unvalidatedConvert(const aidl_hal::Priority& priority) {
162     return static_cast<Priority>(priority);
163 }
164 
unvalidatedConvert(const aidl_hal::Capabilities & capabilities)165 GeneralResult<Capabilities> unvalidatedConvert(const aidl_hal::Capabilities& capabilities) {
166     const bool validOperandTypes = std::all_of(
167             capabilities.operandPerformance.begin(), capabilities.operandPerformance.end(),
168             [](const aidl_hal::OperandPerformance& operandPerformance) {
169                 return validatedConvert(operandPerformance.type).has_value();
170             });
171     if (!validOperandTypes) {
172         return NN_ERROR() << "Invalid OperandType when unvalidatedConverting OperandPerformance in "
173                              "Capabilities";
174     }
175 
176     auto operandPerformance = NN_TRY(unvalidatedConvert(capabilities.operandPerformance));
177     auto table =
178             NN_TRY(Capabilities::OperandPerformanceTable::create(std::move(operandPerformance)));
179 
180     const auto relaxedFloat32toFloat16PerformanceScalar =
181             NN_TRY(unvalidatedConvert(capabilities.relaxedFloat32toFloat16PerformanceScalar));
182     const auto relaxedFloat32toFloat16PerformanceTensor =
183             NN_TRY(unvalidatedConvert(capabilities.relaxedFloat32toFloat16PerformanceTensor));
184     const auto ifPerformance = NN_TRY(unvalidatedConvert(capabilities.ifPerformance));
185     const auto whilePerformance = NN_TRY(unvalidatedConvert(capabilities.whilePerformance));
186     return Capabilities{
187             .relaxedFloat32toFloat16PerformanceScalar = relaxedFloat32toFloat16PerformanceScalar,
188             .relaxedFloat32toFloat16PerformanceTensor = relaxedFloat32toFloat16PerformanceTensor,
189             .operandPerformance = std::move(table),
190             .ifPerformance = ifPerformance,
191             .whilePerformance = whilePerformance,
192     };
193 }
194 
unvalidatedConvert(const aidl_hal::OperandPerformance & operandPerformance)195 GeneralResult<Capabilities::OperandPerformance> unvalidatedConvert(
196         const aidl_hal::OperandPerformance& operandPerformance) {
197     const auto type = NN_TRY(unvalidatedConvert(operandPerformance.type));
198     const auto info = NN_TRY(unvalidatedConvert(operandPerformance.info));
199     return Capabilities::OperandPerformance{
200             .type = type,
201             .info = info,
202     };
203 }
204 
unvalidatedConvert(const aidl_hal::PerformanceInfo & performanceInfo)205 GeneralResult<Capabilities::PerformanceInfo> unvalidatedConvert(
206         const aidl_hal::PerformanceInfo& performanceInfo) {
207     return Capabilities::PerformanceInfo{
208             .execTime = performanceInfo.execTime,
209             .powerUsage = performanceInfo.powerUsage,
210     };
211 }
212 
unvalidatedConvert(const aidl_hal::DataLocation & location)213 GeneralResult<DataLocation> unvalidatedConvert(const aidl_hal::DataLocation& location) {
214     VERIFY_NON_NEGATIVE(location.poolIndex) << "DataLocation: pool index must not be negative";
215     VERIFY_NON_NEGATIVE(location.offset) << "DataLocation: offset must not be negative";
216     VERIFY_NON_NEGATIVE(location.length) << "DataLocation: length must not be negative";
217     VERIFY_NON_NEGATIVE(location.padding) << "DataLocation: padding must not be negative";
218     if (location.offset > std::numeric_limits<uint32_t>::max()) {
219         return NN_ERROR() << "DataLocation: offset must be <= std::numeric_limits<uint32_t>::max()";
220     }
221     if (location.length > std::numeric_limits<uint32_t>::max()) {
222         return NN_ERROR() << "DataLocation: length must be <= std::numeric_limits<uint32_t>::max()";
223     }
224     if (location.padding > std::numeric_limits<uint32_t>::max()) {
225         return NN_ERROR()
226                << "DataLocation: padding must be <= std::numeric_limits<uint32_t>::max()";
227     }
228     return DataLocation{
229             .poolIndex = static_cast<uint32_t>(location.poolIndex),
230             .offset = static_cast<uint32_t>(location.offset),
231             .length = static_cast<uint32_t>(location.length),
232             .padding = static_cast<uint32_t>(location.padding),
233     };
234 }
235 
unvalidatedConvert(const aidl_hal::Operation & operation)236 GeneralResult<Operation> unvalidatedConvert(const aidl_hal::Operation& operation) {
237     const auto type = NN_TRY(unvalidatedConvert(operation.type));
238     auto inputs = NN_TRY(toUnsigned(operation.inputs));
239     auto outputs = NN_TRY(toUnsigned(operation.outputs));
240     return Operation{
241             .type = type,
242             .inputs = std::move(inputs),
243             .outputs = std::move(outputs),
244     };
245 }
246 
unvalidatedConvert(const aidl_hal::OperandLifeTime & operandLifeTime)247 GeneralResult<Operand::LifeTime> unvalidatedConvert(
248         const aidl_hal::OperandLifeTime& operandLifeTime) {
249     return static_cast<Operand::LifeTime>(operandLifeTime);
250 }
251 
unvalidatedConvert(const aidl_hal::Operand & operand)252 GeneralResult<Operand> unvalidatedConvert(const aidl_hal::Operand& operand) {
253     const auto type = NN_TRY(unvalidatedConvert(operand.type));
254     auto dimensions = NN_TRY(toUnsigned(operand.dimensions));
255     const auto lifetime = NN_TRY(unvalidatedConvert(operand.lifetime));
256     const auto location = NN_TRY(unvalidatedConvert(operand.location));
257     auto extraParams = NN_TRY(unvalidatedConvert(operand.extraParams));
258     return Operand{
259             .type = type,
260             .dimensions = std::move(dimensions),
261             .scale = operand.scale,
262             .zeroPoint = operand.zeroPoint,
263             .lifetime = lifetime,
264             .location = location,
265             .extraParams = std::move(extraParams),
266     };
267 }
268 
unvalidatedConvert(const std::optional<aidl_hal::OperandExtraParams> & optionalExtraParams)269 GeneralResult<Operand::ExtraParams> unvalidatedConvert(
270         const std::optional<aidl_hal::OperandExtraParams>& optionalExtraParams) {
271     if (!optionalExtraParams.has_value()) {
272         return Operand::NoParams{};
273     }
274     const auto& extraParams = optionalExtraParams.value();
275     using Tag = aidl_hal::OperandExtraParams::Tag;
276     switch (extraParams.getTag()) {
277         case Tag::channelQuant:
278             return unvalidatedConvert(extraParams.get<Tag::channelQuant>());
279         case Tag::extension:
280             return extraParams.get<Tag::extension>();
281     }
282     return NN_ERROR() << "Unrecognized Operand::ExtraParams tag: "
283                       << underlyingType(extraParams.getTag());
284 }
285 
unvalidatedConvert(const aidl_hal::SymmPerChannelQuantParams & symmPerChannelQuantParams)286 GeneralResult<Operand::SymmPerChannelQuantParams> unvalidatedConvert(
287         const aidl_hal::SymmPerChannelQuantParams& symmPerChannelQuantParams) {
288     VERIFY_NON_NEGATIVE(symmPerChannelQuantParams.channelDim)
289             << "Per-channel quantization channel dimension must not be negative.";
290     return Operand::SymmPerChannelQuantParams{
291             .scales = symmPerChannelQuantParams.scales,
292             .channelDim = static_cast<uint32_t>(symmPerChannelQuantParams.channelDim),
293     };
294 }
295 
unvalidatedConvert(const aidl_hal::Model & model)296 GeneralResult<Model> unvalidatedConvert(const aidl_hal::Model& model) {
297     auto main = NN_TRY(unvalidatedConvert(model.main));
298     auto referenced = NN_TRY(unvalidatedConvert(model.referenced));
299     auto operandValues = NN_TRY(unvalidatedConvert(model.operandValues));
300     auto pools = NN_TRY(unvalidatedConvert(model.pools));
301     auto extensionNameToPrefix = NN_TRY(unvalidatedConvert(model.extensionNameToPrefix));
302     return Model{
303             .main = std::move(main),
304             .referenced = std::move(referenced),
305             .operandValues = std::move(operandValues),
306             .pools = std::move(pools),
307             .relaxComputationFloat32toFloat16 = model.relaxComputationFloat32toFloat16,
308             .extensionNameToPrefix = std::move(extensionNameToPrefix),
309     };
310 }
311 
unvalidatedConvert(const aidl_hal::Subgraph & subgraph)312 GeneralResult<Model::Subgraph> unvalidatedConvert(const aidl_hal::Subgraph& subgraph) {
313     auto operands = NN_TRY(unvalidatedConvert(subgraph.operands));
314     auto operations = NN_TRY(unvalidatedConvert(subgraph.operations));
315     auto inputIndexes = NN_TRY(toUnsigned(subgraph.inputIndexes));
316     auto outputIndexes = NN_TRY(toUnsigned(subgraph.outputIndexes));
317     return Model::Subgraph{
318             .operands = std::move(operands),
319             .operations = std::move(operations),
320             .inputIndexes = std::move(inputIndexes),
321             .outputIndexes = std::move(outputIndexes),
322     };
323 }
324 
unvalidatedConvert(const aidl_hal::ExtensionNameAndPrefix & extensionNameAndPrefix)325 GeneralResult<ExtensionNameAndPrefix> unvalidatedConvert(
326         const aidl_hal::ExtensionNameAndPrefix& extensionNameAndPrefix) {
327     return ExtensionNameAndPrefix{
328             .name = extensionNameAndPrefix.name,
329             .prefix = extensionNameAndPrefix.prefix,
330     };
331 }
332 
unvalidatedConvert(const aidl_hal::Extension & extension)333 GeneralResult<Extension> unvalidatedConvert(const aidl_hal::Extension& extension) {
334     auto operandTypes = NN_TRY(unvalidatedConvert(extension.operandTypes));
335     return Extension{
336             .name = extension.name,
337             .operandTypes = std::move(operandTypes),
338     };
339 }
340 
unvalidatedConvert(const aidl_hal::ExtensionOperandTypeInformation & operandTypeInformation)341 GeneralResult<Extension::OperandTypeInformation> unvalidatedConvert(
342         const aidl_hal::ExtensionOperandTypeInformation& operandTypeInformation) {
343     VERIFY_NON_NEGATIVE(operandTypeInformation.byteSize)
344             << "Extension operand type byte size must not be negative";
345     return Extension::OperandTypeInformation{
346             .type = operandTypeInformation.type,
347             .isTensor = operandTypeInformation.isTensor,
348             .byteSize = static_cast<uint32_t>(operandTypeInformation.byteSize),
349     };
350 }
351 
unvalidatedConvert(const aidl_hal::OutputShape & outputShape)352 GeneralResult<OutputShape> unvalidatedConvert(const aidl_hal::OutputShape& outputShape) {
353     auto dimensions = NN_TRY(toUnsigned(outputShape.dimensions));
354     return OutputShape{
355             .dimensions = std::move(dimensions),
356             .isSufficient = outputShape.isSufficient,
357     };
358 }
359 
unvalidatedConvert(bool measureTiming)360 GeneralResult<MeasureTiming> unvalidatedConvert(bool measureTiming) {
361     return measureTiming ? MeasureTiming::YES : MeasureTiming::NO;
362 }
363 
unvalidatedConvert(const aidl_hal::Memory & memory)364 GeneralResult<SharedMemory> unvalidatedConvert(const aidl_hal::Memory& memory) {
365     using Tag = aidl_hal::Memory::Tag;
366     switch (memory.getTag()) {
367         case Tag::ashmem: {
368             const auto& ashmem = memory.get<Tag::ashmem>();
369             VERIFY_NON_NEGATIVE(ashmem.size) << "Memory size must not be negative";
370             if (ashmem.size > std::numeric_limits<size_t>::max()) {
371                 return NN_ERROR() << "Memory: size must be <= std::numeric_limits<size_t>::max()";
372             }
373 
374             auto fd = NN_TRY(dupFd(ashmem.fd.get()));
375             auto handle = Memory::Ashmem{
376                     .fd = std::move(fd),
377                     .size = static_cast<size_t>(ashmem.size),
378             };
379             return std::make_shared<const Memory>(Memory{.handle = std::move(handle)});
380         }
381         case Tag::mappableFile: {
382             const auto& mappableFile = memory.get<Tag::mappableFile>();
383             VERIFY_NON_NEGATIVE(mappableFile.length) << "Memory size must not be negative";
384             VERIFY_NON_NEGATIVE(mappableFile.offset) << "Memory offset must not be negative";
385             if (mappableFile.length > std::numeric_limits<size_t>::max()) {
386                 return NN_ERROR() << "Memory: size must be <= std::numeric_limits<size_t>::max()";
387             }
388             if (mappableFile.offset > std::numeric_limits<size_t>::max()) {
389                 return NN_ERROR() << "Memory: offset must be <= std::numeric_limits<size_t>::max()";
390             }
391 
392             const size_t size = static_cast<size_t>(mappableFile.length);
393             const int prot = mappableFile.prot;
394             const int fd = mappableFile.fd.get();
395             const size_t offset = static_cast<size_t>(mappableFile.offset);
396 
397             return createSharedMemoryFromFd(size, prot, fd, offset);
398         }
399         case Tag::hardwareBuffer: {
400 #ifdef __ANDROID__
401             const auto& hardwareBuffer = memory.get<Tag::hardwareBuffer>();
402 
403             const UniqueNativeHandle handle =
404                     NN_TRY(nativeHandleFromAidlHandle(hardwareBuffer.handle));
405             const native_handle_t* nativeHandle = handle.get();
406 
407             const AHardwareBuffer_Desc desc{
408                     .width = static_cast<uint32_t>(hardwareBuffer.description.width),
409                     .height = static_cast<uint32_t>(hardwareBuffer.description.height),
410                     .layers = static_cast<uint32_t>(hardwareBuffer.description.layers),
411                     .format = static_cast<uint32_t>(hardwareBuffer.description.format),
412                     .usage = static_cast<uint64_t>(hardwareBuffer.description.usage),
413                     .stride = static_cast<uint32_t>(hardwareBuffer.description.stride),
414             };
415             AHardwareBuffer* ahwb = nullptr;
416             const status_t status = AHardwareBuffer_createFromHandle(
417                     &desc, nativeHandle, AHARDWAREBUFFER_CREATE_FROM_HANDLE_METHOD_CLONE, &ahwb);
418             if (status != NO_ERROR) {
419                 return NN_ERROR() << "createFromHandle failed";
420             }
421 
422             return createSharedMemoryFromAHWB(ahwb, /*takeOwnership=*/true);
423 #else   // __ANDROID__
424             LOG(ERROR) << "GeneralResult<SharedMemory> unvalidatedConvert(const aidl_hal::Memory& "
425                           "memory): Not Available on Host Build";
426             return NN_ERROR() << "createFromHandle failed";
427 #endif  // __ANDROID__
428         }
429     }
430     return NN_ERROR() << "Unrecognized Memory::Tag: " << underlyingType(memory.getTag());
431 }
432 
unvalidatedConvert(const aidl_hal::Timing & timing)433 GeneralResult<Timing> unvalidatedConvert(const aidl_hal::Timing& timing) {
434     if (timing.timeInDriverNs < -1) {
435         return NN_ERROR() << "Timing: timeInDriverNs must not be less than -1";
436     }
437     if (timing.timeOnDeviceNs < -1) {
438         return NN_ERROR() << "Timing: timeOnDeviceNs must not be less than -1";
439     }
440     constexpr auto convertTiming = [](int64_t halTiming) -> OptionalDuration {
441         if (halTiming == kNoTiming) {
442             return {};
443         }
444         return nn::Duration(static_cast<uint64_t>(halTiming));
445     };
446     return Timing{.timeOnDevice = convertTiming(timing.timeOnDeviceNs),
447                   .timeInDriver = convertTiming(timing.timeInDriverNs)};
448 }
449 
unvalidatedConvert(const std::vector<uint8_t> & operandValues)450 GeneralResult<Model::OperandValues> unvalidatedConvert(const std::vector<uint8_t>& operandValues) {
451     return Model::OperandValues(operandValues.data(), operandValues.size());
452 }
453 
unvalidatedConvert(const aidl_hal::BufferDesc & bufferDesc)454 GeneralResult<BufferDesc> unvalidatedConvert(const aidl_hal::BufferDesc& bufferDesc) {
455     auto dimensions = NN_TRY(toUnsigned(bufferDesc.dimensions));
456     return BufferDesc{.dimensions = std::move(dimensions)};
457 }
458 
unvalidatedConvert(const aidl_hal::BufferRole & bufferRole)459 GeneralResult<BufferRole> unvalidatedConvert(const aidl_hal::BufferRole& bufferRole) {
460     VERIFY_NON_NEGATIVE(bufferRole.modelIndex) << "BufferRole: modelIndex must not be negative";
461     VERIFY_NON_NEGATIVE(bufferRole.ioIndex) << "BufferRole: ioIndex must not be negative";
462     return BufferRole{
463             .modelIndex = static_cast<uint32_t>(bufferRole.modelIndex),
464             .ioIndex = static_cast<uint32_t>(bufferRole.ioIndex),
465             .probability = bufferRole.probability,
466     };
467 }
468 
unvalidatedConvert(const aidl_hal::Request & request)469 GeneralResult<Request> unvalidatedConvert(const aidl_hal::Request& request) {
470     auto inputs = NN_TRY(unvalidatedConvert(request.inputs));
471     auto outputs = NN_TRY(unvalidatedConvert(request.outputs));
472     auto pools = NN_TRY(unvalidatedConvert(request.pools));
473     return Request{
474             .inputs = std::move(inputs),
475             .outputs = std::move(outputs),
476             .pools = std::move(pools),
477     };
478 }
479 
unvalidatedConvert(const aidl_hal::RequestArgument & argument)480 GeneralResult<Request::Argument> unvalidatedConvert(const aidl_hal::RequestArgument& argument) {
481     const auto lifetime = argument.hasNoValue ? Request::Argument::LifeTime::NO_VALUE
482                                               : Request::Argument::LifeTime::POOL;
483     const auto location = NN_TRY(unvalidatedConvert(argument.location));
484     auto dimensions = NN_TRY(toUnsigned(argument.dimensions));
485     return Request::Argument{
486             .lifetime = lifetime,
487             .location = location,
488             .dimensions = std::move(dimensions),
489     };
490 }
491 
unvalidatedConvert(const aidl_hal::RequestMemoryPool & memoryPool)492 GeneralResult<Request::MemoryPool> unvalidatedConvert(
493         const aidl_hal::RequestMemoryPool& memoryPool) {
494     using Tag = aidl_hal::RequestMemoryPool::Tag;
495     switch (memoryPool.getTag()) {
496         case Tag::pool:
497             return unvalidatedConvert(memoryPool.get<Tag::pool>());
498         case Tag::token: {
499             const auto token = memoryPool.get<Tag::token>();
500             VERIFY_NON_NEGATIVE(token) << "Memory pool token must not be negative";
501             return static_cast<Request::MemoryDomainToken>(token);
502         }
503     }
504     return NN_ERROR() << "Invalid Request::MemoryPool tag " << underlyingType(memoryPool.getTag());
505 }
506 
unvalidatedConvert(const aidl_hal::ErrorStatus & status)507 GeneralResult<ErrorStatus> unvalidatedConvert(const aidl_hal::ErrorStatus& status) {
508     switch (status) {
509         case aidl_hal::ErrorStatus::NONE:
510         case aidl_hal::ErrorStatus::DEVICE_UNAVAILABLE:
511         case aidl_hal::ErrorStatus::GENERAL_FAILURE:
512         case aidl_hal::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE:
513         case aidl_hal::ErrorStatus::INVALID_ARGUMENT:
514         case aidl_hal::ErrorStatus::MISSED_DEADLINE_TRANSIENT:
515         case aidl_hal::ErrorStatus::MISSED_DEADLINE_PERSISTENT:
516         case aidl_hal::ErrorStatus::RESOURCE_EXHAUSTED_TRANSIENT:
517         case aidl_hal::ErrorStatus::RESOURCE_EXHAUSTED_PERSISTENT:
518             return static_cast<ErrorStatus>(status);
519     }
520     return NN_ERROR() << "Invalid ErrorStatus " << underlyingType(status);
521 }
522 
unvalidatedConvert(const aidl_hal::ExecutionPreference & executionPreference)523 GeneralResult<ExecutionPreference> unvalidatedConvert(
524         const aidl_hal::ExecutionPreference& executionPreference) {
525     return static_cast<ExecutionPreference>(executionPreference);
526 }
527 
unvalidatedConvert(const std::vector<aidl_hal::Operation> & operations)528 GeneralResult<std::vector<Operation>> unvalidatedConvert(
529         const std::vector<aidl_hal::Operation>& operations) {
530     return unvalidatedConvertVec(operations);
531 }
532 
unvalidatedConvert(const ndk::ScopedFileDescriptor & handle)533 GeneralResult<SharedHandle> unvalidatedConvert(const ndk::ScopedFileDescriptor& handle) {
534     auto duplicatedFd = NN_TRY(dupFd(handle.get()));
535     return std::make_shared<const Handle>(std::move(duplicatedFd));
536 }
537 
538 #ifdef NN_AIDL_V4_OR_ABOVE
unvalidatedConvert(const aidl_hal::TokenValuePair & tokenValuePair)539 GeneralResult<TokenValuePair> unvalidatedConvert(const aidl_hal::TokenValuePair& tokenValuePair) {
540     return TokenValuePair{.token = tokenValuePair.token, .value = tokenValuePair.value};
541 }
542 #endif  // NN_AIDL_V4_OR_ABOVE
543 
convert(const aidl_hal::Capabilities & capabilities)544 GeneralResult<Capabilities> convert(const aidl_hal::Capabilities& capabilities) {
545     return validatedConvert(capabilities);
546 }
547 
convert(const aidl_hal::DeviceType & deviceType)548 GeneralResult<DeviceType> convert(const aidl_hal::DeviceType& deviceType) {
549     return validatedConvert(deviceType);
550 }
551 
convert(const aidl_hal::ErrorStatus & errorStatus)552 GeneralResult<ErrorStatus> convert(const aidl_hal::ErrorStatus& errorStatus) {
553     return validatedConvert(errorStatus);
554 }
555 
convert(const aidl_hal::ExecutionPreference & executionPreference)556 GeneralResult<ExecutionPreference> convert(
557         const aidl_hal::ExecutionPreference& executionPreference) {
558     return validatedConvert(executionPreference);
559 }
560 
convert(const aidl_hal::Memory & operand)561 GeneralResult<SharedMemory> convert(const aidl_hal::Memory& operand) {
562     return validatedConvert(operand);
563 }
564 
convert(const aidl_hal::Model & model)565 GeneralResult<Model> convert(const aidl_hal::Model& model) {
566     return validatedConvert(model);
567 }
568 
convert(const aidl_hal::OperandType & operandType)569 GeneralResult<OperandType> convert(const aidl_hal::OperandType& operandType) {
570     return validatedConvert(operandType);
571 }
572 
convert(const aidl_hal::Priority & priority)573 GeneralResult<Priority> convert(const aidl_hal::Priority& priority) {
574     return validatedConvert(priority);
575 }
576 
convert(const aidl_hal::Request & request)577 GeneralResult<Request> convert(const aidl_hal::Request& request) {
578     return validatedConvert(request);
579 }
580 
convert(const aidl_hal::Timing & timing)581 GeneralResult<Timing> convert(const aidl_hal::Timing& timing) {
582     return validatedConvert(timing);
583 }
584 
convert(const ndk::ScopedFileDescriptor & handle)585 GeneralResult<SharedHandle> convert(const ndk::ScopedFileDescriptor& handle) {
586     return validatedConvert(handle);
587 }
588 
convert(const aidl_hal::BufferDesc & bufferDesc)589 GeneralResult<BufferDesc> convert(const aidl_hal::BufferDesc& bufferDesc) {
590     return validatedConvert(bufferDesc);
591 }
592 
convert(const std::vector<aidl_hal::Extension> & extension)593 GeneralResult<std::vector<Extension>> convert(const std::vector<aidl_hal::Extension>& extension) {
594     return validatedConvert(extension);
595 }
596 
convert(const std::vector<aidl_hal::Memory> & memories)597 GeneralResult<std::vector<SharedMemory>> convert(const std::vector<aidl_hal::Memory>& memories) {
598     return validatedConvert(memories);
599 }
convert(const std::vector<aidl_hal::ExtensionNameAndPrefix> & extensionNameAndPrefix)600 GeneralResult<std::vector<ExtensionNameAndPrefix>> convert(
601         const std::vector<aidl_hal::ExtensionNameAndPrefix>& extensionNameAndPrefix) {
602     return unvalidatedConvert(extensionNameAndPrefix);
603 }
604 
605 #ifdef NN_AIDL_V4_OR_ABOVE
convert(const std::vector<aidl_hal::TokenValuePair> & metaData)606 GeneralResult<std::vector<TokenValuePair>> convert(
607         const std::vector<aidl_hal::TokenValuePair>& metaData) {
608     return validatedConvert(metaData);
609 }
610 #endif  // NN_AIDL_V4_OR_ABOVE
611 
convert(const std::vector<aidl_hal::OutputShape> & outputShapes)612 GeneralResult<std::vector<OutputShape>> convert(
613         const std::vector<aidl_hal::OutputShape>& outputShapes) {
614     return validatedConvert(outputShapes);
615 }
616 
convert(const std::vector<ndk::ScopedFileDescriptor> & handles)617 GeneralResult<std::vector<SharedHandle>> convert(
618         const std::vector<ndk::ScopedFileDescriptor>& handles) {
619     return validatedConvert(handles);
620 }
621 
convert(const std::vector<aidl_hal::BufferRole> & roles)622 GeneralResult<std::vector<BufferRole>> convert(const std::vector<aidl_hal::BufferRole>& roles) {
623     return validatedConvert(roles);
624 }
625 
toUnsigned(const std::vector<int32_t> & vec)626 GeneralResult<std::vector<uint32_t>> toUnsigned(const std::vector<int32_t>& vec) {
627     if (!std::all_of(vec.begin(), vec.end(), [](int32_t v) { return v >= 0; })) {
628         return NN_ERROR() << "Negative value passed to conversion from signed to unsigned";
629     }
630     return std::vector<uint32_t>(vec.begin(), vec.end());
631 }
632 
633 }  // namespace android::nn
634 
635 namespace aidl::android::hardware::neuralnetworks::utils {
636 namespace {
637 
638 using utils::unvalidatedConvert;
639 
640 // Helper template for std::visit
641 template <class... Ts>
642 struct overloaded : Ts... {
643     using Ts::operator()...;
644 };
645 template <class... Ts>
646 overloaded(Ts...) -> overloaded<Ts...>;
647 
648 #ifdef __ANDROID__
aidlHandleFromNativeHandle(const native_handle_t & nativeHandle)649 nn::GeneralResult<common::NativeHandle> aidlHandleFromNativeHandle(
650         const native_handle_t& nativeHandle) {
651     auto handle = ::android::dupToAidl(&nativeHandle);
652     if (!std::all_of(handle.fds.begin(), handle.fds.end(),
653                      [](const ndk::ScopedFileDescriptor& fd) { return fd.get() >= 0; })) {
654         return NN_ERROR() << "android::dupToAidl returned an invalid common::NativeHandle";
655     }
656     return handle;
657 }
658 #endif  // __ANDROID__
659 
unvalidatedConvert(const nn::Memory::Ashmem & memory)660 nn::GeneralResult<Memory> unvalidatedConvert(const nn::Memory::Ashmem& memory) {
661     if constexpr (std::numeric_limits<size_t>::max() > std::numeric_limits<int64_t>::max()) {
662         if (memory.size > std::numeric_limits<int64_t>::max()) {
663             return (
664                            NN_ERROR()
665                            << "Memory::Ashmem: size must be <= std::numeric_limits<int64_t>::max()")
666                     .
667                     operator nn::GeneralResult<Memory>();
668         }
669     }
670 
671     auto fd = NN_TRY(nn::dupFd(memory.fd));
672     auto handle = common::Ashmem{
673             .fd = ndk::ScopedFileDescriptor(fd.release()),
674             .size = static_cast<int64_t>(memory.size),
675     };
676     return Memory::make<Memory::Tag::ashmem>(std::move(handle));
677 }
678 
unvalidatedConvert(const nn::Memory::Fd & memory)679 nn::GeneralResult<Memory> unvalidatedConvert(const nn::Memory::Fd& memory) {
680     if constexpr (std::numeric_limits<size_t>::max() > std::numeric_limits<int64_t>::max()) {
681         if (memory.size > std::numeric_limits<int64_t>::max()) {
682             return (NN_ERROR() << "Memory::Fd: size must be <= std::numeric_limits<int64_t>::max()")
683                     .
684                     operator nn::GeneralResult<Memory>();
685         }
686         if (memory.offset > std::numeric_limits<int64_t>::max()) {
687             return (
688                            NN_ERROR()
689                            << "Memory::Fd: offset must be <= std::numeric_limits<int64_t>::max()")
690                     .
691                     operator nn::GeneralResult<Memory>();
692         }
693     }
694 
695     auto fd = NN_TRY(nn::dupFd(memory.fd));
696     auto handle = common::MappableFile{
697             .length = static_cast<int64_t>(memory.size),
698             .prot = memory.prot,
699             .fd = ndk::ScopedFileDescriptor(fd.release()),
700             .offset = static_cast<int64_t>(memory.offset),
701     };
702     return Memory::make<Memory::Tag::mappableFile>(std::move(handle));
703 }
704 
unvalidatedConvert(const nn::Memory::HardwareBuffer & memory)705 nn::GeneralResult<Memory> unvalidatedConvert(const nn::Memory::HardwareBuffer& memory) {
706 #ifdef __ANDROID__
707     const native_handle_t* nativeHandle = AHardwareBuffer_getNativeHandle(memory.handle.get());
708     if (nativeHandle == nullptr) {
709         return (NN_ERROR() << "unvalidatedConvert failed because AHardwareBuffer_getNativeHandle "
710                               "returned nullptr")
711                 .
712                 operator nn::GeneralResult<Memory>();
713     }
714 
715     auto handle = NN_TRY(aidlHandleFromNativeHandle(*nativeHandle));
716 
717     AHardwareBuffer_Desc desc;
718     AHardwareBuffer_describe(memory.handle.get(), &desc);
719 
720     const auto description = graphics::common::HardwareBufferDescription{
721             .width = static_cast<int32_t>(desc.width),
722             .height = static_cast<int32_t>(desc.height),
723             .layers = static_cast<int32_t>(desc.layers),
724             .format = static_cast<graphics::common::PixelFormat>(desc.format),
725             .usage = static_cast<graphics::common::BufferUsage>(desc.usage),
726             .stride = static_cast<int32_t>(desc.stride),
727     };
728 
729     auto hardwareBuffer = graphics::common::HardwareBuffer{
730             .description = std::move(description),
731             .handle = std::move(handle),
732     };
733     return Memory::make<Memory::Tag::hardwareBuffer>(std::move(hardwareBuffer));
734 #else   // __ANDROID__
735     LOG(ERROR) << "nn::GeneralResult<Memory> unvalidatedConvert(const nn::Memory::HardwareBuffer& "
736                   "memory): Not Available on Host Build";
737     (void)memory;
738     return (NN_ERROR() << "unvalidatedConvert failed").operator nn::GeneralResult<Memory>();
739 #endif  // __ANDROID__
740 }
741 
unvalidatedConvert(const nn::Memory::Unknown &)742 nn::GeneralResult<Memory> unvalidatedConvert(const nn::Memory::Unknown& /*memory*/) {
743     return (NN_ERROR() << "Unable to convert Unknown memory type")
744             .
745             operator nn::GeneralResult<Memory>();
746 }
747 
unvalidatedConvert(const nn::Capabilities::PerformanceInfo & info)748 nn::GeneralResult<PerformanceInfo> unvalidatedConvert(
749         const nn::Capabilities::PerformanceInfo& info) {
750     return PerformanceInfo{.execTime = info.execTime, .powerUsage = info.powerUsage};
751 }
752 
unvalidatedConvert(const nn::Capabilities::OperandPerformance & operandPerformance)753 nn::GeneralResult<OperandPerformance> unvalidatedConvert(
754         const nn::Capabilities::OperandPerformance& operandPerformance) {
755     const auto type = NN_TRY(unvalidatedConvert(operandPerformance.type));
756     const auto info = NN_TRY(unvalidatedConvert(operandPerformance.info));
757     return OperandPerformance{.type = type, .info = info};
758 }
759 
unvalidatedConvert(const nn::Capabilities::OperandPerformanceTable & table)760 nn::GeneralResult<std::vector<OperandPerformance>> unvalidatedConvert(
761         const nn::Capabilities::OperandPerformanceTable& table) {
762     std::vector<OperandPerformance> operandPerformances;
763     operandPerformances.reserve(table.asVector().size());
764     for (const auto& operandPerformance : table.asVector()) {
765         operandPerformances.push_back(NN_TRY(unvalidatedConvert(operandPerformance)));
766     }
767     return operandPerformances;
768 }
769 
unvalidatedConvert(const nn::Extension::OperandTypeInformation & info)770 nn::GeneralResult<ExtensionOperandTypeInformation> unvalidatedConvert(
771         const nn::Extension::OperandTypeInformation& info) {
772     return ExtensionOperandTypeInformation{.type = info.type,
773                                            .isTensor = info.isTensor,
774                                            .byteSize = static_cast<int32_t>(info.byteSize)};
775 }
776 
unvalidatedConvert(const nn::Duration & duration)777 nn::GeneralResult<int64_t> unvalidatedConvert(const nn::Duration& duration) {
778     if (duration < nn::Duration::zero()) {
779         return NN_ERROR() << "Unable to convert invalid (negative) duration";
780     }
781     constexpr std::chrono::nanoseconds::rep kIntMax = std::numeric_limits<int64_t>::max();
782     const auto count = duration.count();
783     return static_cast<int64_t>(std::min(count, kIntMax));
784 }
785 
786 template <typename Input>
787 using UnvalidatedConvertOutput =
788         std::decay_t<decltype(unvalidatedConvert(std::declval<Input>()).value())>;
789 
790 template <typename Type>
unvalidatedConvert(const std::vector<Type> & arguments)791 nn::GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> unvalidatedConvert(
792         const std::vector<Type>& arguments) {
793     std::vector<UnvalidatedConvertOutput<Type>> halObject;
794     halObject.reserve(arguments.size());
795     for (const auto& argument : arguments) {
796         halObject.push_back(NN_TRY(unvalidatedConvert(argument)));
797     }
798     return halObject;
799 }
800 
801 template <typename Type>
validatedConvert(const Type & canonical)802 nn::GeneralResult<UnvalidatedConvertOutput<Type>> validatedConvert(const Type& canonical) {
803     NN_TRY(compliantVersion(canonical));
804     return utils::unvalidatedConvert(canonical);
805 }
806 
807 template <typename Type>
validatedConvert(const std::vector<Type> & arguments)808 nn::GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> validatedConvert(
809         const std::vector<Type>& arguments) {
810     std::vector<UnvalidatedConvertOutput<Type>> halObject(arguments.size());
811     for (size_t i = 0; i < arguments.size(); ++i) {
812         halObject[i] = NN_TRY(validatedConvert(arguments[i]));
813     }
814     return halObject;
815 }
816 
817 }  // namespace
818 
unvalidatedConvert(const nn::CacheToken & cacheToken)819 nn::GeneralResult<std::vector<uint8_t>> unvalidatedConvert(const nn::CacheToken& cacheToken) {
820     return std::vector<uint8_t>(cacheToken.begin(), cacheToken.end());
821 }
822 
unvalidatedConvert(const nn::BufferDesc & bufferDesc)823 nn::GeneralResult<BufferDesc> unvalidatedConvert(const nn::BufferDesc& bufferDesc) {
824     auto dimensions = NN_TRY(toSigned(bufferDesc.dimensions));
825     return BufferDesc{.dimensions = std::move(dimensions)};
826 }
827 
unvalidatedConvert(const nn::BufferRole & bufferRole)828 nn::GeneralResult<BufferRole> unvalidatedConvert(const nn::BufferRole& bufferRole) {
829     VERIFY_LE_INT32_MAX(bufferRole.modelIndex)
830             << "BufferRole: modelIndex must be <= std::numeric_limits<int32_t>::max()";
831     VERIFY_LE_INT32_MAX(bufferRole.ioIndex)
832             << "BufferRole: ioIndex must be <= std::numeric_limits<int32_t>::max()";
833     return BufferRole{
834             .modelIndex = static_cast<int32_t>(bufferRole.modelIndex),
835             .ioIndex = static_cast<int32_t>(bufferRole.ioIndex),
836             .probability = bufferRole.probability,
837     };
838 }
839 
unvalidatedConvert(const nn::DeviceType & deviceType)840 nn::GeneralResult<DeviceType> unvalidatedConvert(const nn::DeviceType& deviceType) {
841     switch (deviceType) {
842         case nn::DeviceType::UNKNOWN:
843             break;
844         case nn::DeviceType::OTHER:
845         case nn::DeviceType::CPU:
846         case nn::DeviceType::GPU:
847         case nn::DeviceType::ACCELERATOR:
848             return static_cast<DeviceType>(deviceType);
849     }
850     return NN_ERROR() << "Invalid DeviceType " << deviceType;
851 }
852 
unvalidatedConvert(const nn::MeasureTiming & measureTiming)853 nn::GeneralResult<bool> unvalidatedConvert(const nn::MeasureTiming& measureTiming) {
854     return measureTiming == nn::MeasureTiming::YES;
855 }
856 
unvalidatedConvert(const nn::SharedMemory & memory)857 nn::GeneralResult<Memory> unvalidatedConvert(const nn::SharedMemory& memory) {
858     if (memory == nullptr) {
859         return (NN_ERROR() << "Unable to convert nullptr memory")
860                 .
861                 operator nn::GeneralResult<Memory>();
862     }
863     return std::visit([](const auto& x) { return unvalidatedConvert(x); }, memory->handle);
864 }
865 
unvalidatedConvert(const nn::ErrorStatus & errorStatus)866 nn::GeneralResult<ErrorStatus> unvalidatedConvert(const nn::ErrorStatus& errorStatus) {
867     switch (errorStatus) {
868         case nn::ErrorStatus::NONE:
869         case nn::ErrorStatus::DEVICE_UNAVAILABLE:
870         case nn::ErrorStatus::GENERAL_FAILURE:
871         case nn::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE:
872         case nn::ErrorStatus::INVALID_ARGUMENT:
873         case nn::ErrorStatus::MISSED_DEADLINE_TRANSIENT:
874         case nn::ErrorStatus::MISSED_DEADLINE_PERSISTENT:
875         case nn::ErrorStatus::RESOURCE_EXHAUSTED_TRANSIENT:
876         case nn::ErrorStatus::RESOURCE_EXHAUSTED_PERSISTENT:
877             return static_cast<ErrorStatus>(errorStatus);
878         default:
879             return ErrorStatus::GENERAL_FAILURE;
880     }
881 }
882 
unvalidatedConvert(const nn::OutputShape & outputShape)883 nn::GeneralResult<OutputShape> unvalidatedConvert(const nn::OutputShape& outputShape) {
884     auto dimensions = NN_TRY(toSigned(outputShape.dimensions));
885     return OutputShape{.dimensions = std::move(dimensions),
886                        .isSufficient = outputShape.isSufficient};
887 }
888 
unvalidatedConvert(const nn::ExecutionPreference & executionPreference)889 nn::GeneralResult<ExecutionPreference> unvalidatedConvert(
890         const nn::ExecutionPreference& executionPreference) {
891     return static_cast<ExecutionPreference>(executionPreference);
892 }
893 
unvalidatedConvert(const nn::OperandType & operandType)894 nn::GeneralResult<OperandType> unvalidatedConvert(const nn::OperandType& operandType) {
895     if (operandType == nn::OperandType::OEM || operandType == nn::OperandType::TENSOR_OEM_BYTE) {
896         return NN_ERROR() << "Unable to convert invalid OperandType " << operandType;
897     }
898     return static_cast<OperandType>(operandType);
899 }
900 
unvalidatedConvert(const nn::Operand::LifeTime & operandLifeTime)901 nn::GeneralResult<OperandLifeTime> unvalidatedConvert(
902         const nn::Operand::LifeTime& operandLifeTime) {
903     return static_cast<OperandLifeTime>(operandLifeTime);
904 }
905 
unvalidatedConvert(const nn::DataLocation & location)906 nn::GeneralResult<DataLocation> unvalidatedConvert(const nn::DataLocation& location) {
907     VERIFY_LE_INT32_MAX(location.poolIndex)
908             << "DataLocation: pool index must be <= std::numeric_limits<int32_t>::max()";
909     return DataLocation{
910             .poolIndex = static_cast<int32_t>(location.poolIndex),
911             .offset = static_cast<int64_t>(location.offset),
912             .length = static_cast<int64_t>(location.length),
913     };
914 }
915 
unvalidatedConvert(const nn::Operand::ExtraParams & extraParams)916 nn::GeneralResult<std::optional<OperandExtraParams>> unvalidatedConvert(
917         const nn::Operand::ExtraParams& extraParams) {
918     return std::visit(
919             overloaded{
920                     [](const nn::Operand::NoParams&)
921                             -> nn::GeneralResult<std::optional<OperandExtraParams>> {
922                         return std::nullopt;
923                     },
924                     [](const nn::Operand::SymmPerChannelQuantParams& symmPerChannelQuantParams)
925                             -> nn::GeneralResult<std::optional<OperandExtraParams>> {
926                         if (symmPerChannelQuantParams.channelDim >
927                             std::numeric_limits<int32_t>::max()) {
928                             // Using explicit type conversion because std::optional in successful
929                             // result confuses the compiler.
930                             return (NN_ERROR() << "symmPerChannelQuantParams.channelDim must be <= "
931                                                   "std::numeric_limits<int32_t>::max(), received: "
932                                                << symmPerChannelQuantParams.channelDim)
933                                     .
934                                     operator nn::GeneralResult<std::optional<OperandExtraParams>>();
935                         }
936                         return OperandExtraParams::make<OperandExtraParams::Tag::channelQuant>(
937                                 SymmPerChannelQuantParams{
938                                         .scales = symmPerChannelQuantParams.scales,
939                                         .channelDim = static_cast<int32_t>(
940                                                 symmPerChannelQuantParams.channelDim),
941                                 });
942                     },
943                     [](const nn::Operand::ExtensionParams& extensionParams)
944                             -> nn::GeneralResult<std::optional<OperandExtraParams>> {
945                         return OperandExtraParams::make<OperandExtraParams::Tag::extension>(
946                                 extensionParams);
947                     },
948             },
949             extraParams);
950 }
951 
unvalidatedConvert(const nn::Operand & operand)952 nn::GeneralResult<Operand> unvalidatedConvert(const nn::Operand& operand) {
953     const auto type = NN_TRY(unvalidatedConvert(operand.type));
954     auto dimensions = NN_TRY(toSigned(operand.dimensions));
955     const auto lifetime = NN_TRY(unvalidatedConvert(operand.lifetime));
956     const auto location = NN_TRY(unvalidatedConvert(operand.location));
957     auto extraParams = NN_TRY(unvalidatedConvert(operand.extraParams));
958     return Operand{
959             .type = type,
960             .dimensions = std::move(dimensions),
961             .scale = operand.scale,
962             .zeroPoint = operand.zeroPoint,
963             .lifetime = lifetime,
964             .location = location,
965             .extraParams = std::move(extraParams),
966     };
967 }
968 
unvalidatedConvert(const nn::OperationType & operationType)969 nn::GeneralResult<OperationType> unvalidatedConvert(const nn::OperationType& operationType) {
970     if (operationType == nn::OperationType::OEM_OPERATION) {
971         return NN_ERROR() << "Unable to convert invalid OperationType OEM_OPERATION";
972     }
973     return static_cast<OperationType>(operationType);
974 }
975 
unvalidatedConvert(const nn::Operation & operation)976 nn::GeneralResult<Operation> unvalidatedConvert(const nn::Operation& operation) {
977     const auto type = NN_TRY(unvalidatedConvert(operation.type));
978     auto inputs = NN_TRY(toSigned(operation.inputs));
979     auto outputs = NN_TRY(toSigned(operation.outputs));
980     return Operation{
981             .type = type,
982             .inputs = std::move(inputs),
983             .outputs = std::move(outputs),
984     };
985 }
986 
unvalidatedConvert(const nn::Model::Subgraph & subgraph)987 nn::GeneralResult<Subgraph> unvalidatedConvert(const nn::Model::Subgraph& subgraph) {
988     auto operands = NN_TRY(unvalidatedConvert(subgraph.operands));
989     auto operations = NN_TRY(unvalidatedConvert(subgraph.operations));
990     auto inputIndexes = NN_TRY(toSigned(subgraph.inputIndexes));
991     auto outputIndexes = NN_TRY(toSigned(subgraph.outputIndexes));
992     return Subgraph{
993             .operands = std::move(operands),
994             .operations = std::move(operations),
995             .inputIndexes = std::move(inputIndexes),
996             .outputIndexes = std::move(outputIndexes),
997     };
998 }
999 
unvalidatedConvert(const nn::Model::OperandValues & operandValues)1000 nn::GeneralResult<std::vector<uint8_t>> unvalidatedConvert(
1001         const nn::Model::OperandValues& operandValues) {
1002     return std::vector<uint8_t>(operandValues.data(), operandValues.data() + operandValues.size());
1003 }
1004 
unvalidatedConvert(const nn::ExtensionNameAndPrefix & extensionNameToPrefix)1005 nn::GeneralResult<ExtensionNameAndPrefix> unvalidatedConvert(
1006         const nn::ExtensionNameAndPrefix& extensionNameToPrefix) {
1007     return ExtensionNameAndPrefix{
1008             .name = extensionNameToPrefix.name,
1009             .prefix = extensionNameToPrefix.prefix,
1010     };
1011 }
1012 
unvalidatedConvert(const nn::Model & model)1013 nn::GeneralResult<Model> unvalidatedConvert(const nn::Model& model) {
1014     if (!hal::utils::hasNoPointerData(model)) {
1015         return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT)
1016                << "Model cannot be unvalidatedConverted because it contains pointer-based memory";
1017     }
1018 
1019     auto main = NN_TRY(unvalidatedConvert(model.main));
1020     auto referenced = NN_TRY(unvalidatedConvert(model.referenced));
1021     auto operandValues = NN_TRY(unvalidatedConvert(model.operandValues));
1022     auto pools = NN_TRY(unvalidatedConvert(model.pools));
1023     auto extensionNameToPrefix = NN_TRY(unvalidatedConvert(model.extensionNameToPrefix));
1024     return Model{
1025             .main = std::move(main),
1026             .referenced = std::move(referenced),
1027             .operandValues = std::move(operandValues),
1028             .pools = std::move(pools),
1029             .relaxComputationFloat32toFloat16 = model.relaxComputationFloat32toFloat16,
1030             .extensionNameToPrefix = std::move(extensionNameToPrefix),
1031     };
1032 }
1033 
unvalidatedConvert(const nn::Priority & priority)1034 nn::GeneralResult<Priority> unvalidatedConvert(const nn::Priority& priority) {
1035     return static_cast<Priority>(priority);
1036 }
1037 
unvalidatedConvert(const nn::Request & request)1038 nn::GeneralResult<Request> unvalidatedConvert(const nn::Request& request) {
1039     if (!hal::utils::hasNoPointerData(request)) {
1040         return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT)
1041                << "Request cannot be unvalidatedConverted because it contains pointer-based memory";
1042     }
1043 
1044     auto inputs = NN_TRY(unvalidatedConvert(request.inputs));
1045     auto outputs = NN_TRY(unvalidatedConvert(request.outputs));
1046     auto pools = NN_TRY(unvalidatedConvert(request.pools));
1047     return Request{
1048             .inputs = std::move(inputs),
1049             .outputs = std::move(outputs),
1050             .pools = std::move(pools),
1051     };
1052 }
1053 
unvalidatedConvert(const nn::Request::Argument & requestArgument)1054 nn::GeneralResult<RequestArgument> unvalidatedConvert(
1055         const nn::Request::Argument& requestArgument) {
1056     if (requestArgument.lifetime == nn::Request::Argument::LifeTime::POINTER) {
1057         return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT)
1058                << "Request cannot be unvalidatedConverted because it contains pointer-based memory";
1059     }
1060     const bool hasNoValue = requestArgument.lifetime == nn::Request::Argument::LifeTime::NO_VALUE;
1061     const auto location = NN_TRY(unvalidatedConvert(requestArgument.location));
1062     auto dimensions = NN_TRY(toSigned(requestArgument.dimensions));
1063     return RequestArgument{
1064             .hasNoValue = hasNoValue,
1065             .location = location,
1066             .dimensions = std::move(dimensions),
1067     };
1068 }
1069 
unvalidatedConvert(const nn::Request::MemoryPool & memoryPool)1070 nn::GeneralResult<RequestMemoryPool> unvalidatedConvert(const nn::Request::MemoryPool& memoryPool) {
1071     return std::visit(
1072             overloaded{
1073                     [](const nn::SharedMemory& memory) -> nn::GeneralResult<RequestMemoryPool> {
1074                         return RequestMemoryPool::make<RequestMemoryPool::Tag::pool>(
1075                                 NN_TRY(unvalidatedConvert(memory)));
1076                     },
1077                     [](const nn::Request::MemoryDomainToken& token)
1078                             -> nn::GeneralResult<RequestMemoryPool> {
1079                         return RequestMemoryPool::make<RequestMemoryPool::Tag::token>(
1080                                 underlyingType(token));
1081                     },
1082                     [](const nn::SharedBuffer& /*buffer*/) {
1083                         return (NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
1084                                 << "Unable to make memory pool from IBuffer")
1085                                 .
1086                                 operator nn::GeneralResult<RequestMemoryPool>();
1087                     },
1088             },
1089             memoryPool);
1090 }
1091 
unvalidatedConvert(const nn::Timing & timing)1092 nn::GeneralResult<Timing> unvalidatedConvert(const nn::Timing& timing) {
1093     const auto timeOnDeviceNs = NN_TRY(unvalidatedConvert(timing.timeOnDevice));
1094     const auto timeInDriverNs = NN_TRY(unvalidatedConvert(timing.timeInDriver));
1095     return Timing{
1096             .timeOnDeviceNs = timeOnDeviceNs,
1097             .timeInDriverNs = timeInDriverNs,
1098     };
1099 }
1100 
unvalidatedConvert(const nn::OptionalDuration & optionalDuration)1101 nn::GeneralResult<int64_t> unvalidatedConvert(const nn::OptionalDuration& optionalDuration) {
1102     if (!optionalDuration.has_value()) {
1103         return kNoTiming;
1104     }
1105     return unvalidatedConvert(optionalDuration.value());
1106 }
1107 
unvalidatedConvert(const nn::OptionalTimePoint & optionalTimePoint)1108 nn::GeneralResult<int64_t> unvalidatedConvert(const nn::OptionalTimePoint& optionalTimePoint) {
1109     if (!optionalTimePoint.has_value()) {
1110         return kNoTiming;
1111     }
1112     return unvalidatedConvert(optionalTimePoint->time_since_epoch());
1113 }
1114 
unvalidatedConvert(const nn::SyncFence & syncFence)1115 nn::GeneralResult<ndk::ScopedFileDescriptor> unvalidatedConvert(const nn::SyncFence& syncFence) {
1116     auto duplicatedFd = NN_TRY(nn::dupFd(syncFence.getFd()));
1117     return ndk::ScopedFileDescriptor(duplicatedFd.release());
1118 }
1119 
unvalidatedConvert(const nn::SharedHandle & handle)1120 nn::GeneralResult<ndk::ScopedFileDescriptor> unvalidatedConvert(const nn::SharedHandle& handle) {
1121     auto duplicatedFd = NN_TRY(nn::dupFd(handle->get()));
1122     return ndk::ScopedFileDescriptor(duplicatedFd.release());
1123 }
1124 
unvalidatedConvert(const nn::Capabilities & capabilities)1125 nn::GeneralResult<Capabilities> unvalidatedConvert(const nn::Capabilities& capabilities) {
1126     const auto relaxedFloat32toFloat16PerformanceTensor =
1127             NN_TRY(unvalidatedConvert(capabilities.relaxedFloat32toFloat16PerformanceTensor));
1128     const auto relaxedFloat32toFloat16PerformanceScalar =
1129             NN_TRY(unvalidatedConvert(capabilities.relaxedFloat32toFloat16PerformanceScalar));
1130     auto operandPerformance = NN_TRY(unvalidatedConvert(capabilities.operandPerformance));
1131     const auto ifPerformance = NN_TRY(unvalidatedConvert(capabilities.ifPerformance));
1132     const auto whilePerformance = NN_TRY(unvalidatedConvert(capabilities.whilePerformance));
1133     return Capabilities{
1134             .relaxedFloat32toFloat16PerformanceTensor = relaxedFloat32toFloat16PerformanceTensor,
1135             .relaxedFloat32toFloat16PerformanceScalar = relaxedFloat32toFloat16PerformanceScalar,
1136             .operandPerformance = std::move(operandPerformance),
1137             .ifPerformance = ifPerformance,
1138             .whilePerformance = whilePerformance,
1139     };
1140 }
1141 
unvalidatedConvert(const nn::Extension & extension)1142 nn::GeneralResult<Extension> unvalidatedConvert(const nn::Extension& extension) {
1143     auto operandTypes = NN_TRY(unvalidatedConvert(extension.operandTypes));
1144     return Extension{.name = extension.name, .operandTypes = std::move(operandTypes)};
1145 }
1146 #ifdef NN_AIDL_V4_OR_ABOVE
unvalidatedConvert(const nn::TokenValuePair & tokenValuePair)1147 nn::GeneralResult<TokenValuePair> unvalidatedConvert(const nn::TokenValuePair& tokenValuePair) {
1148     return TokenValuePair{.token = tokenValuePair.token, .value = tokenValuePair.value};
1149 }
1150 #endif  // NN_AIDL_V4_OR_ABOVE
1151 
convert(const nn::CacheToken & cacheToken)1152 nn::GeneralResult<std::vector<uint8_t>> convert(const nn::CacheToken& cacheToken) {
1153     return validatedConvert(cacheToken);
1154 }
1155 
convert(const nn::BufferDesc & bufferDesc)1156 nn::GeneralResult<BufferDesc> convert(const nn::BufferDesc& bufferDesc) {
1157     return validatedConvert(bufferDesc);
1158 }
1159 
convert(const nn::DeviceType & deviceType)1160 nn::GeneralResult<DeviceType> convert(const nn::DeviceType& deviceType) {
1161     return validatedConvert(deviceType);
1162 }
1163 
convert(const nn::MeasureTiming & measureTiming)1164 nn::GeneralResult<bool> convert(const nn::MeasureTiming& measureTiming) {
1165     return validatedConvert(measureTiming);
1166 }
1167 
convert(const nn::SharedMemory & memory)1168 nn::GeneralResult<Memory> convert(const nn::SharedMemory& memory) {
1169     return validatedConvert(memory);
1170 }
1171 
convert(const nn::ErrorStatus & errorStatus)1172 nn::GeneralResult<ErrorStatus> convert(const nn::ErrorStatus& errorStatus) {
1173     return validatedConvert(errorStatus);
1174 }
1175 
convert(const nn::ExecutionPreference & executionPreference)1176 nn::GeneralResult<ExecutionPreference> convert(const nn::ExecutionPreference& executionPreference) {
1177     return validatedConvert(executionPreference);
1178 }
1179 
convert(const nn::Model & model)1180 nn::GeneralResult<Model> convert(const nn::Model& model) {
1181     return validatedConvert(model);
1182 }
1183 
convert(const nn::Priority & priority)1184 nn::GeneralResult<Priority> convert(const nn::Priority& priority) {
1185     return validatedConvert(priority);
1186 }
1187 
convert(const nn::Request & request)1188 nn::GeneralResult<Request> convert(const nn::Request& request) {
1189     return validatedConvert(request);
1190 }
1191 
convert(const nn::Timing & timing)1192 nn::GeneralResult<Timing> convert(const nn::Timing& timing) {
1193     return validatedConvert(timing);
1194 }
1195 
convert(const nn::OptionalDuration & optionalDuration)1196 nn::GeneralResult<int64_t> convert(const nn::OptionalDuration& optionalDuration) {
1197     return validatedConvert(optionalDuration);
1198 }
1199 
convert(const nn::OptionalTimePoint & outputShapes)1200 nn::GeneralResult<int64_t> convert(const nn::OptionalTimePoint& outputShapes) {
1201     return validatedConvert(outputShapes);
1202 }
1203 
convert(const nn::Capabilities & capabilities)1204 nn::GeneralResult<Capabilities> convert(const nn::Capabilities& capabilities) {
1205     return validatedConvert(capabilities);
1206 }
1207 
convert(const nn::Extension & extension)1208 nn::GeneralResult<Extension> convert(const nn::Extension& extension) {
1209     return validatedConvert(extension);
1210 }
1211 
convert(const std::vector<nn::BufferRole> & bufferRoles)1212 nn::GeneralResult<std::vector<BufferRole>> convert(const std::vector<nn::BufferRole>& bufferRoles) {
1213     return validatedConvert(bufferRoles);
1214 }
1215 
convert(const std::vector<nn::OutputShape> & outputShapes)1216 nn::GeneralResult<std::vector<OutputShape>> convert(
1217         const std::vector<nn::OutputShape>& outputShapes) {
1218     return validatedConvert(outputShapes);
1219 }
1220 
convert(const std::vector<nn::SharedHandle> & cacheHandles)1221 nn::GeneralResult<std::vector<ndk::ScopedFileDescriptor>> convert(
1222         const std::vector<nn::SharedHandle>& cacheHandles) {
1223     return validatedConvert(cacheHandles);
1224 }
1225 
convert(const std::vector<nn::SyncFence> & syncFences)1226 nn::GeneralResult<std::vector<ndk::ScopedFileDescriptor>> convert(
1227         const std::vector<nn::SyncFence>& syncFences) {
1228     return validatedConvert(syncFences);
1229 }
convert(const std::vector<nn::ExtensionNameAndPrefix> & extensionNameToPrefix)1230 nn::GeneralResult<std::vector<ExtensionNameAndPrefix>> convert(
1231         const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) {
1232     return unvalidatedConvert(extensionNameToPrefix);
1233 }
1234 
1235 #ifdef NN_AIDL_V4_OR_ABOVE
convert(const std::vector<nn::TokenValuePair> & metaData)1236 nn::GeneralResult<std::vector<TokenValuePair>> convert(
1237         const std::vector<nn::TokenValuePair>& metaData) {
1238     return validatedConvert(metaData);
1239 }
1240 #endif  // NN_AIDL_V4_OR_ABOVE
1241 
convert(const std::vector<nn::Extension> & extensions)1242 nn::GeneralResult<std::vector<Extension>> convert(const std::vector<nn::Extension>& extensions) {
1243     return validatedConvert(extensions);
1244 }
1245 
toSigned(const std::vector<uint32_t> & vec)1246 nn::GeneralResult<std::vector<int32_t>> toSigned(const std::vector<uint32_t>& vec) {
1247     if (!std::all_of(vec.begin(), vec.end(),
1248                      [](uint32_t v) { return v <= std::numeric_limits<int32_t>::max(); })) {
1249         return NN_ERROR() << "Vector contains a value that doesn't fit into int32_t.";
1250     }
1251     return std::vector<int32_t>(vec.begin(), vec.end());
1252 }
1253 
toVec(const std::array<uint8_t,IDevice::BYTE_SIZE_OF_CACHE_TOKEN> & token)1254 std::vector<uint8_t> toVec(const std::array<uint8_t, IDevice::BYTE_SIZE_OF_CACHE_TOKEN>& token) {
1255     return std::vector<uint8_t>(token.begin(), token.end());
1256 }
1257 
1258 }  // namespace aidl::android::hardware::neuralnetworks::utils
1259