1 /*
2 * Copyright (C) 2021 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "Conversions.h"
18
19 #include <aidl/android/hardware/common/Ashmem.h>
20 #include <aidl/android/hardware/common/MappableFile.h>
21 #include <aidl/android/hardware/common/NativeHandle.h>
22 #include <aidl/android/hardware/graphics/common/HardwareBuffer.h>
23 #include <aidlcommonsupport/NativeHandle.h>
24 #include <android-base/logging.h>
25 #include <android-base/mapped_file.h>
26 #include <android-base/unique_fd.h>
27 #include <android/binder_auto_utils.h>
28 #include <android/hardware_buffer.h>
29 #include <cutils/native_handle.h>
30 #include <nnapi/OperandTypes.h>
31 #include <nnapi/OperationTypes.h>
32 #include <nnapi/Result.h>
33 #include <nnapi/SharedMemory.h>
34 #include <nnapi/TypeUtils.h>
35 #include <nnapi/Types.h>
36 #include <nnapi/Validation.h>
37 #include <nnapi/hal/CommonUtils.h>
38 #include <nnapi/hal/HandleError.h>
39 #include <vndk/hardware_buffer.h>
40
41 #include <algorithm>
42 #include <chrono>
43 #include <functional>
44 #include <iterator>
45 #include <limits>
46 #include <type_traits>
47 #include <utility>
48
49 #include "Utils.h"
50
51 #define VERIFY_NON_NEGATIVE(value) \
52 while (UNLIKELY(value < 0)) return NN_ERROR()
53
54 #define VERIFY_LE_INT32_MAX(value) \
55 while (UNLIKELY(value > std::numeric_limits<int32_t>::max())) return NN_ERROR()
56
57 namespace {
58 template <typename Type>
underlyingType(Type value)59 constexpr std::underlying_type_t<Type> underlyingType(Type value) {
60 return static_cast<std::underlying_type_t<Type>>(value);
61 }
62
63 constexpr int64_t kNoTiming = -1;
64
65 } // namespace
66
67 namespace android::nn {
68 namespace {
69
70 using ::aidl::android::hardware::common::NativeHandle;
71
72 template <typename Input>
73 using UnvalidatedConvertOutput =
74 std::decay_t<decltype(unvalidatedConvert(std::declval<Input>()).value())>;
75
76 template <typename Type>
unvalidatedConvertVec(const std::vector<Type> & arguments)77 GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> unvalidatedConvertVec(
78 const std::vector<Type>& arguments) {
79 std::vector<UnvalidatedConvertOutput<Type>> canonical;
80 canonical.reserve(arguments.size());
81 for (const auto& argument : arguments) {
82 canonical.push_back(NN_TRY(nn::unvalidatedConvert(argument)));
83 }
84 return canonical;
85 }
86
87 template <typename Type>
unvalidatedConvert(const std::vector<Type> & arguments)88 GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> unvalidatedConvert(
89 const std::vector<Type>& arguments) {
90 return unvalidatedConvertVec(arguments);
91 }
92
93 template <typename Type>
validatedConvert(const Type & halObject)94 GeneralResult<UnvalidatedConvertOutput<Type>> validatedConvert(const Type& halObject) {
95 auto canonical = NN_TRY(nn::unvalidatedConvert(halObject));
96 NN_TRY(aidl_hal::utils::compliantVersion(canonical));
97 return canonical;
98 }
99
100 template <typename Type>
validatedConvert(const std::vector<Type> & arguments)101 GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> validatedConvert(
102 const std::vector<Type>& arguments) {
103 std::vector<UnvalidatedConvertOutput<Type>> canonical;
104 canonical.reserve(arguments.size());
105 for (const auto& argument : arguments) {
106 canonical.push_back(NN_TRY(validatedConvert(argument)));
107 }
108 return canonical;
109 }
110
unvalidatedConvertHelper(const NativeHandle & aidlNativeHandle)111 GeneralResult<Handle> unvalidatedConvertHelper(const NativeHandle& aidlNativeHandle) {
112 std::vector<base::unique_fd> fds;
113 fds.reserve(aidlNativeHandle.fds.size());
114 for (const auto& fd : aidlNativeHandle.fds) {
115 auto duplicatedFd = NN_TRY(dupFd(fd.get()));
116 fds.emplace_back(duplicatedFd.release());
117 }
118
119 return Handle{.fds = std::move(fds), .ints = aidlNativeHandle.ints};
120 }
121
122 struct NativeHandleDeleter {
operator ()android::nn::__anon497330ac0211::NativeHandleDeleter123 void operator()(native_handle_t* handle) const {
124 if (handle) {
125 native_handle_close(handle);
126 native_handle_delete(handle);
127 }
128 }
129 };
130
131 using UniqueNativeHandle = std::unique_ptr<native_handle_t, NativeHandleDeleter>;
132
nativeHandleFromAidlHandle(const NativeHandle & handle)133 GeneralResult<UniqueNativeHandle> nativeHandleFromAidlHandle(const NativeHandle& handle) {
134 auto nativeHandle = UniqueNativeHandle(dupFromAidl(handle));
135 if (nativeHandle.get() == nullptr) {
136 return NN_ERROR() << "android::dupFromAidl failed to convert the common::NativeHandle to a "
137 "native_handle_t";
138 }
139 if (!std::all_of(nativeHandle->data + 0, nativeHandle->data + nativeHandle->numFds,
140 [](int fd) { return fd >= 0; })) {
141 return NN_ERROR() << "android::dupFromAidl returned an invalid native_handle_t";
142 }
143 return nativeHandle;
144 }
145
146 } // anonymous namespace
147
unvalidatedConvert(const aidl_hal::OperandType & operandType)148 GeneralResult<OperandType> unvalidatedConvert(const aidl_hal::OperandType& operandType) {
149 VERIFY_NON_NEGATIVE(underlyingType(operandType)) << "Negative operand types are not allowed.";
150 const auto canonical = static_cast<OperandType>(operandType);
151 if (canonical == OperandType::OEM || canonical == OperandType::TENSOR_OEM_BYTE) {
152 return NN_ERROR() << "Unable to convert invalid OperandType " << canonical;
153 }
154 return canonical;
155 }
156
unvalidatedConvert(const aidl_hal::OperationType & operationType)157 GeneralResult<OperationType> unvalidatedConvert(const aidl_hal::OperationType& operationType) {
158 VERIFY_NON_NEGATIVE(underlyingType(operationType))
159 << "Negative operation types are not allowed.";
160 const auto canonical = static_cast<OperationType>(operationType);
161 if (canonical == OperationType::OEM_OPERATION) {
162 return NN_ERROR() << "Unable to convert invalid OperationType OEM_OPERATION";
163 }
164 return canonical;
165 }
166
unvalidatedConvert(const aidl_hal::DeviceType & deviceType)167 GeneralResult<DeviceType> unvalidatedConvert(const aidl_hal::DeviceType& deviceType) {
168 return static_cast<DeviceType>(deviceType);
169 }
170
unvalidatedConvert(const aidl_hal::Priority & priority)171 GeneralResult<Priority> unvalidatedConvert(const aidl_hal::Priority& priority) {
172 return static_cast<Priority>(priority);
173 }
174
unvalidatedConvert(const aidl_hal::Capabilities & capabilities)175 GeneralResult<Capabilities> unvalidatedConvert(const aidl_hal::Capabilities& capabilities) {
176 const bool validOperandTypes = std::all_of(
177 capabilities.operandPerformance.begin(), capabilities.operandPerformance.end(),
178 [](const aidl_hal::OperandPerformance& operandPerformance) {
179 return validatedConvert(operandPerformance.type).has_value();
180 });
181 if (!validOperandTypes) {
182 return NN_ERROR() << "Invalid OperandType when unvalidatedConverting OperandPerformance in "
183 "Capabilities";
184 }
185
186 auto operandPerformance = NN_TRY(unvalidatedConvert(capabilities.operandPerformance));
187 auto table = NN_TRY(hal::utils::makeGeneralFailure(
188 Capabilities::OperandPerformanceTable::create(std::move(operandPerformance)),
189 nn::ErrorStatus::GENERAL_FAILURE));
190
191 return Capabilities{
192 .relaxedFloat32toFloat16PerformanceScalar = NN_TRY(
193 unvalidatedConvert(capabilities.relaxedFloat32toFloat16PerformanceScalar)),
194 .relaxedFloat32toFloat16PerformanceTensor = NN_TRY(
195 unvalidatedConvert(capabilities.relaxedFloat32toFloat16PerformanceTensor)),
196 .operandPerformance = std::move(table),
197 .ifPerformance = NN_TRY(unvalidatedConvert(capabilities.ifPerformance)),
198 .whilePerformance = NN_TRY(unvalidatedConvert(capabilities.whilePerformance)),
199 };
200 }
201
unvalidatedConvert(const aidl_hal::OperandPerformance & operandPerformance)202 GeneralResult<Capabilities::OperandPerformance> unvalidatedConvert(
203 const aidl_hal::OperandPerformance& operandPerformance) {
204 return Capabilities::OperandPerformance{
205 .type = NN_TRY(unvalidatedConvert(operandPerformance.type)),
206 .info = NN_TRY(unvalidatedConvert(operandPerformance.info)),
207 };
208 }
209
unvalidatedConvert(const aidl_hal::PerformanceInfo & performanceInfo)210 GeneralResult<Capabilities::PerformanceInfo> unvalidatedConvert(
211 const aidl_hal::PerformanceInfo& performanceInfo) {
212 return Capabilities::PerformanceInfo{
213 .execTime = performanceInfo.execTime,
214 .powerUsage = performanceInfo.powerUsage,
215 };
216 }
217
unvalidatedConvert(const aidl_hal::DataLocation & location)218 GeneralResult<DataLocation> unvalidatedConvert(const aidl_hal::DataLocation& location) {
219 VERIFY_NON_NEGATIVE(location.poolIndex) << "DataLocation: pool index must not be negative";
220 VERIFY_NON_NEGATIVE(location.offset) << "DataLocation: offset must not be negative";
221 VERIFY_NON_NEGATIVE(location.length) << "DataLocation: length must not be negative";
222 VERIFY_NON_NEGATIVE(location.padding) << "DataLocation: padding must not be negative";
223 if (location.offset > std::numeric_limits<uint32_t>::max()) {
224 return NN_ERROR() << "DataLocation: offset must be <= std::numeric_limits<uint32_t>::max()";
225 }
226 if (location.length > std::numeric_limits<uint32_t>::max()) {
227 return NN_ERROR() << "DataLocation: length must be <= std::numeric_limits<uint32_t>::max()";
228 }
229 if (location.padding > std::numeric_limits<uint32_t>::max()) {
230 return NN_ERROR()
231 << "DataLocation: padding must be <= std::numeric_limits<uint32_t>::max()";
232 }
233 return DataLocation{
234 .poolIndex = static_cast<uint32_t>(location.poolIndex),
235 .offset = static_cast<uint32_t>(location.offset),
236 .length = static_cast<uint32_t>(location.length),
237 .padding = static_cast<uint32_t>(location.padding),
238 };
239 }
240
unvalidatedConvert(const aidl_hal::Operation & operation)241 GeneralResult<Operation> unvalidatedConvert(const aidl_hal::Operation& operation) {
242 return Operation{
243 .type = NN_TRY(unvalidatedConvert(operation.type)),
244 .inputs = NN_TRY(toUnsigned(operation.inputs)),
245 .outputs = NN_TRY(toUnsigned(operation.outputs)),
246 };
247 }
248
unvalidatedConvert(const aidl_hal::OperandLifeTime & operandLifeTime)249 GeneralResult<Operand::LifeTime> unvalidatedConvert(
250 const aidl_hal::OperandLifeTime& operandLifeTime) {
251 return static_cast<Operand::LifeTime>(operandLifeTime);
252 }
253
unvalidatedConvert(const aidl_hal::Operand & operand)254 GeneralResult<Operand> unvalidatedConvert(const aidl_hal::Operand& operand) {
255 return Operand{
256 .type = NN_TRY(unvalidatedConvert(operand.type)),
257 .dimensions = NN_TRY(toUnsigned(operand.dimensions)),
258 .scale = operand.scale,
259 .zeroPoint = operand.zeroPoint,
260 .lifetime = NN_TRY(unvalidatedConvert(operand.lifetime)),
261 .location = NN_TRY(unvalidatedConvert(operand.location)),
262 .extraParams = NN_TRY(unvalidatedConvert(operand.extraParams)),
263 };
264 }
265
unvalidatedConvert(const std::optional<aidl_hal::OperandExtraParams> & optionalExtraParams)266 GeneralResult<Operand::ExtraParams> unvalidatedConvert(
267 const std::optional<aidl_hal::OperandExtraParams>& optionalExtraParams) {
268 if (!optionalExtraParams.has_value()) {
269 return Operand::NoParams{};
270 }
271 const auto& extraParams = optionalExtraParams.value();
272 using Tag = aidl_hal::OperandExtraParams::Tag;
273 switch (extraParams.getTag()) {
274 case Tag::channelQuant:
275 return unvalidatedConvert(extraParams.get<Tag::channelQuant>());
276 case Tag::extension:
277 return extraParams.get<Tag::extension>();
278 }
279 return NN_ERROR() << "Unrecognized Operand::ExtraParams tag: "
280 << underlyingType(extraParams.getTag());
281 }
282
unvalidatedConvert(const aidl_hal::SymmPerChannelQuantParams & symmPerChannelQuantParams)283 GeneralResult<Operand::SymmPerChannelQuantParams> unvalidatedConvert(
284 const aidl_hal::SymmPerChannelQuantParams& symmPerChannelQuantParams) {
285 VERIFY_NON_NEGATIVE(symmPerChannelQuantParams.channelDim)
286 << "Per-channel quantization channel dimension must not be negative.";
287 return Operand::SymmPerChannelQuantParams{
288 .scales = symmPerChannelQuantParams.scales,
289 .channelDim = static_cast<uint32_t>(symmPerChannelQuantParams.channelDim),
290 };
291 }
292
unvalidatedConvert(const aidl_hal::Model & model)293 GeneralResult<Model> unvalidatedConvert(const aidl_hal::Model& model) {
294 return Model{
295 .main = NN_TRY(unvalidatedConvert(model.main)),
296 .referenced = NN_TRY(unvalidatedConvert(model.referenced)),
297 .operandValues = NN_TRY(unvalidatedConvert(model.operandValues)),
298 .pools = NN_TRY(unvalidatedConvert(model.pools)),
299 .relaxComputationFloat32toFloat16 = model.relaxComputationFloat32toFloat16,
300 .extensionNameToPrefix = NN_TRY(unvalidatedConvert(model.extensionNameToPrefix)),
301 };
302 }
303
unvalidatedConvert(const aidl_hal::Subgraph & subgraph)304 GeneralResult<Model::Subgraph> unvalidatedConvert(const aidl_hal::Subgraph& subgraph) {
305 return Model::Subgraph{
306 .operands = NN_TRY(unvalidatedConvert(subgraph.operands)),
307 .operations = NN_TRY(unvalidatedConvert(subgraph.operations)),
308 .inputIndexes = NN_TRY(toUnsigned(subgraph.inputIndexes)),
309 .outputIndexes = NN_TRY(toUnsigned(subgraph.outputIndexes)),
310 };
311 }
312
unvalidatedConvert(const aidl_hal::ExtensionNameAndPrefix & extensionNameAndPrefix)313 GeneralResult<Model::ExtensionNameAndPrefix> unvalidatedConvert(
314 const aidl_hal::ExtensionNameAndPrefix& extensionNameAndPrefix) {
315 return Model::ExtensionNameAndPrefix{
316 .name = extensionNameAndPrefix.name,
317 .prefix = extensionNameAndPrefix.prefix,
318 };
319 }
320
unvalidatedConvert(const aidl_hal::Extension & extension)321 GeneralResult<Extension> unvalidatedConvert(const aidl_hal::Extension& extension) {
322 return Extension{
323 .name = extension.name,
324 .operandTypes = NN_TRY(unvalidatedConvert(extension.operandTypes)),
325 };
326 }
327
unvalidatedConvert(const aidl_hal::ExtensionOperandTypeInformation & operandTypeInformation)328 GeneralResult<Extension::OperandTypeInformation> unvalidatedConvert(
329 const aidl_hal::ExtensionOperandTypeInformation& operandTypeInformation) {
330 VERIFY_NON_NEGATIVE(operandTypeInformation.byteSize)
331 << "Extension operand type byte size must not be negative";
332 return Extension::OperandTypeInformation{
333 .type = operandTypeInformation.type,
334 .isTensor = operandTypeInformation.isTensor,
335 .byteSize = static_cast<uint32_t>(operandTypeInformation.byteSize),
336 };
337 }
338
unvalidatedConvert(const aidl_hal::OutputShape & outputShape)339 GeneralResult<OutputShape> unvalidatedConvert(const aidl_hal::OutputShape& outputShape) {
340 return OutputShape{
341 .dimensions = NN_TRY(toUnsigned(outputShape.dimensions)),
342 .isSufficient = outputShape.isSufficient,
343 };
344 }
345
unvalidatedConvert(bool measureTiming)346 GeneralResult<MeasureTiming> unvalidatedConvert(bool measureTiming) {
347 return measureTiming ? MeasureTiming::YES : MeasureTiming::NO;
348 }
349
unvalidatedConvert(const aidl_hal::Memory & memory)350 GeneralResult<SharedMemory> unvalidatedConvert(const aidl_hal::Memory& memory) {
351 using Tag = aidl_hal::Memory::Tag;
352 switch (memory.getTag()) {
353 case Tag::ashmem: {
354 const auto& ashmem = memory.get<Tag::ashmem>();
355 VERIFY_NON_NEGATIVE(ashmem.size) << "Memory size must not be negative";
356 if (ashmem.size > std::numeric_limits<size_t>::max()) {
357 return NN_ERROR() << "Memory: size must be <= std::numeric_limits<size_t>::max()";
358 }
359
360 auto handle = Memory::Ashmem{
361 .fd = NN_TRY(dupFd(ashmem.fd.get())),
362 .size = static_cast<size_t>(ashmem.size),
363 };
364 return std::make_shared<const Memory>(Memory{.handle = std::move(handle)});
365 }
366 case Tag::mappableFile: {
367 const auto& mappableFile = memory.get<Tag::mappableFile>();
368 VERIFY_NON_NEGATIVE(mappableFile.length) << "Memory size must not be negative";
369 VERIFY_NON_NEGATIVE(mappableFile.offset) << "Memory offset must not be negative";
370 if (mappableFile.length > std::numeric_limits<size_t>::max()) {
371 return NN_ERROR() << "Memory: size must be <= std::numeric_limits<size_t>::max()";
372 }
373 if (mappableFile.offset > std::numeric_limits<size_t>::max()) {
374 return NN_ERROR() << "Memory: offset must be <= std::numeric_limits<size_t>::max()";
375 }
376
377 const size_t size = static_cast<size_t>(mappableFile.length);
378 const int prot = mappableFile.prot;
379 const int fd = mappableFile.fd.get();
380 const size_t offset = static_cast<size_t>(mappableFile.offset);
381
382 return createSharedMemoryFromFd(size, prot, fd, offset);
383 }
384 case Tag::hardwareBuffer: {
385 const auto& hardwareBuffer = memory.get<Tag::hardwareBuffer>();
386
387 const UniqueNativeHandle handle =
388 NN_TRY(nativeHandleFromAidlHandle(hardwareBuffer.handle));
389 const native_handle_t* nativeHandle = handle.get();
390
391 const AHardwareBuffer_Desc desc{
392 .width = static_cast<uint32_t>(hardwareBuffer.description.width),
393 .height = static_cast<uint32_t>(hardwareBuffer.description.height),
394 .layers = static_cast<uint32_t>(hardwareBuffer.description.layers),
395 .format = static_cast<uint32_t>(hardwareBuffer.description.format),
396 .usage = static_cast<uint64_t>(hardwareBuffer.description.usage),
397 .stride = static_cast<uint32_t>(hardwareBuffer.description.stride),
398 };
399 AHardwareBuffer* ahwb = nullptr;
400 const status_t status = AHardwareBuffer_createFromHandle(
401 &desc, nativeHandle, AHARDWAREBUFFER_CREATE_FROM_HANDLE_METHOD_CLONE, &ahwb);
402 if (status != NO_ERROR) {
403 return NN_ERROR() << "createFromHandle failed";
404 }
405
406 return createSharedMemoryFromAHWB(ahwb, /*takeOwnership=*/true);
407 }
408 }
409 return NN_ERROR() << "Unrecognized Memory::Tag: " << memory.getTag();
410 }
411
unvalidatedConvert(const aidl_hal::Timing & timing)412 GeneralResult<Timing> unvalidatedConvert(const aidl_hal::Timing& timing) {
413 if (timing.timeInDriverNs < -1) {
414 return NN_ERROR() << "Timing: timeInDriverNs must not be less than -1";
415 }
416 if (timing.timeOnDeviceNs < -1) {
417 return NN_ERROR() << "Timing: timeOnDeviceNs must not be less than -1";
418 }
419 constexpr auto convertTiming = [](int64_t halTiming) -> OptionalDuration {
420 if (halTiming == kNoTiming) {
421 return {};
422 }
423 return nn::Duration(static_cast<uint64_t>(halTiming));
424 };
425 return Timing{.timeOnDevice = convertTiming(timing.timeOnDeviceNs),
426 .timeInDriver = convertTiming(timing.timeInDriverNs)};
427 }
428
unvalidatedConvert(const std::vector<uint8_t> & operandValues)429 GeneralResult<Model::OperandValues> unvalidatedConvert(const std::vector<uint8_t>& operandValues) {
430 return Model::OperandValues(operandValues.data(), operandValues.size());
431 }
432
unvalidatedConvert(const aidl_hal::BufferDesc & bufferDesc)433 GeneralResult<BufferDesc> unvalidatedConvert(const aidl_hal::BufferDesc& bufferDesc) {
434 return BufferDesc{.dimensions = NN_TRY(toUnsigned(bufferDesc.dimensions))};
435 }
436
unvalidatedConvert(const aidl_hal::BufferRole & bufferRole)437 GeneralResult<BufferRole> unvalidatedConvert(const aidl_hal::BufferRole& bufferRole) {
438 VERIFY_NON_NEGATIVE(bufferRole.modelIndex) << "BufferRole: modelIndex must not be negative";
439 VERIFY_NON_NEGATIVE(bufferRole.ioIndex) << "BufferRole: ioIndex must not be negative";
440 return BufferRole{
441 .modelIndex = static_cast<uint32_t>(bufferRole.modelIndex),
442 .ioIndex = static_cast<uint32_t>(bufferRole.ioIndex),
443 .probability = bufferRole.probability,
444 };
445 }
446
unvalidatedConvert(const aidl_hal::Request & request)447 GeneralResult<Request> unvalidatedConvert(const aidl_hal::Request& request) {
448 return Request{
449 .inputs = NN_TRY(unvalidatedConvert(request.inputs)),
450 .outputs = NN_TRY(unvalidatedConvert(request.outputs)),
451 .pools = NN_TRY(unvalidatedConvert(request.pools)),
452 };
453 }
454
unvalidatedConvert(const aidl_hal::RequestArgument & argument)455 GeneralResult<Request::Argument> unvalidatedConvert(const aidl_hal::RequestArgument& argument) {
456 const auto lifetime = argument.hasNoValue ? Request::Argument::LifeTime::NO_VALUE
457 : Request::Argument::LifeTime::POOL;
458 return Request::Argument{
459 .lifetime = lifetime,
460 .location = NN_TRY(unvalidatedConvert(argument.location)),
461 .dimensions = NN_TRY(toUnsigned(argument.dimensions)),
462 };
463 }
464
unvalidatedConvert(const aidl_hal::RequestMemoryPool & memoryPool)465 GeneralResult<Request::MemoryPool> unvalidatedConvert(
466 const aidl_hal::RequestMemoryPool& memoryPool) {
467 using Tag = aidl_hal::RequestMemoryPool::Tag;
468 switch (memoryPool.getTag()) {
469 case Tag::pool:
470 return unvalidatedConvert(memoryPool.get<Tag::pool>());
471 case Tag::token: {
472 const auto token = memoryPool.get<Tag::token>();
473 VERIFY_NON_NEGATIVE(token) << "Memory pool token must not be negative";
474 return static_cast<Request::MemoryDomainToken>(token);
475 }
476 }
477 return NN_ERROR() << "Invalid Request::MemoryPool tag " << underlyingType(memoryPool.getTag());
478 }
479
unvalidatedConvert(const aidl_hal::ErrorStatus & status)480 GeneralResult<ErrorStatus> unvalidatedConvert(const aidl_hal::ErrorStatus& status) {
481 switch (status) {
482 case aidl_hal::ErrorStatus::NONE:
483 case aidl_hal::ErrorStatus::DEVICE_UNAVAILABLE:
484 case aidl_hal::ErrorStatus::GENERAL_FAILURE:
485 case aidl_hal::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE:
486 case aidl_hal::ErrorStatus::INVALID_ARGUMENT:
487 case aidl_hal::ErrorStatus::MISSED_DEADLINE_TRANSIENT:
488 case aidl_hal::ErrorStatus::MISSED_DEADLINE_PERSISTENT:
489 case aidl_hal::ErrorStatus::RESOURCE_EXHAUSTED_TRANSIENT:
490 case aidl_hal::ErrorStatus::RESOURCE_EXHAUSTED_PERSISTENT:
491 return static_cast<ErrorStatus>(status);
492 }
493 return NN_ERROR() << "Invalid ErrorStatus " << underlyingType(status);
494 }
495
unvalidatedConvert(const aidl_hal::ExecutionPreference & executionPreference)496 GeneralResult<ExecutionPreference> unvalidatedConvert(
497 const aidl_hal::ExecutionPreference& executionPreference) {
498 return static_cast<ExecutionPreference>(executionPreference);
499 }
500
unvalidatedConvert(const NativeHandle & aidlNativeHandle)501 GeneralResult<SharedHandle> unvalidatedConvert(const NativeHandle& aidlNativeHandle) {
502 return std::make_shared<const Handle>(NN_TRY(unvalidatedConvertHelper(aidlNativeHandle)));
503 }
504
unvalidatedConvert(const std::vector<aidl_hal::Operation> & operations)505 GeneralResult<std::vector<Operation>> unvalidatedConvert(
506 const std::vector<aidl_hal::Operation>& operations) {
507 return unvalidatedConvertVec(operations);
508 }
509
unvalidatedConvert(const ndk::ScopedFileDescriptor & syncFence)510 GeneralResult<SyncFence> unvalidatedConvert(const ndk::ScopedFileDescriptor& syncFence) {
511 auto duplicatedFd = NN_TRY(dupFd(syncFence.get()));
512 return SyncFence::create(std::move(duplicatedFd));
513 }
514
convert(const aidl_hal::Capabilities & capabilities)515 GeneralResult<Capabilities> convert(const aidl_hal::Capabilities& capabilities) {
516 return validatedConvert(capabilities);
517 }
518
convert(const aidl_hal::DeviceType & deviceType)519 GeneralResult<DeviceType> convert(const aidl_hal::DeviceType& deviceType) {
520 return validatedConvert(deviceType);
521 }
522
convert(const aidl_hal::ErrorStatus & errorStatus)523 GeneralResult<ErrorStatus> convert(const aidl_hal::ErrorStatus& errorStatus) {
524 return validatedConvert(errorStatus);
525 }
526
convert(const aidl_hal::ExecutionPreference & executionPreference)527 GeneralResult<ExecutionPreference> convert(
528 const aidl_hal::ExecutionPreference& executionPreference) {
529 return validatedConvert(executionPreference);
530 }
531
convert(const aidl_hal::Memory & operand)532 GeneralResult<SharedMemory> convert(const aidl_hal::Memory& operand) {
533 return validatedConvert(operand);
534 }
535
convert(const aidl_hal::Model & model)536 GeneralResult<Model> convert(const aidl_hal::Model& model) {
537 return validatedConvert(model);
538 }
539
convert(const aidl_hal::OperandType & operandType)540 GeneralResult<OperandType> convert(const aidl_hal::OperandType& operandType) {
541 return validatedConvert(operandType);
542 }
543
convert(const aidl_hal::Priority & priority)544 GeneralResult<Priority> convert(const aidl_hal::Priority& priority) {
545 return validatedConvert(priority);
546 }
547
convert(const aidl_hal::Request & request)548 GeneralResult<Request> convert(const aidl_hal::Request& request) {
549 return validatedConvert(request);
550 }
551
convert(const aidl_hal::Timing & timing)552 GeneralResult<Timing> convert(const aidl_hal::Timing& timing) {
553 return validatedConvert(timing);
554 }
555
convert(const ndk::ScopedFileDescriptor & syncFence)556 GeneralResult<SyncFence> convert(const ndk::ScopedFileDescriptor& syncFence) {
557 return validatedConvert(syncFence);
558 }
559
convert(const std::vector<aidl_hal::Extension> & extension)560 GeneralResult<std::vector<Extension>> convert(const std::vector<aidl_hal::Extension>& extension) {
561 return validatedConvert(extension);
562 }
563
convert(const std::vector<aidl_hal::Memory> & memories)564 GeneralResult<std::vector<SharedMemory>> convert(const std::vector<aidl_hal::Memory>& memories) {
565 return validatedConvert(memories);
566 }
567
convert(const std::vector<aidl_hal::OutputShape> & outputShapes)568 GeneralResult<std::vector<OutputShape>> convert(
569 const std::vector<aidl_hal::OutputShape>& outputShapes) {
570 return validatedConvert(outputShapes);
571 }
572
toUnsigned(const std::vector<int32_t> & vec)573 GeneralResult<std::vector<uint32_t>> toUnsigned(const std::vector<int32_t>& vec) {
574 if (!std::all_of(vec.begin(), vec.end(), [](int32_t v) { return v >= 0; })) {
575 return NN_ERROR() << "Negative value passed to conversion from signed to unsigned";
576 }
577 return std::vector<uint32_t>(vec.begin(), vec.end());
578 }
579
580 } // namespace android::nn
581
582 namespace aidl::android::hardware::neuralnetworks::utils {
583 namespace {
584
585 template <typename Input>
586 using UnvalidatedConvertOutput =
587 std::decay_t<decltype(unvalidatedConvert(std::declval<Input>()).value())>;
588
589 template <typename Type>
unvalidatedConvertVec(const std::vector<Type> & arguments)590 nn::GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> unvalidatedConvertVec(
591 const std::vector<Type>& arguments) {
592 std::vector<UnvalidatedConvertOutput<Type>> halObject;
593 halObject.reserve(arguments.size());
594 for (const auto& argument : arguments) {
595 halObject.push_back(NN_TRY(unvalidatedConvert(argument)));
596 }
597 return halObject;
598 }
599
600 template <typename Type>
unvalidatedConvert(const std::vector<Type> & arguments)601 nn::GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> unvalidatedConvert(
602 const std::vector<Type>& arguments) {
603 return unvalidatedConvertVec(arguments);
604 }
605
606 template <typename Type>
validatedConvert(const Type & canonical)607 nn::GeneralResult<UnvalidatedConvertOutput<Type>> validatedConvert(const Type& canonical) {
608 NN_TRY(compliantVersion(canonical));
609 return utils::unvalidatedConvert(canonical);
610 }
611
612 template <typename Type>
validatedConvert(const std::vector<Type> & arguments)613 nn::GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> validatedConvert(
614 const std::vector<Type>& arguments) {
615 std::vector<UnvalidatedConvertOutput<Type>> halObject(arguments.size());
616 for (size_t i = 0; i < arguments.size(); ++i) {
617 halObject[i] = NN_TRY(validatedConvert(arguments[i]));
618 }
619 return halObject;
620 }
621
unvalidatedConvert(const nn::Handle & handle)622 nn::GeneralResult<common::NativeHandle> unvalidatedConvert(const nn::Handle& handle) {
623 common::NativeHandle aidlNativeHandle;
624 aidlNativeHandle.fds.reserve(handle.fds.size());
625 for (const auto& fd : handle.fds) {
626 auto duplicatedFd = NN_TRY(nn::dupFd(fd.get()));
627 aidlNativeHandle.fds.emplace_back(duplicatedFd.release());
628 }
629 aidlNativeHandle.ints = handle.ints;
630 return aidlNativeHandle;
631 }
632
633 // Helper template for std::visit
634 template <class... Ts>
635 struct overloaded : Ts... {
636 using Ts::operator()...;
637 };
638 template <class... Ts>
639 overloaded(Ts...)->overloaded<Ts...>;
640
aidlHandleFromNativeHandle(const native_handle_t & nativeHandle)641 nn::GeneralResult<common::NativeHandle> aidlHandleFromNativeHandle(
642 const native_handle_t& nativeHandle) {
643 auto handle = ::android::dupToAidl(&nativeHandle);
644 if (!std::all_of(handle.fds.begin(), handle.fds.end(),
645 [](const ndk::ScopedFileDescriptor& fd) { return fd.get() >= 0; })) {
646 return NN_ERROR() << "android::dupToAidl returned an invalid common::NativeHandle";
647 }
648 return handle;
649 }
650
unvalidatedConvert(const nn::Memory::Ashmem & memory)651 nn::GeneralResult<Memory> unvalidatedConvert(const nn::Memory::Ashmem& memory) {
652 if constexpr (std::numeric_limits<size_t>::max() > std::numeric_limits<int64_t>::max()) {
653 if (memory.size > std::numeric_limits<int64_t>::max()) {
654 return (
655 NN_ERROR()
656 << "Memory::Ashmem: size must be <= std::numeric_limits<int64_t>::max()")
657 .
658 operator nn::GeneralResult<Memory>();
659 }
660 }
661
662 auto fd = NN_TRY(nn::dupFd(memory.fd));
663 auto handle = common::Ashmem{
664 .fd = ndk::ScopedFileDescriptor(fd.release()),
665 .size = static_cast<int64_t>(memory.size),
666 };
667 return Memory::make<Memory::Tag::ashmem>(std::move(handle));
668 }
669
unvalidatedConvert(const nn::Memory::Fd & memory)670 nn::GeneralResult<Memory> unvalidatedConvert(const nn::Memory::Fd& memory) {
671 if constexpr (std::numeric_limits<size_t>::max() > std::numeric_limits<int64_t>::max()) {
672 if (memory.size > std::numeric_limits<int64_t>::max()) {
673 return (NN_ERROR() << "Memory::Fd: size must be <= std::numeric_limits<int64_t>::max()")
674 .
675 operator nn::GeneralResult<Memory>();
676 }
677 if (memory.offset > std::numeric_limits<int64_t>::max()) {
678 return (
679 NN_ERROR()
680 << "Memory::Fd: offset must be <= std::numeric_limits<int64_t>::max()")
681 .
682 operator nn::GeneralResult<Memory>();
683 }
684 }
685
686 auto fd = NN_TRY(nn::dupFd(memory.fd));
687 auto handle = common::MappableFile{
688 .length = static_cast<int64_t>(memory.size),
689 .prot = memory.prot,
690 .fd = ndk::ScopedFileDescriptor(fd.release()),
691 .offset = static_cast<int64_t>(memory.offset),
692 };
693 return Memory::make<Memory::Tag::mappableFile>(std::move(handle));
694 }
695
unvalidatedConvert(const nn::Memory::HardwareBuffer & memory)696 nn::GeneralResult<Memory> unvalidatedConvert(const nn::Memory::HardwareBuffer& memory) {
697 const native_handle_t* nativeHandle = AHardwareBuffer_getNativeHandle(memory.handle.get());
698 if (nativeHandle == nullptr) {
699 return (NN_ERROR() << "unvalidatedConvert failed because AHardwareBuffer_getNativeHandle "
700 "returned nullptr")
701 .
702 operator nn::GeneralResult<Memory>();
703 }
704
705 auto handle = NN_TRY(aidlHandleFromNativeHandle(*nativeHandle));
706
707 AHardwareBuffer_Desc desc;
708 AHardwareBuffer_describe(memory.handle.get(), &desc);
709
710 const auto description = graphics::common::HardwareBufferDescription{
711 .width = static_cast<int32_t>(desc.width),
712 .height = static_cast<int32_t>(desc.height),
713 .layers = static_cast<int32_t>(desc.layers),
714 .format = static_cast<graphics::common::PixelFormat>(desc.format),
715 .usage = static_cast<graphics::common::BufferUsage>(desc.usage),
716 .stride = static_cast<int32_t>(desc.stride),
717 };
718
719 auto hardwareBuffer = graphics::common::HardwareBuffer{
720 .description = std::move(description),
721 .handle = std::move(handle),
722 };
723 return Memory::make<Memory::Tag::hardwareBuffer>(std::move(hardwareBuffer));
724 }
725
unvalidatedConvert(const nn::Memory::Unknown &)726 nn::GeneralResult<Memory> unvalidatedConvert(const nn::Memory::Unknown& /*memory*/) {
727 return (NN_ERROR() << "Unable to convert Unknown memory type")
728 .
729 operator nn::GeneralResult<Memory>();
730 }
731
732 } // namespace
733
unvalidatedConvert(const nn::CacheToken & cacheToken)734 nn::GeneralResult<std::vector<uint8_t>> unvalidatedConvert(const nn::CacheToken& cacheToken) {
735 return std::vector<uint8_t>(cacheToken.begin(), cacheToken.end());
736 }
737
unvalidatedConvert(const nn::BufferDesc & bufferDesc)738 nn::GeneralResult<BufferDesc> unvalidatedConvert(const nn::BufferDesc& bufferDesc) {
739 return BufferDesc{.dimensions = NN_TRY(toSigned(bufferDesc.dimensions))};
740 }
741
unvalidatedConvert(const nn::BufferRole & bufferRole)742 nn::GeneralResult<BufferRole> unvalidatedConvert(const nn::BufferRole& bufferRole) {
743 VERIFY_LE_INT32_MAX(bufferRole.modelIndex)
744 << "BufferRole: modelIndex must be <= std::numeric_limits<int32_t>::max()";
745 VERIFY_LE_INT32_MAX(bufferRole.ioIndex)
746 << "BufferRole: ioIndex must be <= std::numeric_limits<int32_t>::max()";
747 return BufferRole{
748 .modelIndex = static_cast<int32_t>(bufferRole.modelIndex),
749 .ioIndex = static_cast<int32_t>(bufferRole.ioIndex),
750 .probability = bufferRole.probability,
751 };
752 }
753
unvalidatedConvert(const nn::MeasureTiming & measureTiming)754 nn::GeneralResult<bool> unvalidatedConvert(const nn::MeasureTiming& measureTiming) {
755 return measureTiming == nn::MeasureTiming::YES;
756 }
757
unvalidatedConvert(const nn::SharedHandle & sharedHandle)758 nn::GeneralResult<common::NativeHandle> unvalidatedConvert(const nn::SharedHandle& sharedHandle) {
759 CHECK(sharedHandle != nullptr);
760 return unvalidatedConvert(*sharedHandle);
761 }
762
unvalidatedConvert(const nn::SharedMemory & memory)763 nn::GeneralResult<Memory> unvalidatedConvert(const nn::SharedMemory& memory) {
764 if (memory == nullptr) {
765 return (NN_ERROR() << "Unable to convert nullptr memory")
766 .
767 operator nn::GeneralResult<Memory>();
768 }
769 return std::visit([](const auto& x) { return unvalidatedConvert(x); }, memory->handle);
770 }
771
unvalidatedConvert(const nn::ErrorStatus & errorStatus)772 nn::GeneralResult<ErrorStatus> unvalidatedConvert(const nn::ErrorStatus& errorStatus) {
773 switch (errorStatus) {
774 case nn::ErrorStatus::NONE:
775 case nn::ErrorStatus::DEVICE_UNAVAILABLE:
776 case nn::ErrorStatus::GENERAL_FAILURE:
777 case nn::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE:
778 case nn::ErrorStatus::INVALID_ARGUMENT:
779 case nn::ErrorStatus::MISSED_DEADLINE_TRANSIENT:
780 case nn::ErrorStatus::MISSED_DEADLINE_PERSISTENT:
781 case nn::ErrorStatus::RESOURCE_EXHAUSTED_TRANSIENT:
782 case nn::ErrorStatus::RESOURCE_EXHAUSTED_PERSISTENT:
783 return static_cast<ErrorStatus>(errorStatus);
784 default:
785 return ErrorStatus::GENERAL_FAILURE;
786 }
787 }
788
unvalidatedConvert(const nn::OutputShape & outputShape)789 nn::GeneralResult<OutputShape> unvalidatedConvert(const nn::OutputShape& outputShape) {
790 return OutputShape{.dimensions = NN_TRY(toSigned(outputShape.dimensions)),
791 .isSufficient = outputShape.isSufficient};
792 }
793
unvalidatedConvert(const nn::ExecutionPreference & executionPreference)794 nn::GeneralResult<ExecutionPreference> unvalidatedConvert(
795 const nn::ExecutionPreference& executionPreference) {
796 return static_cast<ExecutionPreference>(executionPreference);
797 }
798
unvalidatedConvert(const nn::OperandType & operandType)799 nn::GeneralResult<OperandType> unvalidatedConvert(const nn::OperandType& operandType) {
800 if (operandType == nn::OperandType::OEM || operandType == nn::OperandType::TENSOR_OEM_BYTE) {
801 return NN_ERROR() << "Unable to convert invalid OperandType " << operandType;
802 }
803 return static_cast<OperandType>(operandType);
804 }
805
unvalidatedConvert(const nn::Operand::LifeTime & operandLifeTime)806 nn::GeneralResult<OperandLifeTime> unvalidatedConvert(
807 const nn::Operand::LifeTime& operandLifeTime) {
808 return static_cast<OperandLifeTime>(operandLifeTime);
809 }
810
unvalidatedConvert(const nn::DataLocation & location)811 nn::GeneralResult<DataLocation> unvalidatedConvert(const nn::DataLocation& location) {
812 VERIFY_LE_INT32_MAX(location.poolIndex)
813 << "DataLocation: pool index must be <= std::numeric_limits<int32_t>::max()";
814 return DataLocation{
815 .poolIndex = static_cast<int32_t>(location.poolIndex),
816 .offset = static_cast<int64_t>(location.offset),
817 .length = static_cast<int64_t>(location.length),
818 };
819 }
820
unvalidatedConvert(const nn::Operand::ExtraParams & extraParams)821 nn::GeneralResult<std::optional<OperandExtraParams>> unvalidatedConvert(
822 const nn::Operand::ExtraParams& extraParams) {
823 return std::visit(
824 overloaded{
825 [](const nn::Operand::NoParams&)
826 -> nn::GeneralResult<std::optional<OperandExtraParams>> {
827 return std::nullopt;
828 },
829 [](const nn::Operand::SymmPerChannelQuantParams& symmPerChannelQuantParams)
830 -> nn::GeneralResult<std::optional<OperandExtraParams>> {
831 if (symmPerChannelQuantParams.channelDim >
832 std::numeric_limits<int32_t>::max()) {
833 // Using explicit type conversion because std::optional in successful
834 // result confuses the compiler.
835 return (NN_ERROR() << "symmPerChannelQuantParams.channelDim must be <= "
836 "std::numeric_limits<int32_t>::max(), received: "
837 << symmPerChannelQuantParams.channelDim)
838 .
839 operator nn::GeneralResult<std::optional<OperandExtraParams>>();
840 }
841 return OperandExtraParams::make<OperandExtraParams::Tag::channelQuant>(
842 SymmPerChannelQuantParams{
843 .scales = symmPerChannelQuantParams.scales,
844 .channelDim = static_cast<int32_t>(
845 symmPerChannelQuantParams.channelDim),
846 });
847 },
848 [](const nn::Operand::ExtensionParams& extensionParams)
849 -> nn::GeneralResult<std::optional<OperandExtraParams>> {
850 return OperandExtraParams::make<OperandExtraParams::Tag::extension>(
851 extensionParams);
852 },
853 },
854 extraParams);
855 }
856
unvalidatedConvert(const nn::Operand & operand)857 nn::GeneralResult<Operand> unvalidatedConvert(const nn::Operand& operand) {
858 return Operand{
859 .type = NN_TRY(unvalidatedConvert(operand.type)),
860 .dimensions = NN_TRY(toSigned(operand.dimensions)),
861 .scale = operand.scale,
862 .zeroPoint = operand.zeroPoint,
863 .lifetime = NN_TRY(unvalidatedConvert(operand.lifetime)),
864 .location = NN_TRY(unvalidatedConvert(operand.location)),
865 .extraParams = NN_TRY(unvalidatedConvert(operand.extraParams)),
866 };
867 }
868
unvalidatedConvert(const nn::OperationType & operationType)869 nn::GeneralResult<OperationType> unvalidatedConvert(const nn::OperationType& operationType) {
870 if (operationType == nn::OperationType::OEM_OPERATION) {
871 return NN_ERROR() << "Unable to convert invalid OperationType OEM_OPERATION";
872 }
873 return static_cast<OperationType>(operationType);
874 }
875
unvalidatedConvert(const nn::Operation & operation)876 nn::GeneralResult<Operation> unvalidatedConvert(const nn::Operation& operation) {
877 return Operation{
878 .type = NN_TRY(unvalidatedConvert(operation.type)),
879 .inputs = NN_TRY(toSigned(operation.inputs)),
880 .outputs = NN_TRY(toSigned(operation.outputs)),
881 };
882 }
883
unvalidatedConvert(const nn::Model::Subgraph & subgraph)884 nn::GeneralResult<Subgraph> unvalidatedConvert(const nn::Model::Subgraph& subgraph) {
885 return Subgraph{
886 .operands = NN_TRY(unvalidatedConvert(subgraph.operands)),
887 .operations = NN_TRY(unvalidatedConvert(subgraph.operations)),
888 .inputIndexes = NN_TRY(toSigned(subgraph.inputIndexes)),
889 .outputIndexes = NN_TRY(toSigned(subgraph.outputIndexes)),
890 };
891 }
892
unvalidatedConvert(const nn::Model::OperandValues & operandValues)893 nn::GeneralResult<std::vector<uint8_t>> unvalidatedConvert(
894 const nn::Model::OperandValues& operandValues) {
895 return std::vector<uint8_t>(operandValues.data(), operandValues.data() + operandValues.size());
896 }
897
unvalidatedConvert(const nn::Model::ExtensionNameAndPrefix & extensionNameToPrefix)898 nn::GeneralResult<ExtensionNameAndPrefix> unvalidatedConvert(
899 const nn::Model::ExtensionNameAndPrefix& extensionNameToPrefix) {
900 return ExtensionNameAndPrefix{
901 .name = extensionNameToPrefix.name,
902 .prefix = extensionNameToPrefix.prefix,
903 };
904 }
905
unvalidatedConvert(const nn::Model & model)906 nn::GeneralResult<Model> unvalidatedConvert(const nn::Model& model) {
907 return Model{
908 .main = NN_TRY(unvalidatedConvert(model.main)),
909 .referenced = NN_TRY(unvalidatedConvert(model.referenced)),
910 .operandValues = NN_TRY(unvalidatedConvert(model.operandValues)),
911 .pools = NN_TRY(unvalidatedConvert(model.pools)),
912 .relaxComputationFloat32toFloat16 = model.relaxComputationFloat32toFloat16,
913 .extensionNameToPrefix = NN_TRY(unvalidatedConvert(model.extensionNameToPrefix)),
914 };
915 }
916
unvalidatedConvert(const nn::Priority & priority)917 nn::GeneralResult<Priority> unvalidatedConvert(const nn::Priority& priority) {
918 return static_cast<Priority>(priority);
919 }
920
unvalidatedConvert(const nn::Request & request)921 nn::GeneralResult<Request> unvalidatedConvert(const nn::Request& request) {
922 return Request{
923 .inputs = NN_TRY(unvalidatedConvert(request.inputs)),
924 .outputs = NN_TRY(unvalidatedConvert(request.outputs)),
925 .pools = NN_TRY(unvalidatedConvert(request.pools)),
926 };
927 }
928
unvalidatedConvert(const nn::Request::Argument & requestArgument)929 nn::GeneralResult<RequestArgument> unvalidatedConvert(
930 const nn::Request::Argument& requestArgument) {
931 if (requestArgument.lifetime == nn::Request::Argument::LifeTime::POINTER) {
932 return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT)
933 << "Request cannot be unvalidatedConverted because it contains pointer-based memory";
934 }
935 const bool hasNoValue = requestArgument.lifetime == nn::Request::Argument::LifeTime::NO_VALUE;
936 return RequestArgument{
937 .hasNoValue = hasNoValue,
938 .location = NN_TRY(unvalidatedConvert(requestArgument.location)),
939 .dimensions = NN_TRY(toSigned(requestArgument.dimensions)),
940 };
941 }
942
unvalidatedConvert(const nn::Request::MemoryPool & memoryPool)943 nn::GeneralResult<RequestMemoryPool> unvalidatedConvert(const nn::Request::MemoryPool& memoryPool) {
944 return std::visit(
945 overloaded{
946 [](const nn::SharedMemory& memory) -> nn::GeneralResult<RequestMemoryPool> {
947 return RequestMemoryPool::make<RequestMemoryPool::Tag::pool>(
948 NN_TRY(unvalidatedConvert(memory)));
949 },
950 [](const nn::Request::MemoryDomainToken& token)
951 -> nn::GeneralResult<RequestMemoryPool> {
952 return RequestMemoryPool::make<RequestMemoryPool::Tag::token>(
953 underlyingType(token));
954 },
955 [](const nn::SharedBuffer& /*buffer*/) {
956 return (NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
957 << "Unable to make memory pool from IBuffer")
958 .
959 operator nn::GeneralResult<RequestMemoryPool>();
960 },
961 },
962 memoryPool);
963 }
964
unvalidatedConvert(const nn::Timing & timing)965 nn::GeneralResult<Timing> unvalidatedConvert(const nn::Timing& timing) {
966 return Timing{
967 .timeOnDeviceNs = NN_TRY(unvalidatedConvert(timing.timeOnDevice)),
968 .timeInDriverNs = NN_TRY(unvalidatedConvert(timing.timeInDriver)),
969 };
970 }
971
unvalidatedConvert(const nn::Duration & duration)972 nn::GeneralResult<int64_t> unvalidatedConvert(const nn::Duration& duration) {
973 if (duration < nn::Duration::zero()) {
974 return NN_ERROR() << "Unable to convert invalid (negative) duration";
975 }
976 constexpr std::chrono::nanoseconds::rep kIntMax = std::numeric_limits<int64_t>::max();
977 const auto count = duration.count();
978 return static_cast<int64_t>(std::min(count, kIntMax));
979 }
980
unvalidatedConvert(const nn::OptionalDuration & optionalDuration)981 nn::GeneralResult<int64_t> unvalidatedConvert(const nn::OptionalDuration& optionalDuration) {
982 if (!optionalDuration.has_value()) {
983 return kNoTiming;
984 }
985 return unvalidatedConvert(optionalDuration.value());
986 }
987
unvalidatedConvert(const nn::OptionalTimePoint & optionalTimePoint)988 nn::GeneralResult<int64_t> unvalidatedConvert(const nn::OptionalTimePoint& optionalTimePoint) {
989 if (!optionalTimePoint.has_value()) {
990 return kNoTiming;
991 }
992 return unvalidatedConvert(optionalTimePoint->time_since_epoch());
993 }
994
unvalidatedConvert(const nn::SyncFence & syncFence)995 nn::GeneralResult<ndk::ScopedFileDescriptor> unvalidatedConvert(const nn::SyncFence& syncFence) {
996 auto duplicatedFd = NN_TRY(nn::dupFd(syncFence.getFd()));
997 return ndk::ScopedFileDescriptor(duplicatedFd.release());
998 }
999
unvalidatedConvertCache(const nn::SharedHandle & handle)1000 nn::GeneralResult<ndk::ScopedFileDescriptor> unvalidatedConvertCache(
1001 const nn::SharedHandle& handle) {
1002 if (handle->ints.size() != 0) {
1003 NN_ERROR() << "Cache handle must not contain ints";
1004 }
1005 if (handle->fds.size() != 1) {
1006 NN_ERROR() << "Cache handle must contain exactly one fd but contains "
1007 << handle->fds.size();
1008 }
1009 auto duplicatedFd = NN_TRY(nn::dupFd(handle->fds.front().get()));
1010 return ndk::ScopedFileDescriptor(duplicatedFd.release());
1011 }
1012
convert(const nn::CacheToken & cacheToken)1013 nn::GeneralResult<std::vector<uint8_t>> convert(const nn::CacheToken& cacheToken) {
1014 return validatedConvert(cacheToken);
1015 }
1016
convert(const nn::BufferDesc & bufferDesc)1017 nn::GeneralResult<BufferDesc> convert(const nn::BufferDesc& bufferDesc) {
1018 return validatedConvert(bufferDesc);
1019 }
1020
convert(const nn::MeasureTiming & measureTiming)1021 nn::GeneralResult<bool> convert(const nn::MeasureTiming& measureTiming) {
1022 return validatedConvert(measureTiming);
1023 }
1024
convert(const nn::SharedMemory & memory)1025 nn::GeneralResult<Memory> convert(const nn::SharedMemory& memory) {
1026 return validatedConvert(memory);
1027 }
1028
convert(const nn::ErrorStatus & errorStatus)1029 nn::GeneralResult<ErrorStatus> convert(const nn::ErrorStatus& errorStatus) {
1030 return validatedConvert(errorStatus);
1031 }
1032
convert(const nn::ExecutionPreference & executionPreference)1033 nn::GeneralResult<ExecutionPreference> convert(const nn::ExecutionPreference& executionPreference) {
1034 return validatedConvert(executionPreference);
1035 }
1036
convert(const nn::Model & model)1037 nn::GeneralResult<Model> convert(const nn::Model& model) {
1038 return validatedConvert(model);
1039 }
1040
convert(const nn::Priority & priority)1041 nn::GeneralResult<Priority> convert(const nn::Priority& priority) {
1042 return validatedConvert(priority);
1043 }
1044
convert(const nn::Request & request)1045 nn::GeneralResult<Request> convert(const nn::Request& request) {
1046 return validatedConvert(request);
1047 }
1048
convert(const nn::Timing & timing)1049 nn::GeneralResult<Timing> convert(const nn::Timing& timing) {
1050 return validatedConvert(timing);
1051 }
1052
convert(const nn::OptionalDuration & optionalDuration)1053 nn::GeneralResult<int64_t> convert(const nn::OptionalDuration& optionalDuration) {
1054 return validatedConvert(optionalDuration);
1055 }
1056
convert(const nn::OptionalTimePoint & outputShapes)1057 nn::GeneralResult<int64_t> convert(const nn::OptionalTimePoint& outputShapes) {
1058 return validatedConvert(outputShapes);
1059 }
1060
convert(const std::vector<nn::BufferRole> & bufferRoles)1061 nn::GeneralResult<std::vector<BufferRole>> convert(const std::vector<nn::BufferRole>& bufferRoles) {
1062 return validatedConvert(bufferRoles);
1063 }
1064
convert(const std::vector<nn::OutputShape> & outputShapes)1065 nn::GeneralResult<std::vector<OutputShape>> convert(
1066 const std::vector<nn::OutputShape>& outputShapes) {
1067 return validatedConvert(outputShapes);
1068 }
1069
convert(const std::vector<nn::SharedHandle> & cacheHandles)1070 nn::GeneralResult<std::vector<ndk::ScopedFileDescriptor>> convert(
1071 const std::vector<nn::SharedHandle>& cacheHandles) {
1072 const auto version = NN_TRY(hal::utils::makeGeneralFailure(nn::validate(cacheHandles)));
1073 if (version > kVersion) {
1074 return NN_ERROR() << "Insufficient version: " << version << " vs required " << kVersion;
1075 }
1076 std::vector<ndk::ScopedFileDescriptor> cacheFds;
1077 cacheFds.reserve(cacheHandles.size());
1078 for (const auto& cacheHandle : cacheHandles) {
1079 cacheFds.push_back(NN_TRY(unvalidatedConvertCache(cacheHandle)));
1080 }
1081 return cacheFds;
1082 }
1083
convert(const std::vector<nn::SyncFence> & syncFences)1084 nn::GeneralResult<std::vector<ndk::ScopedFileDescriptor>> convert(
1085 const std::vector<nn::SyncFence>& syncFences) {
1086 return validatedConvert(syncFences);
1087 }
1088
toSigned(const std::vector<uint32_t> & vec)1089 nn::GeneralResult<std::vector<int32_t>> toSigned(const std::vector<uint32_t>& vec) {
1090 if (!std::all_of(vec.begin(), vec.end(),
1091 [](uint32_t v) { return v <= std::numeric_limits<int32_t>::max(); })) {
1092 return NN_ERROR() << "Vector contains a value that doesn't fit into int32_t.";
1093 }
1094 return std::vector<int32_t>(vec.begin(), vec.end());
1095 }
1096
1097 } // namespace aidl::android::hardware::neuralnetworks::utils
1098