1 /*
2 * Copyright (C) 2020 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "Conversions.h"
18
19 #include <android-base/logging.h>
20 #include <android/hardware/neuralnetworks/1.3/types.h>
21 #include <nnapi/OperandTypes.h>
22 #include <nnapi/OperationTypes.h>
23 #include <nnapi/Result.h>
24 #include <nnapi/SharedMemory.h>
25 #include <nnapi/TypeUtils.h>
26 #include <nnapi/Types.h>
27 #include <nnapi/Validation.h>
28 #include <nnapi/hal/1.0/Conversions.h>
29 #include <nnapi/hal/1.2/Conversions.h>
30 #include <nnapi/hal/CommonUtils.h>
31 #include <nnapi/hal/HandleError.h>
32
33 #include <algorithm>
34 #include <chrono>
35 #include <functional>
36 #include <iterator>
37 #include <limits>
38 #include <type_traits>
39 #include <utility>
40
41 #include "Utils.h"
42
43 namespace {
44
makeNanosFromUint64(uint64_t nanoseconds)45 std::chrono::nanoseconds makeNanosFromUint64(uint64_t nanoseconds) {
46 constexpr auto kMaxCount = std::chrono::nanoseconds::max().count();
47 using CommonType = std::common_type_t<std::chrono::nanoseconds::rep, uint64_t>;
48 const auto count = std::min<CommonType>(kMaxCount, nanoseconds);
49 return std::chrono::nanoseconds{static_cast<std::chrono::nanoseconds::rep>(count)};
50 }
51
makeUint64FromNanos(std::chrono::nanoseconds nanoseconds)52 uint64_t makeUint64FromNanos(std::chrono::nanoseconds nanoseconds) {
53 if (nanoseconds < std::chrono::nanoseconds::zero()) {
54 return 0;
55 }
56 constexpr auto kMaxCount = std::numeric_limits<uint64_t>::max();
57 using CommonType = std::common_type_t<std::chrono::nanoseconds::rep, uint64_t>;
58 const auto count = std::min<CommonType>(kMaxCount, nanoseconds.count());
59 return static_cast<uint64_t>(count);
60 }
61
62 template <typename Type>
underlyingType(Type value)63 constexpr std::underlying_type_t<Type> underlyingType(Type value) {
64 return static_cast<std::underlying_type_t<Type>>(value);
65 }
66
67 } // namespace
68
69 namespace android::nn {
70 namespace {
71
72 using hardware::hidl_vec;
73
74 template <typename Input>
75 using UnvalidatedConvertOutput =
76 std::decay_t<decltype(unvalidatedConvert(std::declval<Input>()).value())>;
77
78 template <typename Type>
unvalidatedConvert(const hidl_vec<Type> & arguments)79 GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> unvalidatedConvert(
80 const hidl_vec<Type>& arguments) {
81 std::vector<UnvalidatedConvertOutput<Type>> canonical;
82 canonical.reserve(arguments.size());
83 for (const auto& argument : arguments) {
84 canonical.push_back(NN_TRY(nn::unvalidatedConvert(argument)));
85 }
86 return canonical;
87 }
88
89 template <typename Type>
validatedConvert(const Type & halObject)90 GeneralResult<UnvalidatedConvertOutput<Type>> validatedConvert(const Type& halObject) {
91 auto canonical = NN_TRY(nn::unvalidatedConvert(halObject));
92 NN_TRY(hal::V1_3::utils::compliantVersion(canonical));
93 return canonical;
94 }
95
96 template <typename Type>
validatedConvert(const hidl_vec<Type> & arguments)97 GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> validatedConvert(
98 const hidl_vec<Type>& arguments) {
99 std::vector<UnvalidatedConvertOutput<Type>> canonical;
100 canonical.reserve(arguments.size());
101 for (const auto& argument : arguments) {
102 canonical.push_back(NN_TRY(validatedConvert(argument)));
103 }
104 return canonical;
105 }
106
107 } // anonymous namespace
108
unvalidatedConvert(const hal::V1_3::OperandType & operandType)109 GeneralResult<OperandType> unvalidatedConvert(const hal::V1_3::OperandType& operandType) {
110 return static_cast<OperandType>(operandType);
111 }
112
unvalidatedConvert(const hal::V1_3::OperationType & operationType)113 GeneralResult<OperationType> unvalidatedConvert(const hal::V1_3::OperationType& operationType) {
114 return static_cast<OperationType>(operationType);
115 }
116
unvalidatedConvert(const hal::V1_3::Priority & priority)117 GeneralResult<Priority> unvalidatedConvert(const hal::V1_3::Priority& priority) {
118 return static_cast<Priority>(priority);
119 }
120
unvalidatedConvert(const hal::V1_3::Capabilities & capabilities)121 GeneralResult<Capabilities> unvalidatedConvert(const hal::V1_3::Capabilities& capabilities) {
122 const bool validOperandTypes = std::all_of(
123 capabilities.operandPerformance.begin(), capabilities.operandPerformance.end(),
124 [](const hal::V1_3::Capabilities::OperandPerformance& operandPerformance) {
125 return validatedConvert(operandPerformance.type).has_value();
126 });
127 if (!validOperandTypes) {
128 return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
129 << "Invalid OperandType when unvalidatedConverting OperandPerformance in "
130 "Capabilities";
131 }
132
133 auto operandPerformance = NN_TRY(unvalidatedConvert(capabilities.operandPerformance));
134 auto table = NN_TRY(hal::utils::makeGeneralFailure(
135 Capabilities::OperandPerformanceTable::create(std::move(operandPerformance)),
136 nn::ErrorStatus::GENERAL_FAILURE));
137
138 return Capabilities{
139 .relaxedFloat32toFloat16PerformanceScalar = NN_TRY(
140 unvalidatedConvert(capabilities.relaxedFloat32toFloat16PerformanceScalar)),
141 .relaxedFloat32toFloat16PerformanceTensor = NN_TRY(
142 unvalidatedConvert(capabilities.relaxedFloat32toFloat16PerformanceTensor)),
143 .operandPerformance = std::move(table),
144 .ifPerformance = NN_TRY(unvalidatedConvert(capabilities.ifPerformance)),
145 .whilePerformance = NN_TRY(unvalidatedConvert(capabilities.whilePerformance)),
146 };
147 }
148
unvalidatedConvert(const hal::V1_3::Capabilities::OperandPerformance & operandPerformance)149 GeneralResult<Capabilities::OperandPerformance> unvalidatedConvert(
150 const hal::V1_3::Capabilities::OperandPerformance& operandPerformance) {
151 return Capabilities::OperandPerformance{
152 .type = NN_TRY(unvalidatedConvert(operandPerformance.type)),
153 .info = NN_TRY(unvalidatedConvert(operandPerformance.info)),
154 };
155 }
156
unvalidatedConvert(const hal::V1_3::Operation & operation)157 GeneralResult<Operation> unvalidatedConvert(const hal::V1_3::Operation& operation) {
158 return Operation{
159 .type = NN_TRY(unvalidatedConvert(operation.type)),
160 .inputs = operation.inputs,
161 .outputs = operation.outputs,
162 };
163 }
164
unvalidatedConvert(const hal::V1_3::OperandLifeTime & operandLifeTime)165 GeneralResult<Operand::LifeTime> unvalidatedConvert(
166 const hal::V1_3::OperandLifeTime& operandLifeTime) {
167 return static_cast<Operand::LifeTime>(operandLifeTime);
168 }
169
unvalidatedConvert(const hal::V1_3::Operand & operand)170 GeneralResult<Operand> unvalidatedConvert(const hal::V1_3::Operand& operand) {
171 return Operand{
172 .type = NN_TRY(unvalidatedConvert(operand.type)),
173 .dimensions = operand.dimensions,
174 .scale = operand.scale,
175 .zeroPoint = operand.zeroPoint,
176 .lifetime = NN_TRY(unvalidatedConvert(operand.lifetime)),
177 .location = NN_TRY(unvalidatedConvert(operand.location)),
178 .extraParams = NN_TRY(unvalidatedConvert(operand.extraParams)),
179 };
180 }
181
unvalidatedConvert(const hal::V1_3::Model & model)182 GeneralResult<Model> unvalidatedConvert(const hal::V1_3::Model& model) {
183 return Model{
184 .main = NN_TRY(unvalidatedConvert(model.main)),
185 .referenced = NN_TRY(unvalidatedConvert(model.referenced)),
186 .operandValues = NN_TRY(unvalidatedConvert(model.operandValues)),
187 .pools = NN_TRY(unvalidatedConvert(model.pools)),
188 .relaxComputationFloat32toFloat16 = model.relaxComputationFloat32toFloat16,
189 .extensionNameToPrefix = NN_TRY(unvalidatedConvert(model.extensionNameToPrefix)),
190 };
191 }
192
unvalidatedConvert(const hal::V1_3::Subgraph & subgraph)193 GeneralResult<Model::Subgraph> unvalidatedConvert(const hal::V1_3::Subgraph& subgraph) {
194 auto operations = NN_TRY(unvalidatedConvert(subgraph.operations));
195
196 // Verify number of consumers.
197 const auto numberOfConsumers =
198 NN_TRY(hal::utils::countNumberOfConsumers(subgraph.operands.size(), operations));
199 CHECK(subgraph.operands.size() == numberOfConsumers.size());
200 for (size_t i = 0; i < subgraph.operands.size(); ++i) {
201 if (subgraph.operands[i].numberOfConsumers != numberOfConsumers[i]) {
202 return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
203 << "Invalid numberOfConsumers for operand " << i << ", expected "
204 << numberOfConsumers[i] << " but found "
205 << subgraph.operands[i].numberOfConsumers;
206 }
207 }
208
209 return Model::Subgraph{
210 .operands = NN_TRY(unvalidatedConvert(subgraph.operands)),
211 .operations = std::move(operations),
212 .inputIndexes = subgraph.inputIndexes,
213 .outputIndexes = subgraph.outputIndexes,
214 };
215 }
216
unvalidatedConvert(const hal::V1_3::BufferDesc & bufferDesc)217 GeneralResult<BufferDesc> unvalidatedConvert(const hal::V1_3::BufferDesc& bufferDesc) {
218 return BufferDesc{.dimensions = bufferDesc.dimensions};
219 }
220
unvalidatedConvert(const hal::V1_3::BufferRole & bufferRole)221 GeneralResult<BufferRole> unvalidatedConvert(const hal::V1_3::BufferRole& bufferRole) {
222 return BufferRole{
223 .modelIndex = bufferRole.modelIndex,
224 .ioIndex = bufferRole.ioIndex,
225 .probability = bufferRole.frequency,
226 };
227 }
228
unvalidatedConvert(const hal::V1_3::Request & request)229 GeneralResult<Request> unvalidatedConvert(const hal::V1_3::Request& request) {
230 return Request{
231 .inputs = NN_TRY(unvalidatedConvert(request.inputs)),
232 .outputs = NN_TRY(unvalidatedConvert(request.outputs)),
233 .pools = NN_TRY(unvalidatedConvert(request.pools)),
234 };
235 }
236
unvalidatedConvert(const hal::V1_3::Request::MemoryPool & memoryPool)237 GeneralResult<Request::MemoryPool> unvalidatedConvert(
238 const hal::V1_3::Request::MemoryPool& memoryPool) {
239 using Discriminator = hal::V1_3::Request::MemoryPool::hidl_discriminator;
240 switch (memoryPool.getDiscriminator()) {
241 case Discriminator::hidlMemory:
242 return hal::utils::createSharedMemoryFromHidlMemory(memoryPool.hidlMemory());
243 case Discriminator::token:
244 return static_cast<Request::MemoryDomainToken>(memoryPool.token());
245 }
246 return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
247 << "Invalid Request::MemoryPool discriminator "
248 << underlyingType(memoryPool.getDiscriminator());
249 }
250
unvalidatedConvert(const hal::V1_3::OptionalTimePoint & optionalTimePoint)251 GeneralResult<OptionalTimePoint> unvalidatedConvert(
252 const hal::V1_3::OptionalTimePoint& optionalTimePoint) {
253 using Discriminator = hal::V1_3::OptionalTimePoint::hidl_discriminator;
254 switch (optionalTimePoint.getDiscriminator()) {
255 case Discriminator::none:
256 return {};
257 case Discriminator::nanosecondsSinceEpoch: {
258 const auto currentSteadyTime = std::chrono::steady_clock::now();
259 const auto currentBootTime = Clock::now();
260
261 const auto timeSinceEpoch =
262 makeNanosFromUint64(optionalTimePoint.nanosecondsSinceEpoch());
263 const auto steadyTimePoint = std::chrono::steady_clock::time_point{timeSinceEpoch};
264
265 // Both steadyTimePoint and currentSteadyTime are guaranteed to be non-negative, so this
266 // subtraction will never overflow or underflow.
267 const auto timeRemaining = steadyTimePoint - currentSteadyTime;
268
269 // currentBootTime is guaranteed to be non-negative, so this code only protects against
270 // an overflow.
271 nn::TimePoint bootTimePoint;
272 constexpr auto kZeroNano = std::chrono::nanoseconds::zero();
273 constexpr auto kMaxTime = nn::TimePoint::max();
274 if (timeRemaining > kZeroNano && currentBootTime > kMaxTime - timeRemaining) {
275 bootTimePoint = kMaxTime;
276 } else {
277 bootTimePoint = currentBootTime + timeRemaining;
278 }
279
280 constexpr auto kZeroTime = nn::TimePoint{};
281 return std::max(bootTimePoint, kZeroTime);
282 }
283 }
284 return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
285 << "Invalid OptionalTimePoint discriminator "
286 << underlyingType(optionalTimePoint.getDiscriminator());
287 }
288
unvalidatedConvert(const hal::V1_3::OptionalTimeoutDuration & optionalTimeoutDuration)289 GeneralResult<OptionalDuration> unvalidatedConvert(
290 const hal::V1_3::OptionalTimeoutDuration& optionalTimeoutDuration) {
291 using Discriminator = hal::V1_3::OptionalTimeoutDuration::hidl_discriminator;
292 switch (optionalTimeoutDuration.getDiscriminator()) {
293 case Discriminator::none:
294 return {};
295 case Discriminator::nanoseconds:
296 return Duration(optionalTimeoutDuration.nanoseconds());
297 }
298 return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
299 << "Invalid OptionalTimeoutDuration discriminator "
300 << underlyingType(optionalTimeoutDuration.getDiscriminator());
301 }
302
unvalidatedConvert(const hal::V1_3::ErrorStatus & status)303 GeneralResult<ErrorStatus> unvalidatedConvert(const hal::V1_3::ErrorStatus& status) {
304 switch (status) {
305 case hal::V1_3::ErrorStatus::NONE:
306 case hal::V1_3::ErrorStatus::DEVICE_UNAVAILABLE:
307 case hal::V1_3::ErrorStatus::GENERAL_FAILURE:
308 case hal::V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE:
309 case hal::V1_3::ErrorStatus::INVALID_ARGUMENT:
310 case hal::V1_3::ErrorStatus::MISSED_DEADLINE_TRANSIENT:
311 case hal::V1_3::ErrorStatus::MISSED_DEADLINE_PERSISTENT:
312 case hal::V1_3::ErrorStatus::RESOURCE_EXHAUSTED_TRANSIENT:
313 case hal::V1_3::ErrorStatus::RESOURCE_EXHAUSTED_PERSISTENT:
314 return static_cast<ErrorStatus>(status);
315 }
316 return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
317 << "Invalid ErrorStatus " << underlyingType(status);
318 }
319
convert(const hal::V1_3::Priority & priority)320 GeneralResult<Priority> convert(const hal::V1_3::Priority& priority) {
321 return validatedConvert(priority);
322 }
323
convert(const hal::V1_3::Capabilities & capabilities)324 GeneralResult<Capabilities> convert(const hal::V1_3::Capabilities& capabilities) {
325 return validatedConvert(capabilities);
326 }
327
convert(const hal::V1_3::Model & model)328 GeneralResult<Model> convert(const hal::V1_3::Model& model) {
329 return validatedConvert(model);
330 }
331
convert(const hal::V1_3::BufferDesc & bufferDesc)332 GeneralResult<BufferDesc> convert(const hal::V1_3::BufferDesc& bufferDesc) {
333 return validatedConvert(bufferDesc);
334 }
335
convert(const hal::V1_3::Request & request)336 GeneralResult<Request> convert(const hal::V1_3::Request& request) {
337 return validatedConvert(request);
338 }
339
convert(const hal::V1_3::OptionalTimePoint & optionalTimePoint)340 GeneralResult<OptionalTimePoint> convert(const hal::V1_3::OptionalTimePoint& optionalTimePoint) {
341 return validatedConvert(optionalTimePoint);
342 }
343
convert(const hal::V1_3::OptionalTimeoutDuration & optionalTimeoutDuration)344 GeneralResult<OptionalDuration> convert(
345 const hal::V1_3::OptionalTimeoutDuration& optionalTimeoutDuration) {
346 return validatedConvert(optionalTimeoutDuration);
347 }
348
convert(const hal::V1_3::ErrorStatus & errorStatus)349 GeneralResult<ErrorStatus> convert(const hal::V1_3::ErrorStatus& errorStatus) {
350 return validatedConvert(errorStatus);
351 }
352
convert(const hardware::hidl_handle & handle)353 GeneralResult<SharedHandle> convert(const hardware::hidl_handle& handle) {
354 return validatedConvert(handle);
355 }
356
convert(const hardware::hidl_vec<hal::V1_3::BufferRole> & bufferRoles)357 GeneralResult<std::vector<BufferRole>> convert(
358 const hardware::hidl_vec<hal::V1_3::BufferRole>& bufferRoles) {
359 return validatedConvert(bufferRoles);
360 }
361
362 } // namespace android::nn
363
364 namespace android::hardware::neuralnetworks::V1_3::utils {
365 namespace {
366
367 using utils::unvalidatedConvert;
368
unvalidatedConvert(const nn::Capabilities::PerformanceInfo & performanceInfo)369 nn::GeneralResult<V1_0::PerformanceInfo> unvalidatedConvert(
370 const nn::Capabilities::PerformanceInfo& performanceInfo) {
371 return V1_0::utils::unvalidatedConvert(performanceInfo);
372 }
373
unvalidatedConvert(const nn::DataLocation & dataLocation)374 nn::GeneralResult<V1_0::DataLocation> unvalidatedConvert(const nn::DataLocation& dataLocation) {
375 return V1_0::utils::unvalidatedConvert(dataLocation);
376 }
377
unvalidatedConvert(const nn::Model::OperandValues & operandValues)378 nn::GeneralResult<hidl_vec<uint8_t>> unvalidatedConvert(
379 const nn::Model::OperandValues& operandValues) {
380 return V1_0::utils::unvalidatedConvert(operandValues);
381 }
382
unvalidatedConvert(const nn::SharedHandle & handle)383 nn::GeneralResult<hidl_handle> unvalidatedConvert(const nn::SharedHandle& handle) {
384 return V1_2::utils::unvalidatedConvert(handle);
385 }
386
unvalidatedConvert(const nn::SharedMemory & memory)387 nn::GeneralResult<hidl_memory> unvalidatedConvert(const nn::SharedMemory& memory) {
388 return V1_0::utils::unvalidatedConvert(memory);
389 }
390
unvalidatedConvert(const nn::Request::Argument & argument)391 nn::GeneralResult<V1_0::RequestArgument> unvalidatedConvert(const nn::Request::Argument& argument) {
392 return V1_0::utils::unvalidatedConvert(argument);
393 }
394
unvalidatedConvert(const nn::Operand::ExtraParams & extraParams)395 nn::GeneralResult<V1_2::Operand::ExtraParams> unvalidatedConvert(
396 const nn::Operand::ExtraParams& extraParams) {
397 return V1_2::utils::unvalidatedConvert(extraParams);
398 }
399
unvalidatedConvert(const nn::Model::ExtensionNameAndPrefix & extensionNameAndPrefix)400 nn::GeneralResult<V1_2::Model::ExtensionNameAndPrefix> unvalidatedConvert(
401 const nn::Model::ExtensionNameAndPrefix& extensionNameAndPrefix) {
402 return V1_2::utils::unvalidatedConvert(extensionNameAndPrefix);
403 }
404
405 template <typename Input>
406 using UnvalidatedConvertOutput =
407 std::decay_t<decltype(unvalidatedConvert(std::declval<Input>()).value())>;
408
409 template <typename Type>
unvalidatedConvert(const std::vector<Type> & arguments)410 nn::GeneralResult<hidl_vec<UnvalidatedConvertOutput<Type>>> unvalidatedConvert(
411 const std::vector<Type>& arguments) {
412 hidl_vec<UnvalidatedConvertOutput<Type>> halObject(arguments.size());
413 for (size_t i = 0; i < arguments.size(); ++i) {
414 halObject[i] = NN_TRY(unvalidatedConvert(arguments[i]));
415 }
416 return halObject;
417 }
418
makeMemoryPool(const nn::SharedMemory & memory)419 nn::GeneralResult<Request::MemoryPool> makeMemoryPool(const nn::SharedMemory& memory) {
420 Request::MemoryPool ret;
421 ret.hidlMemory(NN_TRY(unvalidatedConvert(memory)));
422 return ret;
423 }
424
makeMemoryPool(const nn::Request::MemoryDomainToken & token)425 nn::GeneralResult<Request::MemoryPool> makeMemoryPool(const nn::Request::MemoryDomainToken& token) {
426 Request::MemoryPool ret;
427 ret.token(underlyingType(token));
428 return ret;
429 }
430
makeMemoryPool(const nn::SharedBuffer &)431 nn::GeneralResult<Request::MemoryPool> makeMemoryPool(const nn::SharedBuffer& /*buffer*/) {
432 return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE) << "Unable to make memory pool from IBuffer";
433 }
434
435 using utils::unvalidatedConvert;
436
437 template <typename Type>
validatedConvert(const Type & canonical)438 nn::GeneralResult<UnvalidatedConvertOutput<Type>> validatedConvert(const Type& canonical) {
439 NN_TRY(compliantVersion(canonical));
440 return unvalidatedConvert(canonical);
441 }
442
443 template <typename Type>
validatedConvert(const std::vector<Type> & arguments)444 nn::GeneralResult<hidl_vec<UnvalidatedConvertOutput<Type>>> validatedConvert(
445 const std::vector<Type>& arguments) {
446 hidl_vec<UnvalidatedConvertOutput<Type>> halObject(arguments.size());
447 for (size_t i = 0; i < arguments.size(); ++i) {
448 halObject[i] = NN_TRY(validatedConvert(arguments[i]));
449 }
450 return halObject;
451 }
452
453 } // anonymous namespace
454
unvalidatedConvert(const nn::OperandType & operandType)455 nn::GeneralResult<OperandType> unvalidatedConvert(const nn::OperandType& operandType) {
456 return static_cast<OperandType>(operandType);
457 }
458
unvalidatedConvert(const nn::OperationType & operationType)459 nn::GeneralResult<OperationType> unvalidatedConvert(const nn::OperationType& operationType) {
460 return static_cast<OperationType>(operationType);
461 }
462
unvalidatedConvert(const nn::Priority & priority)463 nn::GeneralResult<Priority> unvalidatedConvert(const nn::Priority& priority) {
464 return static_cast<Priority>(priority);
465 }
466
unvalidatedConvert(const nn::Capabilities & capabilities)467 nn::GeneralResult<Capabilities> unvalidatedConvert(const nn::Capabilities& capabilities) {
468 std::vector<nn::Capabilities::OperandPerformance> operandPerformance;
469 operandPerformance.reserve(capabilities.operandPerformance.asVector().size());
470 std::copy_if(capabilities.operandPerformance.asVector().begin(),
471 capabilities.operandPerformance.asVector().end(),
472 std::back_inserter(operandPerformance),
473 [](const nn::Capabilities::OperandPerformance& operandPerformance) {
474 return compliantVersion(operandPerformance.type).has_value();
475 });
476
477 return Capabilities{
478 .relaxedFloat32toFloat16PerformanceScalar = NN_TRY(
479 unvalidatedConvert(capabilities.relaxedFloat32toFloat16PerformanceScalar)),
480 .relaxedFloat32toFloat16PerformanceTensor = NN_TRY(
481 unvalidatedConvert(capabilities.relaxedFloat32toFloat16PerformanceTensor)),
482 .operandPerformance = NN_TRY(unvalidatedConvert(operandPerformance)),
483 .ifPerformance = NN_TRY(unvalidatedConvert(capabilities.ifPerformance)),
484 .whilePerformance = NN_TRY(unvalidatedConvert(capabilities.whilePerformance)),
485 };
486 }
487
unvalidatedConvert(const nn::Capabilities::OperandPerformance & operandPerformance)488 nn::GeneralResult<Capabilities::OperandPerformance> unvalidatedConvert(
489 const nn::Capabilities::OperandPerformance& operandPerformance) {
490 return Capabilities::OperandPerformance{
491 .type = NN_TRY(unvalidatedConvert(operandPerformance.type)),
492 .info = NN_TRY(unvalidatedConvert(operandPerformance.info)),
493 };
494 }
495
unvalidatedConvert(const nn::Operation & operation)496 nn::GeneralResult<Operation> unvalidatedConvert(const nn::Operation& operation) {
497 return Operation{
498 .type = NN_TRY(unvalidatedConvert(operation.type)),
499 .inputs = operation.inputs,
500 .outputs = operation.outputs,
501 };
502 }
503
unvalidatedConvert(const nn::Operand::LifeTime & operandLifeTime)504 nn::GeneralResult<OperandLifeTime> unvalidatedConvert(
505 const nn::Operand::LifeTime& operandLifeTime) {
506 if (operandLifeTime == nn::Operand::LifeTime::POINTER) {
507 return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT)
508 << "Model cannot be unvalidatedConverted because it contains pointer-based memory";
509 }
510 return static_cast<OperandLifeTime>(operandLifeTime);
511 }
512
unvalidatedConvert(const nn::Operand & operand)513 nn::GeneralResult<Operand> unvalidatedConvert(const nn::Operand& operand) {
514 return Operand{
515 .type = NN_TRY(unvalidatedConvert(operand.type)),
516 .dimensions = operand.dimensions,
517 .numberOfConsumers = 0,
518 .scale = operand.scale,
519 .zeroPoint = operand.zeroPoint,
520 .lifetime = NN_TRY(unvalidatedConvert(operand.lifetime)),
521 .location = NN_TRY(unvalidatedConvert(operand.location)),
522 .extraParams = NN_TRY(unvalidatedConvert(operand.extraParams)),
523 };
524 }
525
unvalidatedConvert(const nn::Model & model)526 nn::GeneralResult<Model> unvalidatedConvert(const nn::Model& model) {
527 if (!hal::utils::hasNoPointerData(model)) {
528 return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT)
529 << "Model cannot be unvalidatedConverted because it contains pointer-based memory";
530 }
531
532 return Model{
533 .main = NN_TRY(unvalidatedConvert(model.main)),
534 .referenced = NN_TRY(unvalidatedConvert(model.referenced)),
535 .operandValues = NN_TRY(unvalidatedConvert(model.operandValues)),
536 .pools = NN_TRY(unvalidatedConvert(model.pools)),
537 .relaxComputationFloat32toFloat16 = model.relaxComputationFloat32toFloat16,
538 .extensionNameToPrefix = NN_TRY(unvalidatedConvert(model.extensionNameToPrefix)),
539 };
540 }
541
unvalidatedConvert(const nn::Model::Subgraph & subgraph)542 nn::GeneralResult<Subgraph> unvalidatedConvert(const nn::Model::Subgraph& subgraph) {
543 auto operands = NN_TRY(unvalidatedConvert(subgraph.operands));
544
545 // Update number of consumers.
546 const auto numberOfConsumers =
547 NN_TRY(hal::utils::countNumberOfConsumers(operands.size(), subgraph.operations));
548 CHECK(operands.size() == numberOfConsumers.size());
549 for (size_t i = 0; i < operands.size(); ++i) {
550 operands[i].numberOfConsumers = numberOfConsumers[i];
551 }
552
553 return Subgraph{
554 .operands = std::move(operands),
555 .operations = NN_TRY(unvalidatedConvert(subgraph.operations)),
556 .inputIndexes = subgraph.inputIndexes,
557 .outputIndexes = subgraph.outputIndexes,
558 };
559 }
560
unvalidatedConvert(const nn::BufferDesc & bufferDesc)561 nn::GeneralResult<BufferDesc> unvalidatedConvert(const nn::BufferDesc& bufferDesc) {
562 return BufferDesc{.dimensions = bufferDesc.dimensions};
563 }
564
unvalidatedConvert(const nn::BufferRole & bufferRole)565 nn::GeneralResult<BufferRole> unvalidatedConvert(const nn::BufferRole& bufferRole) {
566 return BufferRole{
567 .modelIndex = bufferRole.modelIndex,
568 .ioIndex = bufferRole.ioIndex,
569 .frequency = bufferRole.probability,
570 };
571 }
572
unvalidatedConvert(const nn::Request & request)573 nn::GeneralResult<Request> unvalidatedConvert(const nn::Request& request) {
574 if (!hal::utils::hasNoPointerData(request)) {
575 return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT)
576 << "Request cannot be unvalidatedConverted because it contains pointer-based memory";
577 }
578
579 return Request{
580 .inputs = NN_TRY(unvalidatedConvert(request.inputs)),
581 .outputs = NN_TRY(unvalidatedConvert(request.outputs)),
582 .pools = NN_TRY(unvalidatedConvert(request.pools)),
583 };
584 }
585
unvalidatedConvert(const nn::Request::MemoryPool & memoryPool)586 nn::GeneralResult<Request::MemoryPool> unvalidatedConvert(
587 const nn::Request::MemoryPool& memoryPool) {
588 return std::visit([](const auto& o) { return makeMemoryPool(o); }, memoryPool);
589 }
590
unvalidatedConvert(const nn::OptionalTimePoint & optionalTimePoint)591 nn::GeneralResult<OptionalTimePoint> unvalidatedConvert(
592 const nn::OptionalTimePoint& optionalTimePoint) {
593 const auto currentSteadyTime = std::chrono::steady_clock::now();
594 const auto currentBootTime = nn::Clock::now();
595
596 OptionalTimePoint ret;
597 if (optionalTimePoint.has_value()) {
598 const auto bootTimePoint = optionalTimePoint.value();
599
600 if (bootTimePoint < nn::TimePoint{}) {
601 return NN_ERROR() << "Trying to cast invalid time point";
602 }
603
604 // Both bootTimePoint and currentBootTime are guaranteed to be non-negative, so this
605 // subtraction will never overflow or underflow.
606 const auto timeRemaining = bootTimePoint - currentBootTime;
607
608 // currentSteadyTime is guaranteed to be non-negative, so this code only protects against an
609 // overflow.
610 std::chrono::steady_clock::time_point steadyTimePoint;
611 constexpr auto kZeroNano = std::chrono::nanoseconds::zero();
612 constexpr auto kMaxTime = std::chrono::steady_clock::time_point::max();
613 if (timeRemaining > kZeroNano && currentSteadyTime > kMaxTime - timeRemaining) {
614 steadyTimePoint = kMaxTime;
615 } else {
616 steadyTimePoint = currentSteadyTime + timeRemaining;
617 }
618
619 const uint64_t count = makeUint64FromNanos(steadyTimePoint.time_since_epoch());
620 ret.nanosecondsSinceEpoch(count);
621 }
622 return ret;
623 }
624
unvalidatedConvert(const nn::OptionalDuration & optionalTimeoutDuration)625 nn::GeneralResult<OptionalTimeoutDuration> unvalidatedConvert(
626 const nn::OptionalDuration& optionalTimeoutDuration) {
627 OptionalTimeoutDuration ret;
628 if (optionalTimeoutDuration.has_value()) {
629 const auto count = optionalTimeoutDuration.value().count();
630 ret.nanoseconds(count);
631 }
632 return ret;
633 }
634
unvalidatedConvert(const nn::ErrorStatus & errorStatus)635 nn::GeneralResult<ErrorStatus> unvalidatedConvert(const nn::ErrorStatus& errorStatus) {
636 switch (errorStatus) {
637 case nn::ErrorStatus::NONE:
638 case nn::ErrorStatus::DEVICE_UNAVAILABLE:
639 case nn::ErrorStatus::GENERAL_FAILURE:
640 case nn::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE:
641 case nn::ErrorStatus::INVALID_ARGUMENT:
642 case nn::ErrorStatus::MISSED_DEADLINE_TRANSIENT:
643 case nn::ErrorStatus::MISSED_DEADLINE_PERSISTENT:
644 case nn::ErrorStatus::RESOURCE_EXHAUSTED_TRANSIENT:
645 case nn::ErrorStatus::RESOURCE_EXHAUSTED_PERSISTENT:
646 return static_cast<ErrorStatus>(errorStatus);
647 default:
648 return ErrorStatus::GENERAL_FAILURE;
649 }
650 }
651
convert(const nn::Priority & priority)652 nn::GeneralResult<Priority> convert(const nn::Priority& priority) {
653 return validatedConvert(priority);
654 }
655
convert(const nn::Capabilities & capabilities)656 nn::GeneralResult<Capabilities> convert(const nn::Capabilities& capabilities) {
657 return validatedConvert(capabilities);
658 }
659
convert(const nn::Model & model)660 nn::GeneralResult<Model> convert(const nn::Model& model) {
661 return validatedConvert(model);
662 }
663
convert(const nn::BufferDesc & bufferDesc)664 nn::GeneralResult<BufferDesc> convert(const nn::BufferDesc& bufferDesc) {
665 return validatedConvert(bufferDesc);
666 }
667
convert(const nn::Request & request)668 nn::GeneralResult<Request> convert(const nn::Request& request) {
669 return validatedConvert(request);
670 }
671
convert(const nn::OptionalTimePoint & optionalTimePoint)672 nn::GeneralResult<OptionalTimePoint> convert(const nn::OptionalTimePoint& optionalTimePoint) {
673 return validatedConvert(optionalTimePoint);
674 }
675
convert(const nn::OptionalDuration & optionalTimeoutDuration)676 nn::GeneralResult<OptionalTimeoutDuration> convert(
677 const nn::OptionalDuration& optionalTimeoutDuration) {
678 return validatedConvert(optionalTimeoutDuration);
679 }
680
convert(const nn::ErrorStatus & errorStatus)681 nn::GeneralResult<ErrorStatus> convert(const nn::ErrorStatus& errorStatus) {
682 return validatedConvert(errorStatus);
683 }
684
convert(const nn::SharedHandle & handle)685 nn::GeneralResult<hidl_handle> convert(const nn::SharedHandle& handle) {
686 return validatedConvert(handle);
687 }
688
convert(const nn::SharedMemory & memory)689 nn::GeneralResult<hidl_memory> convert(const nn::SharedMemory& memory) {
690 return validatedConvert(memory);
691 }
692
convert(const std::vector<nn::BufferRole> & bufferRoles)693 nn::GeneralResult<hidl_vec<BufferRole>> convert(const std::vector<nn::BufferRole>& bufferRoles) {
694 return validatedConvert(bufferRoles);
695 }
696
convert(const nn::DeviceStatus & deviceStatus)697 nn::GeneralResult<V1_0::DeviceStatus> convert(const nn::DeviceStatus& deviceStatus) {
698 return V1_2::utils::convert(deviceStatus);
699 }
700
convert(const nn::ExecutionPreference & executionPreference)701 nn::GeneralResult<V1_1::ExecutionPreference> convert(
702 const nn::ExecutionPreference& executionPreference) {
703 return V1_2::utils::convert(executionPreference);
704 }
705
convert(const std::vector<nn::Extension> & extensions)706 nn::GeneralResult<hidl_vec<V1_2::Extension>> convert(const std::vector<nn::Extension>& extensions) {
707 return V1_2::utils::convert(extensions);
708 }
709
convert(const std::vector<nn::SharedHandle> & handles)710 nn::GeneralResult<hidl_vec<hidl_handle>> convert(const std::vector<nn::SharedHandle>& handles) {
711 return V1_2::utils::convert(handles);
712 }
713
convert(const std::vector<nn::OutputShape> & outputShapes)714 nn::GeneralResult<hidl_vec<V1_2::OutputShape>> convert(
715 const std::vector<nn::OutputShape>& outputShapes) {
716 return V1_2::utils::convert(outputShapes);
717 }
718
convert(const nn::DeviceType & deviceType)719 nn::GeneralResult<V1_2::DeviceType> convert(const nn::DeviceType& deviceType) {
720 return V1_2::utils::convert(deviceType);
721 }
722
convert(const nn::MeasureTiming & measureTiming)723 nn::GeneralResult<V1_2::MeasureTiming> convert(const nn::MeasureTiming& measureTiming) {
724 return V1_2::utils::convert(measureTiming);
725 }
726
convert(const nn::Timing & timing)727 nn::GeneralResult<V1_2::Timing> convert(const nn::Timing& timing) {
728 return V1_2::utils::convert(timing);
729 }
730
731 } // namespace android::hardware::neuralnetworks::V1_3::utils
732