1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "Conversions.h"
18 
19 #include <aidl/android/hardware/common/Ashmem.h>
20 #include <aidl/android/hardware/common/MappableFile.h>
21 #include <aidl/android/hardware/common/NativeHandle.h>
22 #include <aidl/android/hardware/graphics/common/HardwareBuffer.h>
23 #include <aidlcommonsupport/NativeHandle.h>
24 #include <android-base/logging.h>
25 #include <android-base/mapped_file.h>
26 #include <android-base/unique_fd.h>
27 #include <android/binder_auto_utils.h>
28 #include <android/hardware_buffer.h>
29 #include <cutils/native_handle.h>
30 #include <nnapi/OperandTypes.h>
31 #include <nnapi/OperationTypes.h>
32 #include <nnapi/Result.h>
33 #include <nnapi/SharedMemory.h>
34 #include <nnapi/TypeUtils.h>
35 #include <nnapi/Types.h>
36 #include <nnapi/Validation.h>
37 #include <nnapi/hal/CommonUtils.h>
38 #include <nnapi/hal/HandleError.h>
39 #include <vndk/hardware_buffer.h>
40 
41 #include <algorithm>
42 #include <chrono>
43 #include <functional>
44 #include <iterator>
45 #include <limits>
46 #include <type_traits>
47 #include <utility>
48 
49 #include "Utils.h"
50 
51 #define VERIFY_NON_NEGATIVE(value) \
52     while (UNLIKELY(value < 0)) return NN_ERROR()
53 
54 #define VERIFY_LE_INT32_MAX(value) \
55     while (UNLIKELY(value > std::numeric_limits<int32_t>::max())) return NN_ERROR()
56 
57 namespace {
58 template <typename Type>
underlyingType(Type value)59 constexpr std::underlying_type_t<Type> underlyingType(Type value) {
60     return static_cast<std::underlying_type_t<Type>>(value);
61 }
62 
63 constexpr int64_t kNoTiming = -1;
64 
65 }  // namespace
66 
67 namespace android::nn {
68 namespace {
69 
70 using ::aidl::android::hardware::common::NativeHandle;
71 
72 template <typename Input>
73 using UnvalidatedConvertOutput =
74         std::decay_t<decltype(unvalidatedConvert(std::declval<Input>()).value())>;
75 
76 template <typename Type>
unvalidatedConvertVec(const std::vector<Type> & arguments)77 GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> unvalidatedConvertVec(
78         const std::vector<Type>& arguments) {
79     std::vector<UnvalidatedConvertOutput<Type>> canonical;
80     canonical.reserve(arguments.size());
81     for (const auto& argument : arguments) {
82         canonical.push_back(NN_TRY(nn::unvalidatedConvert(argument)));
83     }
84     return canonical;
85 }
86 
87 template <typename Type>
unvalidatedConvert(const std::vector<Type> & arguments)88 GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> unvalidatedConvert(
89         const std::vector<Type>& arguments) {
90     return unvalidatedConvertVec(arguments);
91 }
92 
93 template <typename Type>
validatedConvert(const Type & halObject)94 GeneralResult<UnvalidatedConvertOutput<Type>> validatedConvert(const Type& halObject) {
95     auto canonical = NN_TRY(nn::unvalidatedConvert(halObject));
96     NN_TRY(aidl_hal::utils::compliantVersion(canonical));
97     return canonical;
98 }
99 
100 template <typename Type>
validatedConvert(const std::vector<Type> & arguments)101 GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> validatedConvert(
102         const std::vector<Type>& arguments) {
103     std::vector<UnvalidatedConvertOutput<Type>> canonical;
104     canonical.reserve(arguments.size());
105     for (const auto& argument : arguments) {
106         canonical.push_back(NN_TRY(validatedConvert(argument)));
107     }
108     return canonical;
109 }
110 
unvalidatedConvertHelper(const NativeHandle & aidlNativeHandle)111 GeneralResult<Handle> unvalidatedConvertHelper(const NativeHandle& aidlNativeHandle) {
112     std::vector<base::unique_fd> fds;
113     fds.reserve(aidlNativeHandle.fds.size());
114     for (const auto& fd : aidlNativeHandle.fds) {
115         auto duplicatedFd = NN_TRY(dupFd(fd.get()));
116         fds.emplace_back(duplicatedFd.release());
117     }
118 
119     return Handle{.fds = std::move(fds), .ints = aidlNativeHandle.ints};
120 }
121 
122 struct NativeHandleDeleter {
operator ()android::nn::__anon497330ac0211::NativeHandleDeleter123     void operator()(native_handle_t* handle) const {
124         if (handle) {
125             native_handle_close(handle);
126             native_handle_delete(handle);
127         }
128     }
129 };
130 
131 using UniqueNativeHandle = std::unique_ptr<native_handle_t, NativeHandleDeleter>;
132 
nativeHandleFromAidlHandle(const NativeHandle & handle)133 GeneralResult<UniqueNativeHandle> nativeHandleFromAidlHandle(const NativeHandle& handle) {
134     auto nativeHandle = UniqueNativeHandle(dupFromAidl(handle));
135     if (nativeHandle.get() == nullptr) {
136         return NN_ERROR() << "android::dupFromAidl failed to convert the common::NativeHandle to a "
137                              "native_handle_t";
138     }
139     if (!std::all_of(nativeHandle->data + 0, nativeHandle->data + nativeHandle->numFds,
140                      [](int fd) { return fd >= 0; })) {
141         return NN_ERROR() << "android::dupFromAidl returned an invalid native_handle_t";
142     }
143     return nativeHandle;
144 }
145 
146 }  // anonymous namespace
147 
unvalidatedConvert(const aidl_hal::OperandType & operandType)148 GeneralResult<OperandType> unvalidatedConvert(const aidl_hal::OperandType& operandType) {
149     VERIFY_NON_NEGATIVE(underlyingType(operandType)) << "Negative operand types are not allowed.";
150     const auto canonical = static_cast<OperandType>(operandType);
151     if (canonical == OperandType::OEM || canonical == OperandType::TENSOR_OEM_BYTE) {
152         return NN_ERROR() << "Unable to convert invalid OperandType " << canonical;
153     }
154     return canonical;
155 }
156 
unvalidatedConvert(const aidl_hal::OperationType & operationType)157 GeneralResult<OperationType> unvalidatedConvert(const aidl_hal::OperationType& operationType) {
158     VERIFY_NON_NEGATIVE(underlyingType(operationType))
159             << "Negative operation types are not allowed.";
160     const auto canonical = static_cast<OperationType>(operationType);
161     if (canonical == OperationType::OEM_OPERATION) {
162         return NN_ERROR() << "Unable to convert invalid OperationType OEM_OPERATION";
163     }
164     return canonical;
165 }
166 
unvalidatedConvert(const aidl_hal::DeviceType & deviceType)167 GeneralResult<DeviceType> unvalidatedConvert(const aidl_hal::DeviceType& deviceType) {
168     return static_cast<DeviceType>(deviceType);
169 }
170 
unvalidatedConvert(const aidl_hal::Priority & priority)171 GeneralResult<Priority> unvalidatedConvert(const aidl_hal::Priority& priority) {
172     return static_cast<Priority>(priority);
173 }
174 
unvalidatedConvert(const aidl_hal::Capabilities & capabilities)175 GeneralResult<Capabilities> unvalidatedConvert(const aidl_hal::Capabilities& capabilities) {
176     const bool validOperandTypes = std::all_of(
177             capabilities.operandPerformance.begin(), capabilities.operandPerformance.end(),
178             [](const aidl_hal::OperandPerformance& operandPerformance) {
179                 return validatedConvert(operandPerformance.type).has_value();
180             });
181     if (!validOperandTypes) {
182         return NN_ERROR() << "Invalid OperandType when unvalidatedConverting OperandPerformance in "
183                              "Capabilities";
184     }
185 
186     auto operandPerformance = NN_TRY(unvalidatedConvert(capabilities.operandPerformance));
187     auto table = NN_TRY(hal::utils::makeGeneralFailure(
188             Capabilities::OperandPerformanceTable::create(std::move(operandPerformance)),
189             nn::ErrorStatus::GENERAL_FAILURE));
190 
191     return Capabilities{
192             .relaxedFloat32toFloat16PerformanceScalar = NN_TRY(
193                     unvalidatedConvert(capabilities.relaxedFloat32toFloat16PerformanceScalar)),
194             .relaxedFloat32toFloat16PerformanceTensor = NN_TRY(
195                     unvalidatedConvert(capabilities.relaxedFloat32toFloat16PerformanceTensor)),
196             .operandPerformance = std::move(table),
197             .ifPerformance = NN_TRY(unvalidatedConvert(capabilities.ifPerformance)),
198             .whilePerformance = NN_TRY(unvalidatedConvert(capabilities.whilePerformance)),
199     };
200 }
201 
unvalidatedConvert(const aidl_hal::OperandPerformance & operandPerformance)202 GeneralResult<Capabilities::OperandPerformance> unvalidatedConvert(
203         const aidl_hal::OperandPerformance& operandPerformance) {
204     return Capabilities::OperandPerformance{
205             .type = NN_TRY(unvalidatedConvert(operandPerformance.type)),
206             .info = NN_TRY(unvalidatedConvert(operandPerformance.info)),
207     };
208 }
209 
unvalidatedConvert(const aidl_hal::PerformanceInfo & performanceInfo)210 GeneralResult<Capabilities::PerformanceInfo> unvalidatedConvert(
211         const aidl_hal::PerformanceInfo& performanceInfo) {
212     return Capabilities::PerformanceInfo{
213             .execTime = performanceInfo.execTime,
214             .powerUsage = performanceInfo.powerUsage,
215     };
216 }
217 
unvalidatedConvert(const aidl_hal::DataLocation & location)218 GeneralResult<DataLocation> unvalidatedConvert(const aidl_hal::DataLocation& location) {
219     VERIFY_NON_NEGATIVE(location.poolIndex) << "DataLocation: pool index must not be negative";
220     VERIFY_NON_NEGATIVE(location.offset) << "DataLocation: offset must not be negative";
221     VERIFY_NON_NEGATIVE(location.length) << "DataLocation: length must not be negative";
222     VERIFY_NON_NEGATIVE(location.padding) << "DataLocation: padding must not be negative";
223     if (location.offset > std::numeric_limits<uint32_t>::max()) {
224         return NN_ERROR() << "DataLocation: offset must be <= std::numeric_limits<uint32_t>::max()";
225     }
226     if (location.length > std::numeric_limits<uint32_t>::max()) {
227         return NN_ERROR() << "DataLocation: length must be <= std::numeric_limits<uint32_t>::max()";
228     }
229     if (location.padding > std::numeric_limits<uint32_t>::max()) {
230         return NN_ERROR()
231                << "DataLocation: padding must be <= std::numeric_limits<uint32_t>::max()";
232     }
233     return DataLocation{
234             .poolIndex = static_cast<uint32_t>(location.poolIndex),
235             .offset = static_cast<uint32_t>(location.offset),
236             .length = static_cast<uint32_t>(location.length),
237             .padding = static_cast<uint32_t>(location.padding),
238     };
239 }
240 
unvalidatedConvert(const aidl_hal::Operation & operation)241 GeneralResult<Operation> unvalidatedConvert(const aidl_hal::Operation& operation) {
242     return Operation{
243             .type = NN_TRY(unvalidatedConvert(operation.type)),
244             .inputs = NN_TRY(toUnsigned(operation.inputs)),
245             .outputs = NN_TRY(toUnsigned(operation.outputs)),
246     };
247 }
248 
unvalidatedConvert(const aidl_hal::OperandLifeTime & operandLifeTime)249 GeneralResult<Operand::LifeTime> unvalidatedConvert(
250         const aidl_hal::OperandLifeTime& operandLifeTime) {
251     return static_cast<Operand::LifeTime>(operandLifeTime);
252 }
253 
unvalidatedConvert(const aidl_hal::Operand & operand)254 GeneralResult<Operand> unvalidatedConvert(const aidl_hal::Operand& operand) {
255     return Operand{
256             .type = NN_TRY(unvalidatedConvert(operand.type)),
257             .dimensions = NN_TRY(toUnsigned(operand.dimensions)),
258             .scale = operand.scale,
259             .zeroPoint = operand.zeroPoint,
260             .lifetime = NN_TRY(unvalidatedConvert(operand.lifetime)),
261             .location = NN_TRY(unvalidatedConvert(operand.location)),
262             .extraParams = NN_TRY(unvalidatedConvert(operand.extraParams)),
263     };
264 }
265 
unvalidatedConvert(const std::optional<aidl_hal::OperandExtraParams> & optionalExtraParams)266 GeneralResult<Operand::ExtraParams> unvalidatedConvert(
267         const std::optional<aidl_hal::OperandExtraParams>& optionalExtraParams) {
268     if (!optionalExtraParams.has_value()) {
269         return Operand::NoParams{};
270     }
271     const auto& extraParams = optionalExtraParams.value();
272     using Tag = aidl_hal::OperandExtraParams::Tag;
273     switch (extraParams.getTag()) {
274         case Tag::channelQuant:
275             return unvalidatedConvert(extraParams.get<Tag::channelQuant>());
276         case Tag::extension:
277             return extraParams.get<Tag::extension>();
278     }
279     return NN_ERROR() << "Unrecognized Operand::ExtraParams tag: "
280                       << underlyingType(extraParams.getTag());
281 }
282 
unvalidatedConvert(const aidl_hal::SymmPerChannelQuantParams & symmPerChannelQuantParams)283 GeneralResult<Operand::SymmPerChannelQuantParams> unvalidatedConvert(
284         const aidl_hal::SymmPerChannelQuantParams& symmPerChannelQuantParams) {
285     VERIFY_NON_NEGATIVE(symmPerChannelQuantParams.channelDim)
286             << "Per-channel quantization channel dimension must not be negative.";
287     return Operand::SymmPerChannelQuantParams{
288             .scales = symmPerChannelQuantParams.scales,
289             .channelDim = static_cast<uint32_t>(symmPerChannelQuantParams.channelDim),
290     };
291 }
292 
unvalidatedConvert(const aidl_hal::Model & model)293 GeneralResult<Model> unvalidatedConvert(const aidl_hal::Model& model) {
294     return Model{
295             .main = NN_TRY(unvalidatedConvert(model.main)),
296             .referenced = NN_TRY(unvalidatedConvert(model.referenced)),
297             .operandValues = NN_TRY(unvalidatedConvert(model.operandValues)),
298             .pools = NN_TRY(unvalidatedConvert(model.pools)),
299             .relaxComputationFloat32toFloat16 = model.relaxComputationFloat32toFloat16,
300             .extensionNameToPrefix = NN_TRY(unvalidatedConvert(model.extensionNameToPrefix)),
301     };
302 }
303 
unvalidatedConvert(const aidl_hal::Subgraph & subgraph)304 GeneralResult<Model::Subgraph> unvalidatedConvert(const aidl_hal::Subgraph& subgraph) {
305     return Model::Subgraph{
306             .operands = NN_TRY(unvalidatedConvert(subgraph.operands)),
307             .operations = NN_TRY(unvalidatedConvert(subgraph.operations)),
308             .inputIndexes = NN_TRY(toUnsigned(subgraph.inputIndexes)),
309             .outputIndexes = NN_TRY(toUnsigned(subgraph.outputIndexes)),
310     };
311 }
312 
unvalidatedConvert(const aidl_hal::ExtensionNameAndPrefix & extensionNameAndPrefix)313 GeneralResult<Model::ExtensionNameAndPrefix> unvalidatedConvert(
314         const aidl_hal::ExtensionNameAndPrefix& extensionNameAndPrefix) {
315     return Model::ExtensionNameAndPrefix{
316             .name = extensionNameAndPrefix.name,
317             .prefix = extensionNameAndPrefix.prefix,
318     };
319 }
320 
unvalidatedConvert(const aidl_hal::Extension & extension)321 GeneralResult<Extension> unvalidatedConvert(const aidl_hal::Extension& extension) {
322     return Extension{
323             .name = extension.name,
324             .operandTypes = NN_TRY(unvalidatedConvert(extension.operandTypes)),
325     };
326 }
327 
unvalidatedConvert(const aidl_hal::ExtensionOperandTypeInformation & operandTypeInformation)328 GeneralResult<Extension::OperandTypeInformation> unvalidatedConvert(
329         const aidl_hal::ExtensionOperandTypeInformation& operandTypeInformation) {
330     VERIFY_NON_NEGATIVE(operandTypeInformation.byteSize)
331             << "Extension operand type byte size must not be negative";
332     return Extension::OperandTypeInformation{
333             .type = operandTypeInformation.type,
334             .isTensor = operandTypeInformation.isTensor,
335             .byteSize = static_cast<uint32_t>(operandTypeInformation.byteSize),
336     };
337 }
338 
unvalidatedConvert(const aidl_hal::OutputShape & outputShape)339 GeneralResult<OutputShape> unvalidatedConvert(const aidl_hal::OutputShape& outputShape) {
340     return OutputShape{
341             .dimensions = NN_TRY(toUnsigned(outputShape.dimensions)),
342             .isSufficient = outputShape.isSufficient,
343     };
344 }
345 
unvalidatedConvert(bool measureTiming)346 GeneralResult<MeasureTiming> unvalidatedConvert(bool measureTiming) {
347     return measureTiming ? MeasureTiming::YES : MeasureTiming::NO;
348 }
349 
unvalidatedConvert(const aidl_hal::Memory & memory)350 GeneralResult<SharedMemory> unvalidatedConvert(const aidl_hal::Memory& memory) {
351     using Tag = aidl_hal::Memory::Tag;
352     switch (memory.getTag()) {
353         case Tag::ashmem: {
354             const auto& ashmem = memory.get<Tag::ashmem>();
355             VERIFY_NON_NEGATIVE(ashmem.size) << "Memory size must not be negative";
356             if (ashmem.size > std::numeric_limits<size_t>::max()) {
357                 return NN_ERROR() << "Memory: size must be <= std::numeric_limits<size_t>::max()";
358             }
359 
360             auto handle = Memory::Ashmem{
361                     .fd = NN_TRY(dupFd(ashmem.fd.get())),
362                     .size = static_cast<size_t>(ashmem.size),
363             };
364             return std::make_shared<const Memory>(Memory{.handle = std::move(handle)});
365         }
366         case Tag::mappableFile: {
367             const auto& mappableFile = memory.get<Tag::mappableFile>();
368             VERIFY_NON_NEGATIVE(mappableFile.length) << "Memory size must not be negative";
369             VERIFY_NON_NEGATIVE(mappableFile.offset) << "Memory offset must not be negative";
370             if (mappableFile.length > std::numeric_limits<size_t>::max()) {
371                 return NN_ERROR() << "Memory: size must be <= std::numeric_limits<size_t>::max()";
372             }
373             if (mappableFile.offset > std::numeric_limits<size_t>::max()) {
374                 return NN_ERROR() << "Memory: offset must be <= std::numeric_limits<size_t>::max()";
375             }
376 
377             const size_t size = static_cast<size_t>(mappableFile.length);
378             const int prot = mappableFile.prot;
379             const int fd = mappableFile.fd.get();
380             const size_t offset = static_cast<size_t>(mappableFile.offset);
381 
382             return createSharedMemoryFromFd(size, prot, fd, offset);
383         }
384         case Tag::hardwareBuffer: {
385             const auto& hardwareBuffer = memory.get<Tag::hardwareBuffer>();
386 
387             const UniqueNativeHandle handle =
388                     NN_TRY(nativeHandleFromAidlHandle(hardwareBuffer.handle));
389             const native_handle_t* nativeHandle = handle.get();
390 
391             const AHardwareBuffer_Desc desc{
392                     .width = static_cast<uint32_t>(hardwareBuffer.description.width),
393                     .height = static_cast<uint32_t>(hardwareBuffer.description.height),
394                     .layers = static_cast<uint32_t>(hardwareBuffer.description.layers),
395                     .format = static_cast<uint32_t>(hardwareBuffer.description.format),
396                     .usage = static_cast<uint64_t>(hardwareBuffer.description.usage),
397                     .stride = static_cast<uint32_t>(hardwareBuffer.description.stride),
398             };
399             AHardwareBuffer* ahwb = nullptr;
400             const status_t status = AHardwareBuffer_createFromHandle(
401                     &desc, nativeHandle, AHARDWAREBUFFER_CREATE_FROM_HANDLE_METHOD_CLONE, &ahwb);
402             if (status != NO_ERROR) {
403                 return NN_ERROR() << "createFromHandle failed";
404             }
405 
406             return createSharedMemoryFromAHWB(ahwb, /*takeOwnership=*/true);
407         }
408     }
409     return NN_ERROR() << "Unrecognized Memory::Tag: " << memory.getTag();
410 }
411 
unvalidatedConvert(const aidl_hal::Timing & timing)412 GeneralResult<Timing> unvalidatedConvert(const aidl_hal::Timing& timing) {
413     if (timing.timeInDriverNs < -1) {
414         return NN_ERROR() << "Timing: timeInDriverNs must not be less than -1";
415     }
416     if (timing.timeOnDeviceNs < -1) {
417         return NN_ERROR() << "Timing: timeOnDeviceNs must not be less than -1";
418     }
419     constexpr auto convertTiming = [](int64_t halTiming) -> OptionalDuration {
420         if (halTiming == kNoTiming) {
421             return {};
422         }
423         return nn::Duration(static_cast<uint64_t>(halTiming));
424     };
425     return Timing{.timeOnDevice = convertTiming(timing.timeOnDeviceNs),
426                   .timeInDriver = convertTiming(timing.timeInDriverNs)};
427 }
428 
unvalidatedConvert(const std::vector<uint8_t> & operandValues)429 GeneralResult<Model::OperandValues> unvalidatedConvert(const std::vector<uint8_t>& operandValues) {
430     return Model::OperandValues(operandValues.data(), operandValues.size());
431 }
432 
unvalidatedConvert(const aidl_hal::BufferDesc & bufferDesc)433 GeneralResult<BufferDesc> unvalidatedConvert(const aidl_hal::BufferDesc& bufferDesc) {
434     return BufferDesc{.dimensions = NN_TRY(toUnsigned(bufferDesc.dimensions))};
435 }
436 
unvalidatedConvert(const aidl_hal::BufferRole & bufferRole)437 GeneralResult<BufferRole> unvalidatedConvert(const aidl_hal::BufferRole& bufferRole) {
438     VERIFY_NON_NEGATIVE(bufferRole.modelIndex) << "BufferRole: modelIndex must not be negative";
439     VERIFY_NON_NEGATIVE(bufferRole.ioIndex) << "BufferRole: ioIndex must not be negative";
440     return BufferRole{
441             .modelIndex = static_cast<uint32_t>(bufferRole.modelIndex),
442             .ioIndex = static_cast<uint32_t>(bufferRole.ioIndex),
443             .probability = bufferRole.probability,
444     };
445 }
446 
unvalidatedConvert(const aidl_hal::Request & request)447 GeneralResult<Request> unvalidatedConvert(const aidl_hal::Request& request) {
448     return Request{
449             .inputs = NN_TRY(unvalidatedConvert(request.inputs)),
450             .outputs = NN_TRY(unvalidatedConvert(request.outputs)),
451             .pools = NN_TRY(unvalidatedConvert(request.pools)),
452     };
453 }
454 
unvalidatedConvert(const aidl_hal::RequestArgument & argument)455 GeneralResult<Request::Argument> unvalidatedConvert(const aidl_hal::RequestArgument& argument) {
456     const auto lifetime = argument.hasNoValue ? Request::Argument::LifeTime::NO_VALUE
457                                               : Request::Argument::LifeTime::POOL;
458     return Request::Argument{
459             .lifetime = lifetime,
460             .location = NN_TRY(unvalidatedConvert(argument.location)),
461             .dimensions = NN_TRY(toUnsigned(argument.dimensions)),
462     };
463 }
464 
unvalidatedConvert(const aidl_hal::RequestMemoryPool & memoryPool)465 GeneralResult<Request::MemoryPool> unvalidatedConvert(
466         const aidl_hal::RequestMemoryPool& memoryPool) {
467     using Tag = aidl_hal::RequestMemoryPool::Tag;
468     switch (memoryPool.getTag()) {
469         case Tag::pool:
470             return unvalidatedConvert(memoryPool.get<Tag::pool>());
471         case Tag::token: {
472             const auto token = memoryPool.get<Tag::token>();
473             VERIFY_NON_NEGATIVE(token) << "Memory pool token must not be negative";
474             return static_cast<Request::MemoryDomainToken>(token);
475         }
476     }
477     return NN_ERROR() << "Invalid Request::MemoryPool tag " << underlyingType(memoryPool.getTag());
478 }
479 
unvalidatedConvert(const aidl_hal::ErrorStatus & status)480 GeneralResult<ErrorStatus> unvalidatedConvert(const aidl_hal::ErrorStatus& status) {
481     switch (status) {
482         case aidl_hal::ErrorStatus::NONE:
483         case aidl_hal::ErrorStatus::DEVICE_UNAVAILABLE:
484         case aidl_hal::ErrorStatus::GENERAL_FAILURE:
485         case aidl_hal::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE:
486         case aidl_hal::ErrorStatus::INVALID_ARGUMENT:
487         case aidl_hal::ErrorStatus::MISSED_DEADLINE_TRANSIENT:
488         case aidl_hal::ErrorStatus::MISSED_DEADLINE_PERSISTENT:
489         case aidl_hal::ErrorStatus::RESOURCE_EXHAUSTED_TRANSIENT:
490         case aidl_hal::ErrorStatus::RESOURCE_EXHAUSTED_PERSISTENT:
491             return static_cast<ErrorStatus>(status);
492     }
493     return NN_ERROR() << "Invalid ErrorStatus " << underlyingType(status);
494 }
495 
unvalidatedConvert(const aidl_hal::ExecutionPreference & executionPreference)496 GeneralResult<ExecutionPreference> unvalidatedConvert(
497         const aidl_hal::ExecutionPreference& executionPreference) {
498     return static_cast<ExecutionPreference>(executionPreference);
499 }
500 
unvalidatedConvert(const NativeHandle & aidlNativeHandle)501 GeneralResult<SharedHandle> unvalidatedConvert(const NativeHandle& aidlNativeHandle) {
502     return std::make_shared<const Handle>(NN_TRY(unvalidatedConvertHelper(aidlNativeHandle)));
503 }
504 
unvalidatedConvert(const std::vector<aidl_hal::Operation> & operations)505 GeneralResult<std::vector<Operation>> unvalidatedConvert(
506         const std::vector<aidl_hal::Operation>& operations) {
507     return unvalidatedConvertVec(operations);
508 }
509 
unvalidatedConvert(const ndk::ScopedFileDescriptor & syncFence)510 GeneralResult<SyncFence> unvalidatedConvert(const ndk::ScopedFileDescriptor& syncFence) {
511     auto duplicatedFd = NN_TRY(dupFd(syncFence.get()));
512     return SyncFence::create(std::move(duplicatedFd));
513 }
514 
convert(const aidl_hal::Capabilities & capabilities)515 GeneralResult<Capabilities> convert(const aidl_hal::Capabilities& capabilities) {
516     return validatedConvert(capabilities);
517 }
518 
convert(const aidl_hal::DeviceType & deviceType)519 GeneralResult<DeviceType> convert(const aidl_hal::DeviceType& deviceType) {
520     return validatedConvert(deviceType);
521 }
522 
convert(const aidl_hal::ErrorStatus & errorStatus)523 GeneralResult<ErrorStatus> convert(const aidl_hal::ErrorStatus& errorStatus) {
524     return validatedConvert(errorStatus);
525 }
526 
convert(const aidl_hal::ExecutionPreference & executionPreference)527 GeneralResult<ExecutionPreference> convert(
528         const aidl_hal::ExecutionPreference& executionPreference) {
529     return validatedConvert(executionPreference);
530 }
531 
convert(const aidl_hal::Memory & operand)532 GeneralResult<SharedMemory> convert(const aidl_hal::Memory& operand) {
533     return validatedConvert(operand);
534 }
535 
convert(const aidl_hal::Model & model)536 GeneralResult<Model> convert(const aidl_hal::Model& model) {
537     return validatedConvert(model);
538 }
539 
convert(const aidl_hal::OperandType & operandType)540 GeneralResult<OperandType> convert(const aidl_hal::OperandType& operandType) {
541     return validatedConvert(operandType);
542 }
543 
convert(const aidl_hal::Priority & priority)544 GeneralResult<Priority> convert(const aidl_hal::Priority& priority) {
545     return validatedConvert(priority);
546 }
547 
convert(const aidl_hal::Request & request)548 GeneralResult<Request> convert(const aidl_hal::Request& request) {
549     return validatedConvert(request);
550 }
551 
convert(const aidl_hal::Timing & timing)552 GeneralResult<Timing> convert(const aidl_hal::Timing& timing) {
553     return validatedConvert(timing);
554 }
555 
convert(const ndk::ScopedFileDescriptor & syncFence)556 GeneralResult<SyncFence> convert(const ndk::ScopedFileDescriptor& syncFence) {
557     return validatedConvert(syncFence);
558 }
559 
convert(const std::vector<aidl_hal::Extension> & extension)560 GeneralResult<std::vector<Extension>> convert(const std::vector<aidl_hal::Extension>& extension) {
561     return validatedConvert(extension);
562 }
563 
convert(const std::vector<aidl_hal::Memory> & memories)564 GeneralResult<std::vector<SharedMemory>> convert(const std::vector<aidl_hal::Memory>& memories) {
565     return validatedConvert(memories);
566 }
567 
convert(const std::vector<aidl_hal::OutputShape> & outputShapes)568 GeneralResult<std::vector<OutputShape>> convert(
569         const std::vector<aidl_hal::OutputShape>& outputShapes) {
570     return validatedConvert(outputShapes);
571 }
572 
toUnsigned(const std::vector<int32_t> & vec)573 GeneralResult<std::vector<uint32_t>> toUnsigned(const std::vector<int32_t>& vec) {
574     if (!std::all_of(vec.begin(), vec.end(), [](int32_t v) { return v >= 0; })) {
575         return NN_ERROR() << "Negative value passed to conversion from signed to unsigned";
576     }
577     return std::vector<uint32_t>(vec.begin(), vec.end());
578 }
579 
580 }  // namespace android::nn
581 
582 namespace aidl::android::hardware::neuralnetworks::utils {
583 namespace {
584 
585 template <typename Input>
586 using UnvalidatedConvertOutput =
587         std::decay_t<decltype(unvalidatedConvert(std::declval<Input>()).value())>;
588 
589 template <typename Type>
unvalidatedConvertVec(const std::vector<Type> & arguments)590 nn::GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> unvalidatedConvertVec(
591         const std::vector<Type>& arguments) {
592     std::vector<UnvalidatedConvertOutput<Type>> halObject;
593     halObject.reserve(arguments.size());
594     for (const auto& argument : arguments) {
595         halObject.push_back(NN_TRY(unvalidatedConvert(argument)));
596     }
597     return halObject;
598 }
599 
600 template <typename Type>
unvalidatedConvert(const std::vector<Type> & arguments)601 nn::GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> unvalidatedConvert(
602         const std::vector<Type>& arguments) {
603     return unvalidatedConvertVec(arguments);
604 }
605 
606 template <typename Type>
validatedConvert(const Type & canonical)607 nn::GeneralResult<UnvalidatedConvertOutput<Type>> validatedConvert(const Type& canonical) {
608     NN_TRY(compliantVersion(canonical));
609     return utils::unvalidatedConvert(canonical);
610 }
611 
612 template <typename Type>
validatedConvert(const std::vector<Type> & arguments)613 nn::GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> validatedConvert(
614         const std::vector<Type>& arguments) {
615     std::vector<UnvalidatedConvertOutput<Type>> halObject(arguments.size());
616     for (size_t i = 0; i < arguments.size(); ++i) {
617         halObject[i] = NN_TRY(validatedConvert(arguments[i]));
618     }
619     return halObject;
620 }
621 
unvalidatedConvert(const nn::Handle & handle)622 nn::GeneralResult<common::NativeHandle> unvalidatedConvert(const nn::Handle& handle) {
623     common::NativeHandle aidlNativeHandle;
624     aidlNativeHandle.fds.reserve(handle.fds.size());
625     for (const auto& fd : handle.fds) {
626         auto duplicatedFd = NN_TRY(nn::dupFd(fd.get()));
627         aidlNativeHandle.fds.emplace_back(duplicatedFd.release());
628     }
629     aidlNativeHandle.ints = handle.ints;
630     return aidlNativeHandle;
631 }
632 
633 // Helper template for std::visit
634 template <class... Ts>
635 struct overloaded : Ts... {
636     using Ts::operator()...;
637 };
638 template <class... Ts>
639 overloaded(Ts...)->overloaded<Ts...>;
640 
aidlHandleFromNativeHandle(const native_handle_t & nativeHandle)641 nn::GeneralResult<common::NativeHandle> aidlHandleFromNativeHandle(
642         const native_handle_t& nativeHandle) {
643     auto handle = ::android::dupToAidl(&nativeHandle);
644     if (!std::all_of(handle.fds.begin(), handle.fds.end(),
645                      [](const ndk::ScopedFileDescriptor& fd) { return fd.get() >= 0; })) {
646         return NN_ERROR() << "android::dupToAidl returned an invalid common::NativeHandle";
647     }
648     return handle;
649 }
650 
unvalidatedConvert(const nn::Memory::Ashmem & memory)651 nn::GeneralResult<Memory> unvalidatedConvert(const nn::Memory::Ashmem& memory) {
652     if constexpr (std::numeric_limits<size_t>::max() > std::numeric_limits<int64_t>::max()) {
653         if (memory.size > std::numeric_limits<int64_t>::max()) {
654             return (
655                            NN_ERROR()
656                            << "Memory::Ashmem: size must be <= std::numeric_limits<int64_t>::max()")
657                     .
658                     operator nn::GeneralResult<Memory>();
659         }
660     }
661 
662     auto fd = NN_TRY(nn::dupFd(memory.fd));
663     auto handle = common::Ashmem{
664             .fd = ndk::ScopedFileDescriptor(fd.release()),
665             .size = static_cast<int64_t>(memory.size),
666     };
667     return Memory::make<Memory::Tag::ashmem>(std::move(handle));
668 }
669 
unvalidatedConvert(const nn::Memory::Fd & memory)670 nn::GeneralResult<Memory> unvalidatedConvert(const nn::Memory::Fd& memory) {
671     if constexpr (std::numeric_limits<size_t>::max() > std::numeric_limits<int64_t>::max()) {
672         if (memory.size > std::numeric_limits<int64_t>::max()) {
673             return (NN_ERROR() << "Memory::Fd: size must be <= std::numeric_limits<int64_t>::max()")
674                     .
675                     operator nn::GeneralResult<Memory>();
676         }
677         if (memory.offset > std::numeric_limits<int64_t>::max()) {
678             return (
679                            NN_ERROR()
680                            << "Memory::Fd: offset must be <= std::numeric_limits<int64_t>::max()")
681                     .
682                     operator nn::GeneralResult<Memory>();
683         }
684     }
685 
686     auto fd = NN_TRY(nn::dupFd(memory.fd));
687     auto handle = common::MappableFile{
688             .length = static_cast<int64_t>(memory.size),
689             .prot = memory.prot,
690             .fd = ndk::ScopedFileDescriptor(fd.release()),
691             .offset = static_cast<int64_t>(memory.offset),
692     };
693     return Memory::make<Memory::Tag::mappableFile>(std::move(handle));
694 }
695 
unvalidatedConvert(const nn::Memory::HardwareBuffer & memory)696 nn::GeneralResult<Memory> unvalidatedConvert(const nn::Memory::HardwareBuffer& memory) {
697     const native_handle_t* nativeHandle = AHardwareBuffer_getNativeHandle(memory.handle.get());
698     if (nativeHandle == nullptr) {
699         return (NN_ERROR() << "unvalidatedConvert failed because AHardwareBuffer_getNativeHandle "
700                               "returned nullptr")
701                 .
702                 operator nn::GeneralResult<Memory>();
703     }
704 
705     auto handle = NN_TRY(aidlHandleFromNativeHandle(*nativeHandle));
706 
707     AHardwareBuffer_Desc desc;
708     AHardwareBuffer_describe(memory.handle.get(), &desc);
709 
710     const auto description = graphics::common::HardwareBufferDescription{
711             .width = static_cast<int32_t>(desc.width),
712             .height = static_cast<int32_t>(desc.height),
713             .layers = static_cast<int32_t>(desc.layers),
714             .format = static_cast<graphics::common::PixelFormat>(desc.format),
715             .usage = static_cast<graphics::common::BufferUsage>(desc.usage),
716             .stride = static_cast<int32_t>(desc.stride),
717     };
718 
719     auto hardwareBuffer = graphics::common::HardwareBuffer{
720             .description = std::move(description),
721             .handle = std::move(handle),
722     };
723     return Memory::make<Memory::Tag::hardwareBuffer>(std::move(hardwareBuffer));
724 }
725 
unvalidatedConvert(const nn::Memory::Unknown &)726 nn::GeneralResult<Memory> unvalidatedConvert(const nn::Memory::Unknown& /*memory*/) {
727     return (NN_ERROR() << "Unable to convert Unknown memory type")
728             .
729             operator nn::GeneralResult<Memory>();
730 }
731 
732 }  // namespace
733 
unvalidatedConvert(const nn::CacheToken & cacheToken)734 nn::GeneralResult<std::vector<uint8_t>> unvalidatedConvert(const nn::CacheToken& cacheToken) {
735     return std::vector<uint8_t>(cacheToken.begin(), cacheToken.end());
736 }
737 
unvalidatedConvert(const nn::BufferDesc & bufferDesc)738 nn::GeneralResult<BufferDesc> unvalidatedConvert(const nn::BufferDesc& bufferDesc) {
739     return BufferDesc{.dimensions = NN_TRY(toSigned(bufferDesc.dimensions))};
740 }
741 
unvalidatedConvert(const nn::BufferRole & bufferRole)742 nn::GeneralResult<BufferRole> unvalidatedConvert(const nn::BufferRole& bufferRole) {
743     VERIFY_LE_INT32_MAX(bufferRole.modelIndex)
744             << "BufferRole: modelIndex must be <= std::numeric_limits<int32_t>::max()";
745     VERIFY_LE_INT32_MAX(bufferRole.ioIndex)
746             << "BufferRole: ioIndex must be <= std::numeric_limits<int32_t>::max()";
747     return BufferRole{
748             .modelIndex = static_cast<int32_t>(bufferRole.modelIndex),
749             .ioIndex = static_cast<int32_t>(bufferRole.ioIndex),
750             .probability = bufferRole.probability,
751     };
752 }
753 
unvalidatedConvert(const nn::MeasureTiming & measureTiming)754 nn::GeneralResult<bool> unvalidatedConvert(const nn::MeasureTiming& measureTiming) {
755     return measureTiming == nn::MeasureTiming::YES;
756 }
757 
unvalidatedConvert(const nn::SharedHandle & sharedHandle)758 nn::GeneralResult<common::NativeHandle> unvalidatedConvert(const nn::SharedHandle& sharedHandle) {
759     CHECK(sharedHandle != nullptr);
760     return unvalidatedConvert(*sharedHandle);
761 }
762 
unvalidatedConvert(const nn::SharedMemory & memory)763 nn::GeneralResult<Memory> unvalidatedConvert(const nn::SharedMemory& memory) {
764     if (memory == nullptr) {
765         return (NN_ERROR() << "Unable to convert nullptr memory")
766                 .
767                 operator nn::GeneralResult<Memory>();
768     }
769     return std::visit([](const auto& x) { return unvalidatedConvert(x); }, memory->handle);
770 }
771 
unvalidatedConvert(const nn::ErrorStatus & errorStatus)772 nn::GeneralResult<ErrorStatus> unvalidatedConvert(const nn::ErrorStatus& errorStatus) {
773     switch (errorStatus) {
774         case nn::ErrorStatus::NONE:
775         case nn::ErrorStatus::DEVICE_UNAVAILABLE:
776         case nn::ErrorStatus::GENERAL_FAILURE:
777         case nn::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE:
778         case nn::ErrorStatus::INVALID_ARGUMENT:
779         case nn::ErrorStatus::MISSED_DEADLINE_TRANSIENT:
780         case nn::ErrorStatus::MISSED_DEADLINE_PERSISTENT:
781         case nn::ErrorStatus::RESOURCE_EXHAUSTED_TRANSIENT:
782         case nn::ErrorStatus::RESOURCE_EXHAUSTED_PERSISTENT:
783             return static_cast<ErrorStatus>(errorStatus);
784         default:
785             return ErrorStatus::GENERAL_FAILURE;
786     }
787 }
788 
unvalidatedConvert(const nn::OutputShape & outputShape)789 nn::GeneralResult<OutputShape> unvalidatedConvert(const nn::OutputShape& outputShape) {
790     return OutputShape{.dimensions = NN_TRY(toSigned(outputShape.dimensions)),
791                        .isSufficient = outputShape.isSufficient};
792 }
793 
unvalidatedConvert(const nn::ExecutionPreference & executionPreference)794 nn::GeneralResult<ExecutionPreference> unvalidatedConvert(
795         const nn::ExecutionPreference& executionPreference) {
796     return static_cast<ExecutionPreference>(executionPreference);
797 }
798 
unvalidatedConvert(const nn::OperandType & operandType)799 nn::GeneralResult<OperandType> unvalidatedConvert(const nn::OperandType& operandType) {
800     if (operandType == nn::OperandType::OEM || operandType == nn::OperandType::TENSOR_OEM_BYTE) {
801         return NN_ERROR() << "Unable to convert invalid OperandType " << operandType;
802     }
803     return static_cast<OperandType>(operandType);
804 }
805 
unvalidatedConvert(const nn::Operand::LifeTime & operandLifeTime)806 nn::GeneralResult<OperandLifeTime> unvalidatedConvert(
807         const nn::Operand::LifeTime& operandLifeTime) {
808     return static_cast<OperandLifeTime>(operandLifeTime);
809 }
810 
unvalidatedConvert(const nn::DataLocation & location)811 nn::GeneralResult<DataLocation> unvalidatedConvert(const nn::DataLocation& location) {
812     VERIFY_LE_INT32_MAX(location.poolIndex)
813             << "DataLocation: pool index must be <= std::numeric_limits<int32_t>::max()";
814     return DataLocation{
815             .poolIndex = static_cast<int32_t>(location.poolIndex),
816             .offset = static_cast<int64_t>(location.offset),
817             .length = static_cast<int64_t>(location.length),
818     };
819 }
820 
unvalidatedConvert(const nn::Operand::ExtraParams & extraParams)821 nn::GeneralResult<std::optional<OperandExtraParams>> unvalidatedConvert(
822         const nn::Operand::ExtraParams& extraParams) {
823     return std::visit(
824             overloaded{
825                     [](const nn::Operand::NoParams&)
826                             -> nn::GeneralResult<std::optional<OperandExtraParams>> {
827                         return std::nullopt;
828                     },
829                     [](const nn::Operand::SymmPerChannelQuantParams& symmPerChannelQuantParams)
830                             -> nn::GeneralResult<std::optional<OperandExtraParams>> {
831                         if (symmPerChannelQuantParams.channelDim >
832                             std::numeric_limits<int32_t>::max()) {
833                             // Using explicit type conversion because std::optional in successful
834                             // result confuses the compiler.
835                             return (NN_ERROR() << "symmPerChannelQuantParams.channelDim must be <= "
836                                                   "std::numeric_limits<int32_t>::max(), received: "
837                                                << symmPerChannelQuantParams.channelDim)
838                                     .
839                                     operator nn::GeneralResult<std::optional<OperandExtraParams>>();
840                         }
841                         return OperandExtraParams::make<OperandExtraParams::Tag::channelQuant>(
842                                 SymmPerChannelQuantParams{
843                                         .scales = symmPerChannelQuantParams.scales,
844                                         .channelDim = static_cast<int32_t>(
845                                                 symmPerChannelQuantParams.channelDim),
846                                 });
847                     },
848                     [](const nn::Operand::ExtensionParams& extensionParams)
849                             -> nn::GeneralResult<std::optional<OperandExtraParams>> {
850                         return OperandExtraParams::make<OperandExtraParams::Tag::extension>(
851                                 extensionParams);
852                     },
853             },
854             extraParams);
855 }
856 
unvalidatedConvert(const nn::Operand & operand)857 nn::GeneralResult<Operand> unvalidatedConvert(const nn::Operand& operand) {
858     return Operand{
859             .type = NN_TRY(unvalidatedConvert(operand.type)),
860             .dimensions = NN_TRY(toSigned(operand.dimensions)),
861             .scale = operand.scale,
862             .zeroPoint = operand.zeroPoint,
863             .lifetime = NN_TRY(unvalidatedConvert(operand.lifetime)),
864             .location = NN_TRY(unvalidatedConvert(operand.location)),
865             .extraParams = NN_TRY(unvalidatedConvert(operand.extraParams)),
866     };
867 }
868 
unvalidatedConvert(const nn::OperationType & operationType)869 nn::GeneralResult<OperationType> unvalidatedConvert(const nn::OperationType& operationType) {
870     if (operationType == nn::OperationType::OEM_OPERATION) {
871         return NN_ERROR() << "Unable to convert invalid OperationType OEM_OPERATION";
872     }
873     return static_cast<OperationType>(operationType);
874 }
875 
unvalidatedConvert(const nn::Operation & operation)876 nn::GeneralResult<Operation> unvalidatedConvert(const nn::Operation& operation) {
877     return Operation{
878             .type = NN_TRY(unvalidatedConvert(operation.type)),
879             .inputs = NN_TRY(toSigned(operation.inputs)),
880             .outputs = NN_TRY(toSigned(operation.outputs)),
881     };
882 }
883 
unvalidatedConvert(const nn::Model::Subgraph & subgraph)884 nn::GeneralResult<Subgraph> unvalidatedConvert(const nn::Model::Subgraph& subgraph) {
885     return Subgraph{
886             .operands = NN_TRY(unvalidatedConvert(subgraph.operands)),
887             .operations = NN_TRY(unvalidatedConvert(subgraph.operations)),
888             .inputIndexes = NN_TRY(toSigned(subgraph.inputIndexes)),
889             .outputIndexes = NN_TRY(toSigned(subgraph.outputIndexes)),
890     };
891 }
892 
unvalidatedConvert(const nn::Model::OperandValues & operandValues)893 nn::GeneralResult<std::vector<uint8_t>> unvalidatedConvert(
894         const nn::Model::OperandValues& operandValues) {
895     return std::vector<uint8_t>(operandValues.data(), operandValues.data() + operandValues.size());
896 }
897 
unvalidatedConvert(const nn::Model::ExtensionNameAndPrefix & extensionNameToPrefix)898 nn::GeneralResult<ExtensionNameAndPrefix> unvalidatedConvert(
899         const nn::Model::ExtensionNameAndPrefix& extensionNameToPrefix) {
900     return ExtensionNameAndPrefix{
901             .name = extensionNameToPrefix.name,
902             .prefix = extensionNameToPrefix.prefix,
903     };
904 }
905 
unvalidatedConvert(const nn::Model & model)906 nn::GeneralResult<Model> unvalidatedConvert(const nn::Model& model) {
907     return Model{
908             .main = NN_TRY(unvalidatedConvert(model.main)),
909             .referenced = NN_TRY(unvalidatedConvert(model.referenced)),
910             .operandValues = NN_TRY(unvalidatedConvert(model.operandValues)),
911             .pools = NN_TRY(unvalidatedConvert(model.pools)),
912             .relaxComputationFloat32toFloat16 = model.relaxComputationFloat32toFloat16,
913             .extensionNameToPrefix = NN_TRY(unvalidatedConvert(model.extensionNameToPrefix)),
914     };
915 }
916 
unvalidatedConvert(const nn::Priority & priority)917 nn::GeneralResult<Priority> unvalidatedConvert(const nn::Priority& priority) {
918     return static_cast<Priority>(priority);
919 }
920 
unvalidatedConvert(const nn::Request & request)921 nn::GeneralResult<Request> unvalidatedConvert(const nn::Request& request) {
922     return Request{
923             .inputs = NN_TRY(unvalidatedConvert(request.inputs)),
924             .outputs = NN_TRY(unvalidatedConvert(request.outputs)),
925             .pools = NN_TRY(unvalidatedConvert(request.pools)),
926     };
927 }
928 
unvalidatedConvert(const nn::Request::Argument & requestArgument)929 nn::GeneralResult<RequestArgument> unvalidatedConvert(
930         const nn::Request::Argument& requestArgument) {
931     if (requestArgument.lifetime == nn::Request::Argument::LifeTime::POINTER) {
932         return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT)
933                << "Request cannot be unvalidatedConverted because it contains pointer-based memory";
934     }
935     const bool hasNoValue = requestArgument.lifetime == nn::Request::Argument::LifeTime::NO_VALUE;
936     return RequestArgument{
937             .hasNoValue = hasNoValue,
938             .location = NN_TRY(unvalidatedConvert(requestArgument.location)),
939             .dimensions = NN_TRY(toSigned(requestArgument.dimensions)),
940     };
941 }
942 
unvalidatedConvert(const nn::Request::MemoryPool & memoryPool)943 nn::GeneralResult<RequestMemoryPool> unvalidatedConvert(const nn::Request::MemoryPool& memoryPool) {
944     return std::visit(
945             overloaded{
946                     [](const nn::SharedMemory& memory) -> nn::GeneralResult<RequestMemoryPool> {
947                         return RequestMemoryPool::make<RequestMemoryPool::Tag::pool>(
948                                 NN_TRY(unvalidatedConvert(memory)));
949                     },
950                     [](const nn::Request::MemoryDomainToken& token)
951                             -> nn::GeneralResult<RequestMemoryPool> {
952                         return RequestMemoryPool::make<RequestMemoryPool::Tag::token>(
953                                 underlyingType(token));
954                     },
955                     [](const nn::SharedBuffer& /*buffer*/) {
956                         return (NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
957                                 << "Unable to make memory pool from IBuffer")
958                                 .
959                                 operator nn::GeneralResult<RequestMemoryPool>();
960                     },
961             },
962             memoryPool);
963 }
964 
unvalidatedConvert(const nn::Timing & timing)965 nn::GeneralResult<Timing> unvalidatedConvert(const nn::Timing& timing) {
966     return Timing{
967             .timeOnDeviceNs = NN_TRY(unvalidatedConvert(timing.timeOnDevice)),
968             .timeInDriverNs = NN_TRY(unvalidatedConvert(timing.timeInDriver)),
969     };
970 }
971 
unvalidatedConvert(const nn::Duration & duration)972 nn::GeneralResult<int64_t> unvalidatedConvert(const nn::Duration& duration) {
973     if (duration < nn::Duration::zero()) {
974         return NN_ERROR() << "Unable to convert invalid (negative) duration";
975     }
976     constexpr std::chrono::nanoseconds::rep kIntMax = std::numeric_limits<int64_t>::max();
977     const auto count = duration.count();
978     return static_cast<int64_t>(std::min(count, kIntMax));
979 }
980 
unvalidatedConvert(const nn::OptionalDuration & optionalDuration)981 nn::GeneralResult<int64_t> unvalidatedConvert(const nn::OptionalDuration& optionalDuration) {
982     if (!optionalDuration.has_value()) {
983         return kNoTiming;
984     }
985     return unvalidatedConvert(optionalDuration.value());
986 }
987 
unvalidatedConvert(const nn::OptionalTimePoint & optionalTimePoint)988 nn::GeneralResult<int64_t> unvalidatedConvert(const nn::OptionalTimePoint& optionalTimePoint) {
989     if (!optionalTimePoint.has_value()) {
990         return kNoTiming;
991     }
992     return unvalidatedConvert(optionalTimePoint->time_since_epoch());
993 }
994 
unvalidatedConvert(const nn::SyncFence & syncFence)995 nn::GeneralResult<ndk::ScopedFileDescriptor> unvalidatedConvert(const nn::SyncFence& syncFence) {
996     auto duplicatedFd = NN_TRY(nn::dupFd(syncFence.getFd()));
997     return ndk::ScopedFileDescriptor(duplicatedFd.release());
998 }
999 
unvalidatedConvertCache(const nn::SharedHandle & handle)1000 nn::GeneralResult<ndk::ScopedFileDescriptor> unvalidatedConvertCache(
1001         const nn::SharedHandle& handle) {
1002     if (handle->ints.size() != 0) {
1003         NN_ERROR() << "Cache handle must not contain ints";
1004     }
1005     if (handle->fds.size() != 1) {
1006         NN_ERROR() << "Cache handle must contain exactly one fd but contains "
1007                    << handle->fds.size();
1008     }
1009     auto duplicatedFd = NN_TRY(nn::dupFd(handle->fds.front().get()));
1010     return ndk::ScopedFileDescriptor(duplicatedFd.release());
1011 }
1012 
convert(const nn::CacheToken & cacheToken)1013 nn::GeneralResult<std::vector<uint8_t>> convert(const nn::CacheToken& cacheToken) {
1014     return validatedConvert(cacheToken);
1015 }
1016 
convert(const nn::BufferDesc & bufferDesc)1017 nn::GeneralResult<BufferDesc> convert(const nn::BufferDesc& bufferDesc) {
1018     return validatedConvert(bufferDesc);
1019 }
1020 
convert(const nn::MeasureTiming & measureTiming)1021 nn::GeneralResult<bool> convert(const nn::MeasureTiming& measureTiming) {
1022     return validatedConvert(measureTiming);
1023 }
1024 
convert(const nn::SharedMemory & memory)1025 nn::GeneralResult<Memory> convert(const nn::SharedMemory& memory) {
1026     return validatedConvert(memory);
1027 }
1028 
convert(const nn::ErrorStatus & errorStatus)1029 nn::GeneralResult<ErrorStatus> convert(const nn::ErrorStatus& errorStatus) {
1030     return validatedConvert(errorStatus);
1031 }
1032 
convert(const nn::ExecutionPreference & executionPreference)1033 nn::GeneralResult<ExecutionPreference> convert(const nn::ExecutionPreference& executionPreference) {
1034     return validatedConvert(executionPreference);
1035 }
1036 
convert(const nn::Model & model)1037 nn::GeneralResult<Model> convert(const nn::Model& model) {
1038     return validatedConvert(model);
1039 }
1040 
convert(const nn::Priority & priority)1041 nn::GeneralResult<Priority> convert(const nn::Priority& priority) {
1042     return validatedConvert(priority);
1043 }
1044 
convert(const nn::Request & request)1045 nn::GeneralResult<Request> convert(const nn::Request& request) {
1046     return validatedConvert(request);
1047 }
1048 
convert(const nn::Timing & timing)1049 nn::GeneralResult<Timing> convert(const nn::Timing& timing) {
1050     return validatedConvert(timing);
1051 }
1052 
convert(const nn::OptionalDuration & optionalDuration)1053 nn::GeneralResult<int64_t> convert(const nn::OptionalDuration& optionalDuration) {
1054     return validatedConvert(optionalDuration);
1055 }
1056 
convert(const nn::OptionalTimePoint & outputShapes)1057 nn::GeneralResult<int64_t> convert(const nn::OptionalTimePoint& outputShapes) {
1058     return validatedConvert(outputShapes);
1059 }
1060 
convert(const std::vector<nn::BufferRole> & bufferRoles)1061 nn::GeneralResult<std::vector<BufferRole>> convert(const std::vector<nn::BufferRole>& bufferRoles) {
1062     return validatedConvert(bufferRoles);
1063 }
1064 
convert(const std::vector<nn::OutputShape> & outputShapes)1065 nn::GeneralResult<std::vector<OutputShape>> convert(
1066         const std::vector<nn::OutputShape>& outputShapes) {
1067     return validatedConvert(outputShapes);
1068 }
1069 
convert(const std::vector<nn::SharedHandle> & cacheHandles)1070 nn::GeneralResult<std::vector<ndk::ScopedFileDescriptor>> convert(
1071         const std::vector<nn::SharedHandle>& cacheHandles) {
1072     const auto version = NN_TRY(hal::utils::makeGeneralFailure(nn::validate(cacheHandles)));
1073     if (version > kVersion) {
1074         return NN_ERROR() << "Insufficient version: " << version << " vs required " << kVersion;
1075     }
1076     std::vector<ndk::ScopedFileDescriptor> cacheFds;
1077     cacheFds.reserve(cacheHandles.size());
1078     for (const auto& cacheHandle : cacheHandles) {
1079         cacheFds.push_back(NN_TRY(unvalidatedConvertCache(cacheHandle)));
1080     }
1081     return cacheFds;
1082 }
1083 
convert(const std::vector<nn::SyncFence> & syncFences)1084 nn::GeneralResult<std::vector<ndk::ScopedFileDescriptor>> convert(
1085         const std::vector<nn::SyncFence>& syncFences) {
1086     return validatedConvert(syncFences);
1087 }
1088 
toSigned(const std::vector<uint32_t> & vec)1089 nn::GeneralResult<std::vector<int32_t>> toSigned(const std::vector<uint32_t>& vec) {
1090     if (!std::all_of(vec.begin(), vec.end(),
1091                      [](uint32_t v) { return v <= std::numeric_limits<int32_t>::max(); })) {
1092         return NN_ERROR() << "Vector contains a value that doesn't fit into int32_t.";
1093     }
1094     return std::vector<int32_t>(vec.begin(), vec.end());
1095 }
1096 
1097 }  // namespace aidl::android::hardware::neuralnetworks::utils
1098