1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "Conversions.h"
18 
19 #include <android-base/logging.h>
20 #include <android/hardware/neuralnetworks/1.3/types.h>
21 #include <nnapi/OperandTypes.h>
22 #include <nnapi/OperationTypes.h>
23 #include <nnapi/Result.h>
24 #include <nnapi/SharedMemory.h>
25 #include <nnapi/TypeUtils.h>
26 #include <nnapi/Types.h>
27 #include <nnapi/Validation.h>
28 #include <nnapi/hal/1.0/Conversions.h>
29 #include <nnapi/hal/1.2/Conversions.h>
30 #include <nnapi/hal/CommonUtils.h>
31 #include <nnapi/hal/HandleError.h>
32 
33 #include <algorithm>
34 #include <chrono>
35 #include <functional>
36 #include <iterator>
37 #include <limits>
38 #include <type_traits>
39 #include <utility>
40 
41 #include "Utils.h"
42 
43 namespace {
44 
makeNanosFromUint64(uint64_t nanoseconds)45 std::chrono::nanoseconds makeNanosFromUint64(uint64_t nanoseconds) {
46     constexpr auto kMaxCount = std::chrono::nanoseconds::max().count();
47     using CommonType = std::common_type_t<std::chrono::nanoseconds::rep, uint64_t>;
48     const auto count = std::min<CommonType>(kMaxCount, nanoseconds);
49     return std::chrono::nanoseconds{static_cast<std::chrono::nanoseconds::rep>(count)};
50 }
51 
makeUint64FromNanos(std::chrono::nanoseconds nanoseconds)52 uint64_t makeUint64FromNanos(std::chrono::nanoseconds nanoseconds) {
53     if (nanoseconds < std::chrono::nanoseconds::zero()) {
54         return 0;
55     }
56     constexpr auto kMaxCount = std::numeric_limits<uint64_t>::max();
57     using CommonType = std::common_type_t<std::chrono::nanoseconds::rep, uint64_t>;
58     const auto count = std::min<CommonType>(kMaxCount, nanoseconds.count());
59     return static_cast<uint64_t>(count);
60 }
61 
62 template <typename Type>
underlyingType(Type value)63 constexpr std::underlying_type_t<Type> underlyingType(Type value) {
64     return static_cast<std::underlying_type_t<Type>>(value);
65 }
66 
67 }  // namespace
68 
69 namespace android::nn {
70 namespace {
71 
72 using hardware::hidl_vec;
73 
74 template <typename Input>
75 using UnvalidatedConvertOutput =
76         std::decay_t<decltype(unvalidatedConvert(std::declval<Input>()).value())>;
77 
78 template <typename Type>
unvalidatedConvert(const hidl_vec<Type> & arguments)79 GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> unvalidatedConvert(
80         const hidl_vec<Type>& arguments) {
81     std::vector<UnvalidatedConvertOutput<Type>> canonical;
82     canonical.reserve(arguments.size());
83     for (const auto& argument : arguments) {
84         canonical.push_back(NN_TRY(nn::unvalidatedConvert(argument)));
85     }
86     return canonical;
87 }
88 
89 template <typename Type>
validatedConvert(const Type & halObject)90 GeneralResult<UnvalidatedConvertOutput<Type>> validatedConvert(const Type& halObject) {
91     auto canonical = NN_TRY(nn::unvalidatedConvert(halObject));
92     NN_TRY(hal::V1_3::utils::compliantVersion(canonical));
93     return canonical;
94 }
95 
96 template <typename Type>
validatedConvert(const hidl_vec<Type> & arguments)97 GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> validatedConvert(
98         const hidl_vec<Type>& arguments) {
99     std::vector<UnvalidatedConvertOutput<Type>> canonical;
100     canonical.reserve(arguments.size());
101     for (const auto& argument : arguments) {
102         canonical.push_back(NN_TRY(validatedConvert(argument)));
103     }
104     return canonical;
105 }
106 
107 }  // anonymous namespace
108 
unvalidatedConvert(const hal::V1_3::OperandType & operandType)109 GeneralResult<OperandType> unvalidatedConvert(const hal::V1_3::OperandType& operandType) {
110     return static_cast<OperandType>(operandType);
111 }
112 
unvalidatedConvert(const hal::V1_3::OperationType & operationType)113 GeneralResult<OperationType> unvalidatedConvert(const hal::V1_3::OperationType& operationType) {
114     return static_cast<OperationType>(operationType);
115 }
116 
unvalidatedConvert(const hal::V1_3::Priority & priority)117 GeneralResult<Priority> unvalidatedConvert(const hal::V1_3::Priority& priority) {
118     return static_cast<Priority>(priority);
119 }
120 
unvalidatedConvert(const hal::V1_3::Capabilities & capabilities)121 GeneralResult<Capabilities> unvalidatedConvert(const hal::V1_3::Capabilities& capabilities) {
122     const bool validOperandTypes = std::all_of(
123             capabilities.operandPerformance.begin(), capabilities.operandPerformance.end(),
124             [](const hal::V1_3::Capabilities::OperandPerformance& operandPerformance) {
125                 return validatedConvert(operandPerformance.type).has_value();
126             });
127     if (!validOperandTypes) {
128         return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
129                << "Invalid OperandType when unvalidatedConverting OperandPerformance in "
130                   "Capabilities";
131     }
132 
133     auto operandPerformance = NN_TRY(unvalidatedConvert(capabilities.operandPerformance));
134     auto table = NN_TRY(hal::utils::makeGeneralFailure(
135             Capabilities::OperandPerformanceTable::create(std::move(operandPerformance)),
136             nn::ErrorStatus::GENERAL_FAILURE));
137 
138     return Capabilities{
139             .relaxedFloat32toFloat16PerformanceScalar = NN_TRY(
140                     unvalidatedConvert(capabilities.relaxedFloat32toFloat16PerformanceScalar)),
141             .relaxedFloat32toFloat16PerformanceTensor = NN_TRY(
142                     unvalidatedConvert(capabilities.relaxedFloat32toFloat16PerformanceTensor)),
143             .operandPerformance = std::move(table),
144             .ifPerformance = NN_TRY(unvalidatedConvert(capabilities.ifPerformance)),
145             .whilePerformance = NN_TRY(unvalidatedConvert(capabilities.whilePerformance)),
146     };
147 }
148 
unvalidatedConvert(const hal::V1_3::Capabilities::OperandPerformance & operandPerformance)149 GeneralResult<Capabilities::OperandPerformance> unvalidatedConvert(
150         const hal::V1_3::Capabilities::OperandPerformance& operandPerformance) {
151     return Capabilities::OperandPerformance{
152             .type = NN_TRY(unvalidatedConvert(operandPerformance.type)),
153             .info = NN_TRY(unvalidatedConvert(operandPerformance.info)),
154     };
155 }
156 
unvalidatedConvert(const hal::V1_3::Operation & operation)157 GeneralResult<Operation> unvalidatedConvert(const hal::V1_3::Operation& operation) {
158     return Operation{
159             .type = NN_TRY(unvalidatedConvert(operation.type)),
160             .inputs = operation.inputs,
161             .outputs = operation.outputs,
162     };
163 }
164 
unvalidatedConvert(const hal::V1_3::OperandLifeTime & operandLifeTime)165 GeneralResult<Operand::LifeTime> unvalidatedConvert(
166         const hal::V1_3::OperandLifeTime& operandLifeTime) {
167     return static_cast<Operand::LifeTime>(operandLifeTime);
168 }
169 
unvalidatedConvert(const hal::V1_3::Operand & operand)170 GeneralResult<Operand> unvalidatedConvert(const hal::V1_3::Operand& operand) {
171     return Operand{
172             .type = NN_TRY(unvalidatedConvert(operand.type)),
173             .dimensions = operand.dimensions,
174             .scale = operand.scale,
175             .zeroPoint = operand.zeroPoint,
176             .lifetime = NN_TRY(unvalidatedConvert(operand.lifetime)),
177             .location = NN_TRY(unvalidatedConvert(operand.location)),
178             .extraParams = NN_TRY(unvalidatedConvert(operand.extraParams)),
179     };
180 }
181 
unvalidatedConvert(const hal::V1_3::Model & model)182 GeneralResult<Model> unvalidatedConvert(const hal::V1_3::Model& model) {
183     return Model{
184             .main = NN_TRY(unvalidatedConvert(model.main)),
185             .referenced = NN_TRY(unvalidatedConvert(model.referenced)),
186             .operandValues = NN_TRY(unvalidatedConvert(model.operandValues)),
187             .pools = NN_TRY(unvalidatedConvert(model.pools)),
188             .relaxComputationFloat32toFloat16 = model.relaxComputationFloat32toFloat16,
189             .extensionNameToPrefix = NN_TRY(unvalidatedConvert(model.extensionNameToPrefix)),
190     };
191 }
192 
unvalidatedConvert(const hal::V1_3::Subgraph & subgraph)193 GeneralResult<Model::Subgraph> unvalidatedConvert(const hal::V1_3::Subgraph& subgraph) {
194     auto operations = NN_TRY(unvalidatedConvert(subgraph.operations));
195 
196     // Verify number of consumers.
197     const auto numberOfConsumers =
198             NN_TRY(hal::utils::countNumberOfConsumers(subgraph.operands.size(), operations));
199     CHECK(subgraph.operands.size() == numberOfConsumers.size());
200     for (size_t i = 0; i < subgraph.operands.size(); ++i) {
201         if (subgraph.operands[i].numberOfConsumers != numberOfConsumers[i]) {
202             return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
203                    << "Invalid numberOfConsumers for operand " << i << ", expected "
204                    << numberOfConsumers[i] << " but found "
205                    << subgraph.operands[i].numberOfConsumers;
206         }
207     }
208 
209     return Model::Subgraph{
210             .operands = NN_TRY(unvalidatedConvert(subgraph.operands)),
211             .operations = std::move(operations),
212             .inputIndexes = subgraph.inputIndexes,
213             .outputIndexes = subgraph.outputIndexes,
214     };
215 }
216 
unvalidatedConvert(const hal::V1_3::BufferDesc & bufferDesc)217 GeneralResult<BufferDesc> unvalidatedConvert(const hal::V1_3::BufferDesc& bufferDesc) {
218     return BufferDesc{.dimensions = bufferDesc.dimensions};
219 }
220 
unvalidatedConvert(const hal::V1_3::BufferRole & bufferRole)221 GeneralResult<BufferRole> unvalidatedConvert(const hal::V1_3::BufferRole& bufferRole) {
222     return BufferRole{
223             .modelIndex = bufferRole.modelIndex,
224             .ioIndex = bufferRole.ioIndex,
225             .probability = bufferRole.frequency,
226     };
227 }
228 
unvalidatedConvert(const hal::V1_3::Request & request)229 GeneralResult<Request> unvalidatedConvert(const hal::V1_3::Request& request) {
230     return Request{
231             .inputs = NN_TRY(unvalidatedConvert(request.inputs)),
232             .outputs = NN_TRY(unvalidatedConvert(request.outputs)),
233             .pools = NN_TRY(unvalidatedConvert(request.pools)),
234     };
235 }
236 
unvalidatedConvert(const hal::V1_3::Request::MemoryPool & memoryPool)237 GeneralResult<Request::MemoryPool> unvalidatedConvert(
238         const hal::V1_3::Request::MemoryPool& memoryPool) {
239     using Discriminator = hal::V1_3::Request::MemoryPool::hidl_discriminator;
240     switch (memoryPool.getDiscriminator()) {
241         case Discriminator::hidlMemory:
242             return hal::utils::createSharedMemoryFromHidlMemory(memoryPool.hidlMemory());
243         case Discriminator::token:
244             return static_cast<Request::MemoryDomainToken>(memoryPool.token());
245     }
246     return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
247            << "Invalid Request::MemoryPool discriminator "
248            << underlyingType(memoryPool.getDiscriminator());
249 }
250 
unvalidatedConvert(const hal::V1_3::OptionalTimePoint & optionalTimePoint)251 GeneralResult<OptionalTimePoint> unvalidatedConvert(
252         const hal::V1_3::OptionalTimePoint& optionalTimePoint) {
253     using Discriminator = hal::V1_3::OptionalTimePoint::hidl_discriminator;
254     switch (optionalTimePoint.getDiscriminator()) {
255         case Discriminator::none:
256             return {};
257         case Discriminator::nanosecondsSinceEpoch: {
258             const auto currentSteadyTime = std::chrono::steady_clock::now();
259             const auto currentBootTime = Clock::now();
260 
261             const auto timeSinceEpoch =
262                     makeNanosFromUint64(optionalTimePoint.nanosecondsSinceEpoch());
263             const auto steadyTimePoint = std::chrono::steady_clock::time_point{timeSinceEpoch};
264 
265             // Both steadyTimePoint and currentSteadyTime are guaranteed to be non-negative, so this
266             // subtraction will never overflow or underflow.
267             const auto timeRemaining = steadyTimePoint - currentSteadyTime;
268 
269             // currentBootTime is guaranteed to be non-negative, so this code only protects against
270             // an overflow.
271             nn::TimePoint bootTimePoint;
272             constexpr auto kZeroNano = std::chrono::nanoseconds::zero();
273             constexpr auto kMaxTime = nn::TimePoint::max();
274             if (timeRemaining > kZeroNano && currentBootTime > kMaxTime - timeRemaining) {
275                 bootTimePoint = kMaxTime;
276             } else {
277                 bootTimePoint = currentBootTime + timeRemaining;
278             }
279 
280             constexpr auto kZeroTime = nn::TimePoint{};
281             return std::max(bootTimePoint, kZeroTime);
282         }
283     }
284     return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
285            << "Invalid OptionalTimePoint discriminator "
286            << underlyingType(optionalTimePoint.getDiscriminator());
287 }
288 
unvalidatedConvert(const hal::V1_3::OptionalTimeoutDuration & optionalTimeoutDuration)289 GeneralResult<OptionalDuration> unvalidatedConvert(
290         const hal::V1_3::OptionalTimeoutDuration& optionalTimeoutDuration) {
291     using Discriminator = hal::V1_3::OptionalTimeoutDuration::hidl_discriminator;
292     switch (optionalTimeoutDuration.getDiscriminator()) {
293         case Discriminator::none:
294             return {};
295         case Discriminator::nanoseconds:
296             return Duration(optionalTimeoutDuration.nanoseconds());
297     }
298     return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
299            << "Invalid OptionalTimeoutDuration discriminator "
300            << underlyingType(optionalTimeoutDuration.getDiscriminator());
301 }
302 
unvalidatedConvert(const hal::V1_3::ErrorStatus & status)303 GeneralResult<ErrorStatus> unvalidatedConvert(const hal::V1_3::ErrorStatus& status) {
304     switch (status) {
305         case hal::V1_3::ErrorStatus::NONE:
306         case hal::V1_3::ErrorStatus::DEVICE_UNAVAILABLE:
307         case hal::V1_3::ErrorStatus::GENERAL_FAILURE:
308         case hal::V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE:
309         case hal::V1_3::ErrorStatus::INVALID_ARGUMENT:
310         case hal::V1_3::ErrorStatus::MISSED_DEADLINE_TRANSIENT:
311         case hal::V1_3::ErrorStatus::MISSED_DEADLINE_PERSISTENT:
312         case hal::V1_3::ErrorStatus::RESOURCE_EXHAUSTED_TRANSIENT:
313         case hal::V1_3::ErrorStatus::RESOURCE_EXHAUSTED_PERSISTENT:
314             return static_cast<ErrorStatus>(status);
315     }
316     return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
317            << "Invalid ErrorStatus " << underlyingType(status);
318 }
319 
convert(const hal::V1_3::Priority & priority)320 GeneralResult<Priority> convert(const hal::V1_3::Priority& priority) {
321     return validatedConvert(priority);
322 }
323 
convert(const hal::V1_3::Capabilities & capabilities)324 GeneralResult<Capabilities> convert(const hal::V1_3::Capabilities& capabilities) {
325     return validatedConvert(capabilities);
326 }
327 
convert(const hal::V1_3::Model & model)328 GeneralResult<Model> convert(const hal::V1_3::Model& model) {
329     return validatedConvert(model);
330 }
331 
convert(const hal::V1_3::BufferDesc & bufferDesc)332 GeneralResult<BufferDesc> convert(const hal::V1_3::BufferDesc& bufferDesc) {
333     return validatedConvert(bufferDesc);
334 }
335 
convert(const hal::V1_3::Request & request)336 GeneralResult<Request> convert(const hal::V1_3::Request& request) {
337     return validatedConvert(request);
338 }
339 
convert(const hal::V1_3::OptionalTimePoint & optionalTimePoint)340 GeneralResult<OptionalTimePoint> convert(const hal::V1_3::OptionalTimePoint& optionalTimePoint) {
341     return validatedConvert(optionalTimePoint);
342 }
343 
convert(const hal::V1_3::OptionalTimeoutDuration & optionalTimeoutDuration)344 GeneralResult<OptionalDuration> convert(
345         const hal::V1_3::OptionalTimeoutDuration& optionalTimeoutDuration) {
346     return validatedConvert(optionalTimeoutDuration);
347 }
348 
convert(const hal::V1_3::ErrorStatus & errorStatus)349 GeneralResult<ErrorStatus> convert(const hal::V1_3::ErrorStatus& errorStatus) {
350     return validatedConvert(errorStatus);
351 }
352 
convert(const hardware::hidl_handle & handle)353 GeneralResult<SharedHandle> convert(const hardware::hidl_handle& handle) {
354     return validatedConvert(handle);
355 }
356 
convert(const hardware::hidl_vec<hal::V1_3::BufferRole> & bufferRoles)357 GeneralResult<std::vector<BufferRole>> convert(
358         const hardware::hidl_vec<hal::V1_3::BufferRole>& bufferRoles) {
359     return validatedConvert(bufferRoles);
360 }
361 
362 }  // namespace android::nn
363 
364 namespace android::hardware::neuralnetworks::V1_3::utils {
365 namespace {
366 
367 using utils::unvalidatedConvert;
368 
unvalidatedConvert(const nn::Capabilities::PerformanceInfo & performanceInfo)369 nn::GeneralResult<V1_0::PerformanceInfo> unvalidatedConvert(
370         const nn::Capabilities::PerformanceInfo& performanceInfo) {
371     return V1_0::utils::unvalidatedConvert(performanceInfo);
372 }
373 
unvalidatedConvert(const nn::DataLocation & dataLocation)374 nn::GeneralResult<V1_0::DataLocation> unvalidatedConvert(const nn::DataLocation& dataLocation) {
375     return V1_0::utils::unvalidatedConvert(dataLocation);
376 }
377 
unvalidatedConvert(const nn::Model::OperandValues & operandValues)378 nn::GeneralResult<hidl_vec<uint8_t>> unvalidatedConvert(
379         const nn::Model::OperandValues& operandValues) {
380     return V1_0::utils::unvalidatedConvert(operandValues);
381 }
382 
unvalidatedConvert(const nn::SharedHandle & handle)383 nn::GeneralResult<hidl_handle> unvalidatedConvert(const nn::SharedHandle& handle) {
384     return V1_2::utils::unvalidatedConvert(handle);
385 }
386 
unvalidatedConvert(const nn::SharedMemory & memory)387 nn::GeneralResult<hidl_memory> unvalidatedConvert(const nn::SharedMemory& memory) {
388     return V1_0::utils::unvalidatedConvert(memory);
389 }
390 
unvalidatedConvert(const nn::Request::Argument & argument)391 nn::GeneralResult<V1_0::RequestArgument> unvalidatedConvert(const nn::Request::Argument& argument) {
392     return V1_0::utils::unvalidatedConvert(argument);
393 }
394 
unvalidatedConvert(const nn::Operand::ExtraParams & extraParams)395 nn::GeneralResult<V1_2::Operand::ExtraParams> unvalidatedConvert(
396         const nn::Operand::ExtraParams& extraParams) {
397     return V1_2::utils::unvalidatedConvert(extraParams);
398 }
399 
unvalidatedConvert(const nn::Model::ExtensionNameAndPrefix & extensionNameAndPrefix)400 nn::GeneralResult<V1_2::Model::ExtensionNameAndPrefix> unvalidatedConvert(
401         const nn::Model::ExtensionNameAndPrefix& extensionNameAndPrefix) {
402     return V1_2::utils::unvalidatedConvert(extensionNameAndPrefix);
403 }
404 
405 template <typename Input>
406 using UnvalidatedConvertOutput =
407         std::decay_t<decltype(unvalidatedConvert(std::declval<Input>()).value())>;
408 
409 template <typename Type>
unvalidatedConvert(const std::vector<Type> & arguments)410 nn::GeneralResult<hidl_vec<UnvalidatedConvertOutput<Type>>> unvalidatedConvert(
411         const std::vector<Type>& arguments) {
412     hidl_vec<UnvalidatedConvertOutput<Type>> halObject(arguments.size());
413     for (size_t i = 0; i < arguments.size(); ++i) {
414         halObject[i] = NN_TRY(unvalidatedConvert(arguments[i]));
415     }
416     return halObject;
417 }
418 
makeMemoryPool(const nn::SharedMemory & memory)419 nn::GeneralResult<Request::MemoryPool> makeMemoryPool(const nn::SharedMemory& memory) {
420     Request::MemoryPool ret;
421     ret.hidlMemory(NN_TRY(unvalidatedConvert(memory)));
422     return ret;
423 }
424 
makeMemoryPool(const nn::Request::MemoryDomainToken & token)425 nn::GeneralResult<Request::MemoryPool> makeMemoryPool(const nn::Request::MemoryDomainToken& token) {
426     Request::MemoryPool ret;
427     ret.token(underlyingType(token));
428     return ret;
429 }
430 
makeMemoryPool(const nn::SharedBuffer &)431 nn::GeneralResult<Request::MemoryPool> makeMemoryPool(const nn::SharedBuffer& /*buffer*/) {
432     return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE) << "Unable to make memory pool from IBuffer";
433 }
434 
435 using utils::unvalidatedConvert;
436 
437 template <typename Type>
validatedConvert(const Type & canonical)438 nn::GeneralResult<UnvalidatedConvertOutput<Type>> validatedConvert(const Type& canonical) {
439     NN_TRY(compliantVersion(canonical));
440     return unvalidatedConvert(canonical);
441 }
442 
443 template <typename Type>
validatedConvert(const std::vector<Type> & arguments)444 nn::GeneralResult<hidl_vec<UnvalidatedConvertOutput<Type>>> validatedConvert(
445         const std::vector<Type>& arguments) {
446     hidl_vec<UnvalidatedConvertOutput<Type>> halObject(arguments.size());
447     for (size_t i = 0; i < arguments.size(); ++i) {
448         halObject[i] = NN_TRY(validatedConvert(arguments[i]));
449     }
450     return halObject;
451 }
452 
453 }  // anonymous namespace
454 
unvalidatedConvert(const nn::OperandType & operandType)455 nn::GeneralResult<OperandType> unvalidatedConvert(const nn::OperandType& operandType) {
456     return static_cast<OperandType>(operandType);
457 }
458 
unvalidatedConvert(const nn::OperationType & operationType)459 nn::GeneralResult<OperationType> unvalidatedConvert(const nn::OperationType& operationType) {
460     return static_cast<OperationType>(operationType);
461 }
462 
unvalidatedConvert(const nn::Priority & priority)463 nn::GeneralResult<Priority> unvalidatedConvert(const nn::Priority& priority) {
464     return static_cast<Priority>(priority);
465 }
466 
unvalidatedConvert(const nn::Capabilities & capabilities)467 nn::GeneralResult<Capabilities> unvalidatedConvert(const nn::Capabilities& capabilities) {
468     std::vector<nn::Capabilities::OperandPerformance> operandPerformance;
469     operandPerformance.reserve(capabilities.operandPerformance.asVector().size());
470     std::copy_if(capabilities.operandPerformance.asVector().begin(),
471                  capabilities.operandPerformance.asVector().end(),
472                  std::back_inserter(operandPerformance),
473                  [](const nn::Capabilities::OperandPerformance& operandPerformance) {
474                      return compliantVersion(operandPerformance.type).has_value();
475                  });
476 
477     return Capabilities{
478             .relaxedFloat32toFloat16PerformanceScalar = NN_TRY(
479                     unvalidatedConvert(capabilities.relaxedFloat32toFloat16PerformanceScalar)),
480             .relaxedFloat32toFloat16PerformanceTensor = NN_TRY(
481                     unvalidatedConvert(capabilities.relaxedFloat32toFloat16PerformanceTensor)),
482             .operandPerformance = NN_TRY(unvalidatedConvert(operandPerformance)),
483             .ifPerformance = NN_TRY(unvalidatedConvert(capabilities.ifPerformance)),
484             .whilePerformance = NN_TRY(unvalidatedConvert(capabilities.whilePerformance)),
485     };
486 }
487 
unvalidatedConvert(const nn::Capabilities::OperandPerformance & operandPerformance)488 nn::GeneralResult<Capabilities::OperandPerformance> unvalidatedConvert(
489         const nn::Capabilities::OperandPerformance& operandPerformance) {
490     return Capabilities::OperandPerformance{
491             .type = NN_TRY(unvalidatedConvert(operandPerformance.type)),
492             .info = NN_TRY(unvalidatedConvert(operandPerformance.info)),
493     };
494 }
495 
unvalidatedConvert(const nn::Operation & operation)496 nn::GeneralResult<Operation> unvalidatedConvert(const nn::Operation& operation) {
497     return Operation{
498             .type = NN_TRY(unvalidatedConvert(operation.type)),
499             .inputs = operation.inputs,
500             .outputs = operation.outputs,
501     };
502 }
503 
unvalidatedConvert(const nn::Operand::LifeTime & operandLifeTime)504 nn::GeneralResult<OperandLifeTime> unvalidatedConvert(
505         const nn::Operand::LifeTime& operandLifeTime) {
506     if (operandLifeTime == nn::Operand::LifeTime::POINTER) {
507         return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT)
508                << "Model cannot be unvalidatedConverted because it contains pointer-based memory";
509     }
510     return static_cast<OperandLifeTime>(operandLifeTime);
511 }
512 
unvalidatedConvert(const nn::Operand & operand)513 nn::GeneralResult<Operand> unvalidatedConvert(const nn::Operand& operand) {
514     return Operand{
515             .type = NN_TRY(unvalidatedConvert(operand.type)),
516             .dimensions = operand.dimensions,
517             .numberOfConsumers = 0,
518             .scale = operand.scale,
519             .zeroPoint = operand.zeroPoint,
520             .lifetime = NN_TRY(unvalidatedConvert(operand.lifetime)),
521             .location = NN_TRY(unvalidatedConvert(operand.location)),
522             .extraParams = NN_TRY(unvalidatedConvert(operand.extraParams)),
523     };
524 }
525 
unvalidatedConvert(const nn::Model & model)526 nn::GeneralResult<Model> unvalidatedConvert(const nn::Model& model) {
527     if (!hal::utils::hasNoPointerData(model)) {
528         return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT)
529                << "Model cannot be unvalidatedConverted because it contains pointer-based memory";
530     }
531 
532     return Model{
533             .main = NN_TRY(unvalidatedConvert(model.main)),
534             .referenced = NN_TRY(unvalidatedConvert(model.referenced)),
535             .operandValues = NN_TRY(unvalidatedConvert(model.operandValues)),
536             .pools = NN_TRY(unvalidatedConvert(model.pools)),
537             .relaxComputationFloat32toFloat16 = model.relaxComputationFloat32toFloat16,
538             .extensionNameToPrefix = NN_TRY(unvalidatedConvert(model.extensionNameToPrefix)),
539     };
540 }
541 
unvalidatedConvert(const nn::Model::Subgraph & subgraph)542 nn::GeneralResult<Subgraph> unvalidatedConvert(const nn::Model::Subgraph& subgraph) {
543     auto operands = NN_TRY(unvalidatedConvert(subgraph.operands));
544 
545     // Update number of consumers.
546     const auto numberOfConsumers =
547             NN_TRY(hal::utils::countNumberOfConsumers(operands.size(), subgraph.operations));
548     CHECK(operands.size() == numberOfConsumers.size());
549     for (size_t i = 0; i < operands.size(); ++i) {
550         operands[i].numberOfConsumers = numberOfConsumers[i];
551     }
552 
553     return Subgraph{
554             .operands = std::move(operands),
555             .operations = NN_TRY(unvalidatedConvert(subgraph.operations)),
556             .inputIndexes = subgraph.inputIndexes,
557             .outputIndexes = subgraph.outputIndexes,
558     };
559 }
560 
unvalidatedConvert(const nn::BufferDesc & bufferDesc)561 nn::GeneralResult<BufferDesc> unvalidatedConvert(const nn::BufferDesc& bufferDesc) {
562     return BufferDesc{.dimensions = bufferDesc.dimensions};
563 }
564 
unvalidatedConvert(const nn::BufferRole & bufferRole)565 nn::GeneralResult<BufferRole> unvalidatedConvert(const nn::BufferRole& bufferRole) {
566     return BufferRole{
567             .modelIndex = bufferRole.modelIndex,
568             .ioIndex = bufferRole.ioIndex,
569             .frequency = bufferRole.probability,
570     };
571 }
572 
unvalidatedConvert(const nn::Request & request)573 nn::GeneralResult<Request> unvalidatedConvert(const nn::Request& request) {
574     if (!hal::utils::hasNoPointerData(request)) {
575         return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT)
576                << "Request cannot be unvalidatedConverted because it contains pointer-based memory";
577     }
578 
579     return Request{
580             .inputs = NN_TRY(unvalidatedConvert(request.inputs)),
581             .outputs = NN_TRY(unvalidatedConvert(request.outputs)),
582             .pools = NN_TRY(unvalidatedConvert(request.pools)),
583     };
584 }
585 
unvalidatedConvert(const nn::Request::MemoryPool & memoryPool)586 nn::GeneralResult<Request::MemoryPool> unvalidatedConvert(
587         const nn::Request::MemoryPool& memoryPool) {
588     return std::visit([](const auto& o) { return makeMemoryPool(o); }, memoryPool);
589 }
590 
unvalidatedConvert(const nn::OptionalTimePoint & optionalTimePoint)591 nn::GeneralResult<OptionalTimePoint> unvalidatedConvert(
592         const nn::OptionalTimePoint& optionalTimePoint) {
593     const auto currentSteadyTime = std::chrono::steady_clock::now();
594     const auto currentBootTime = nn::Clock::now();
595 
596     OptionalTimePoint ret;
597     if (optionalTimePoint.has_value()) {
598         const auto bootTimePoint = optionalTimePoint.value();
599 
600         if (bootTimePoint < nn::TimePoint{}) {
601             return NN_ERROR() << "Trying to cast invalid time point";
602         }
603 
604         // Both bootTimePoint and currentBootTime are guaranteed to be non-negative, so this
605         // subtraction will never overflow or underflow.
606         const auto timeRemaining = bootTimePoint - currentBootTime;
607 
608         // currentSteadyTime is guaranteed to be non-negative, so this code only protects against an
609         // overflow.
610         std::chrono::steady_clock::time_point steadyTimePoint;
611         constexpr auto kZeroNano = std::chrono::nanoseconds::zero();
612         constexpr auto kMaxTime = std::chrono::steady_clock::time_point::max();
613         if (timeRemaining > kZeroNano && currentSteadyTime > kMaxTime - timeRemaining) {
614             steadyTimePoint = kMaxTime;
615         } else {
616             steadyTimePoint = currentSteadyTime + timeRemaining;
617         }
618 
619         const uint64_t count = makeUint64FromNanos(steadyTimePoint.time_since_epoch());
620         ret.nanosecondsSinceEpoch(count);
621     }
622     return ret;
623 }
624 
unvalidatedConvert(const nn::OptionalDuration & optionalTimeoutDuration)625 nn::GeneralResult<OptionalTimeoutDuration> unvalidatedConvert(
626         const nn::OptionalDuration& optionalTimeoutDuration) {
627     OptionalTimeoutDuration ret;
628     if (optionalTimeoutDuration.has_value()) {
629         const auto count = optionalTimeoutDuration.value().count();
630         ret.nanoseconds(count);
631     }
632     return ret;
633 }
634 
unvalidatedConvert(const nn::ErrorStatus & errorStatus)635 nn::GeneralResult<ErrorStatus> unvalidatedConvert(const nn::ErrorStatus& errorStatus) {
636     switch (errorStatus) {
637         case nn::ErrorStatus::NONE:
638         case nn::ErrorStatus::DEVICE_UNAVAILABLE:
639         case nn::ErrorStatus::GENERAL_FAILURE:
640         case nn::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE:
641         case nn::ErrorStatus::INVALID_ARGUMENT:
642         case nn::ErrorStatus::MISSED_DEADLINE_TRANSIENT:
643         case nn::ErrorStatus::MISSED_DEADLINE_PERSISTENT:
644         case nn::ErrorStatus::RESOURCE_EXHAUSTED_TRANSIENT:
645         case nn::ErrorStatus::RESOURCE_EXHAUSTED_PERSISTENT:
646             return static_cast<ErrorStatus>(errorStatus);
647         default:
648             return ErrorStatus::GENERAL_FAILURE;
649     }
650 }
651 
convert(const nn::Priority & priority)652 nn::GeneralResult<Priority> convert(const nn::Priority& priority) {
653     return validatedConvert(priority);
654 }
655 
convert(const nn::Capabilities & capabilities)656 nn::GeneralResult<Capabilities> convert(const nn::Capabilities& capabilities) {
657     return validatedConvert(capabilities);
658 }
659 
convert(const nn::Model & model)660 nn::GeneralResult<Model> convert(const nn::Model& model) {
661     return validatedConvert(model);
662 }
663 
convert(const nn::BufferDesc & bufferDesc)664 nn::GeneralResult<BufferDesc> convert(const nn::BufferDesc& bufferDesc) {
665     return validatedConvert(bufferDesc);
666 }
667 
convert(const nn::Request & request)668 nn::GeneralResult<Request> convert(const nn::Request& request) {
669     return validatedConvert(request);
670 }
671 
convert(const nn::OptionalTimePoint & optionalTimePoint)672 nn::GeneralResult<OptionalTimePoint> convert(const nn::OptionalTimePoint& optionalTimePoint) {
673     return validatedConvert(optionalTimePoint);
674 }
675 
convert(const nn::OptionalDuration & optionalTimeoutDuration)676 nn::GeneralResult<OptionalTimeoutDuration> convert(
677         const nn::OptionalDuration& optionalTimeoutDuration) {
678     return validatedConvert(optionalTimeoutDuration);
679 }
680 
convert(const nn::ErrorStatus & errorStatus)681 nn::GeneralResult<ErrorStatus> convert(const nn::ErrorStatus& errorStatus) {
682     return validatedConvert(errorStatus);
683 }
684 
convert(const nn::SharedHandle & handle)685 nn::GeneralResult<hidl_handle> convert(const nn::SharedHandle& handle) {
686     return validatedConvert(handle);
687 }
688 
convert(const nn::SharedMemory & memory)689 nn::GeneralResult<hidl_memory> convert(const nn::SharedMemory& memory) {
690     return validatedConvert(memory);
691 }
692 
convert(const std::vector<nn::BufferRole> & bufferRoles)693 nn::GeneralResult<hidl_vec<BufferRole>> convert(const std::vector<nn::BufferRole>& bufferRoles) {
694     return validatedConvert(bufferRoles);
695 }
696 
convert(const nn::DeviceStatus & deviceStatus)697 nn::GeneralResult<V1_0::DeviceStatus> convert(const nn::DeviceStatus& deviceStatus) {
698     return V1_2::utils::convert(deviceStatus);
699 }
700 
convert(const nn::ExecutionPreference & executionPreference)701 nn::GeneralResult<V1_1::ExecutionPreference> convert(
702         const nn::ExecutionPreference& executionPreference) {
703     return V1_2::utils::convert(executionPreference);
704 }
705 
convert(const std::vector<nn::Extension> & extensions)706 nn::GeneralResult<hidl_vec<V1_2::Extension>> convert(const std::vector<nn::Extension>& extensions) {
707     return V1_2::utils::convert(extensions);
708 }
709 
convert(const std::vector<nn::SharedHandle> & handles)710 nn::GeneralResult<hidl_vec<hidl_handle>> convert(const std::vector<nn::SharedHandle>& handles) {
711     return V1_2::utils::convert(handles);
712 }
713 
convert(const std::vector<nn::OutputShape> & outputShapes)714 nn::GeneralResult<hidl_vec<V1_2::OutputShape>> convert(
715         const std::vector<nn::OutputShape>& outputShapes) {
716     return V1_2::utils::convert(outputShapes);
717 }
718 
convert(const nn::DeviceType & deviceType)719 nn::GeneralResult<V1_2::DeviceType> convert(const nn::DeviceType& deviceType) {
720     return V1_2::utils::convert(deviceType);
721 }
722 
convert(const nn::MeasureTiming & measureTiming)723 nn::GeneralResult<V1_2::MeasureTiming> convert(const nn::MeasureTiming& measureTiming) {
724     return V1_2::utils::convert(measureTiming);
725 }
726 
convert(const nn::Timing & timing)727 nn::GeneralResult<V1_2::Timing> convert(const nn::Timing& timing) {
728     return V1_2::utils::convert(timing);
729 }
730 
731 }  // namespace android::hardware::neuralnetworks::V1_3::utils
732