1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "Conversions.h"
18 
19 #include <android-base/logging.h>
20 #include <android/hardware/neuralnetworks/1.2/types.h>
21 #include <nnapi/OperandTypes.h>
22 #include <nnapi/OperationTypes.h>
23 #include <nnapi/Result.h>
24 #include <nnapi/SharedMemory.h>
25 #include <nnapi/TypeUtils.h>
26 #include <nnapi/Types.h>
27 #include <nnapi/Validation.h>
28 #include <nnapi/hal/1.0/Conversions.h>
29 #include <nnapi/hal/1.1/Conversions.h>
30 #include <nnapi/hal/CommonUtils.h>
31 #include <nnapi/hal/HandleError.h>
32 
33 #include <algorithm>
34 #include <functional>
35 #include <iterator>
36 #include <memory>
37 #include <type_traits>
38 #include <utility>
39 
40 #include "Utils.h"
41 
42 namespace {
43 
44 template <typename Type>
underlyingType(Type value)45 constexpr std::underlying_type_t<Type> underlyingType(Type value) {
46     return static_cast<std::underlying_type_t<Type>>(value);
47 }
48 
49 using HalDuration = std::chrono::duration<uint64_t, std::micro>;
50 
51 }  // namespace
52 
53 namespace android::nn {
54 namespace {
55 
56 using hardware::hidl_handle;
57 using hardware::hidl_vec;
58 
59 template <typename Input>
60 using UnvalidatedConvertOutput =
61         std::decay_t<decltype(unvalidatedConvert(std::declval<Input>()).value())>;
62 
63 template <typename Type>
unvalidatedConvert(const hidl_vec<Type> & arguments)64 GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> unvalidatedConvert(
65         const hidl_vec<Type>& arguments) {
66     std::vector<UnvalidatedConvertOutput<Type>> canonical;
67     canonical.reserve(arguments.size());
68     for (const auto& argument : arguments) {
69         canonical.push_back(NN_TRY(nn::unvalidatedConvert(argument)));
70     }
71     return canonical;
72 }
73 
74 template <typename Type>
validatedConvert(const Type & halObject)75 GeneralResult<UnvalidatedConvertOutput<Type>> validatedConvert(const Type& halObject) {
76     auto canonical = NN_TRY(nn::unvalidatedConvert(halObject));
77     NN_TRY(hal::V1_2::utils::compliantVersion(canonical));
78     return canonical;
79 }
80 
81 template <typename Type>
validatedConvert(const hidl_vec<Type> & arguments)82 GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> validatedConvert(
83         const hidl_vec<Type>& arguments) {
84     std::vector<UnvalidatedConvertOutput<Type>> canonical;
85     canonical.reserve(arguments.size());
86     for (const auto& argument : arguments) {
87         canonical.push_back(NN_TRY(validatedConvert(argument)));
88     }
89     return canonical;
90 }
91 
92 }  // anonymous namespace
93 
unvalidatedConvert(const hal::V1_2::OperandType & operandType)94 GeneralResult<OperandType> unvalidatedConvert(const hal::V1_2::OperandType& operandType) {
95     return static_cast<OperandType>(operandType);
96 }
97 
unvalidatedConvert(const hal::V1_2::OperationType & operationType)98 GeneralResult<OperationType> unvalidatedConvert(const hal::V1_2::OperationType& operationType) {
99     return static_cast<OperationType>(operationType);
100 }
101 
unvalidatedConvert(const hal::V1_2::DeviceType & deviceType)102 GeneralResult<DeviceType> unvalidatedConvert(const hal::V1_2::DeviceType& deviceType) {
103     return static_cast<DeviceType>(deviceType);
104 }
105 
unvalidatedConvert(const hal::V1_2::Capabilities & capabilities)106 GeneralResult<Capabilities> unvalidatedConvert(const hal::V1_2::Capabilities& capabilities) {
107     const bool validOperandTypes = std::all_of(
108             capabilities.operandPerformance.begin(), capabilities.operandPerformance.end(),
109             [](const hal::V1_2::Capabilities::OperandPerformance& operandPerformance) {
110                 return validatedConvert(operandPerformance.type).has_value();
111             });
112     if (!validOperandTypes) {
113         return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
114                << "Invalid OperandType when converting OperandPerformance in Capabilities";
115     }
116 
117     const auto relaxedFloat32toFloat16PerformanceScalar =
118             NN_TRY(unvalidatedConvert(capabilities.relaxedFloat32toFloat16PerformanceScalar));
119     const auto relaxedFloat32toFloat16PerformanceTensor =
120             NN_TRY(unvalidatedConvert(capabilities.relaxedFloat32toFloat16PerformanceTensor));
121     auto operandPerformance = NN_TRY(unvalidatedConvert(capabilities.operandPerformance));
122 
123     auto table = NN_TRY(hal::utils::makeGeneralFailure(
124             Capabilities::OperandPerformanceTable::create(std::move(operandPerformance)),
125             nn::ErrorStatus::GENERAL_FAILURE));
126 
127     return Capabilities{
128             .relaxedFloat32toFloat16PerformanceScalar = relaxedFloat32toFloat16PerformanceScalar,
129             .relaxedFloat32toFloat16PerformanceTensor = relaxedFloat32toFloat16PerformanceTensor,
130             .operandPerformance = std::move(table),
131     };
132 }
133 
unvalidatedConvert(const hal::V1_2::Capabilities::OperandPerformance & operandPerformance)134 GeneralResult<Capabilities::OperandPerformance> unvalidatedConvert(
135         const hal::V1_2::Capabilities::OperandPerformance& operandPerformance) {
136     return Capabilities::OperandPerformance{
137             .type = NN_TRY(unvalidatedConvert(operandPerformance.type)),
138             .info = NN_TRY(unvalidatedConvert(operandPerformance.info)),
139     };
140 }
141 
unvalidatedConvert(const hal::V1_2::Operation & operation)142 GeneralResult<Operation> unvalidatedConvert(const hal::V1_2::Operation& operation) {
143     return Operation{
144             .type = NN_TRY(unvalidatedConvert(operation.type)),
145             .inputs = operation.inputs,
146             .outputs = operation.outputs,
147     };
148 }
149 
unvalidatedConvert(const hal::V1_2::SymmPerChannelQuantParams & symmPerChannelQuantParams)150 GeneralResult<Operand::SymmPerChannelQuantParams> unvalidatedConvert(
151         const hal::V1_2::SymmPerChannelQuantParams& symmPerChannelQuantParams) {
152     return Operand::SymmPerChannelQuantParams{
153             .scales = symmPerChannelQuantParams.scales,
154             .channelDim = symmPerChannelQuantParams.channelDim,
155     };
156 }
157 
unvalidatedConvert(const hal::V1_2::Operand & operand)158 GeneralResult<Operand> unvalidatedConvert(const hal::V1_2::Operand& operand) {
159     return Operand{
160             .type = NN_TRY(unvalidatedConvert(operand.type)),
161             .dimensions = operand.dimensions,
162             .scale = operand.scale,
163             .zeroPoint = operand.zeroPoint,
164             .lifetime = NN_TRY(unvalidatedConvert(operand.lifetime)),
165             .location = NN_TRY(unvalidatedConvert(operand.location)),
166             .extraParams = NN_TRY(unvalidatedConvert(operand.extraParams)),
167     };
168 }
169 
unvalidatedConvert(const hal::V1_2::Operand::ExtraParams & extraParams)170 GeneralResult<Operand::ExtraParams> unvalidatedConvert(
171         const hal::V1_2::Operand::ExtraParams& extraParams) {
172     using Discriminator = hal::V1_2::Operand::ExtraParams::hidl_discriminator;
173     switch (extraParams.getDiscriminator()) {
174         case Discriminator::none:
175             return Operand::NoParams{};
176         case Discriminator::channelQuant:
177             return unvalidatedConvert(extraParams.channelQuant());
178         case Discriminator::extension:
179             return extraParams.extension();
180     }
181     return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
182            << "Unrecognized Operand::ExtraParams discriminator: "
183            << underlyingType(extraParams.getDiscriminator());
184 }
185 
unvalidatedConvert(const hal::V1_2::Model & model)186 GeneralResult<Model> unvalidatedConvert(const hal::V1_2::Model& model) {
187     auto operations = NN_TRY(unvalidatedConvert(model.operations));
188 
189     // Verify number of consumers.
190     const auto numberOfConsumers =
191             NN_TRY(hal::utils::countNumberOfConsumers(model.operands.size(), operations));
192     CHECK(model.operands.size() == numberOfConsumers.size());
193     for (size_t i = 0; i < model.operands.size(); ++i) {
194         if (model.operands[i].numberOfConsumers != numberOfConsumers[i]) {
195             return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
196                    << "Invalid numberOfConsumers for operand " << i << ", expected "
197                    << numberOfConsumers[i] << " but found " << model.operands[i].numberOfConsumers;
198         }
199     }
200 
201     auto main = Model::Subgraph{
202             .operands = NN_TRY(unvalidatedConvert(model.operands)),
203             .operations = std::move(operations),
204             .inputIndexes = model.inputIndexes,
205             .outputIndexes = model.outputIndexes,
206     };
207 
208     return Model{
209             .main = std::move(main),
210             .operandValues = NN_TRY(unvalidatedConvert(model.operandValues)),
211             .pools = NN_TRY(unvalidatedConvert(model.pools)),
212             .relaxComputationFloat32toFloat16 = model.relaxComputationFloat32toFloat16,
213             .extensionNameToPrefix = NN_TRY(unvalidatedConvert(model.extensionNameToPrefix)),
214     };
215 }
216 
unvalidatedConvert(const hal::V1_2::Model::ExtensionNameAndPrefix & extensionNameAndPrefix)217 GeneralResult<Model::ExtensionNameAndPrefix> unvalidatedConvert(
218         const hal::V1_2::Model::ExtensionNameAndPrefix& extensionNameAndPrefix) {
219     return Model::ExtensionNameAndPrefix{
220             .name = extensionNameAndPrefix.name,
221             .prefix = extensionNameAndPrefix.prefix,
222     };
223 }
224 
unvalidatedConvert(const hal::V1_2::OutputShape & outputShape)225 GeneralResult<OutputShape> unvalidatedConvert(const hal::V1_2::OutputShape& outputShape) {
226     return OutputShape{
227             .dimensions = outputShape.dimensions,
228             .isSufficient = outputShape.isSufficient,
229     };
230 }
231 
unvalidatedConvert(const hal::V1_2::MeasureTiming & measureTiming)232 GeneralResult<MeasureTiming> unvalidatedConvert(const hal::V1_2::MeasureTiming& measureTiming) {
233     return static_cast<MeasureTiming>(measureTiming);
234 }
235 
unvalidatedConvert(const hal::V1_2::Timing & timing)236 GeneralResult<Timing> unvalidatedConvert(const hal::V1_2::Timing& timing) {
237     constexpr uint64_t kMaxTiming = std::chrono::floor<HalDuration>(Duration::max()).count();
238     constexpr auto convertTiming = [](uint64_t halTiming) -> OptionalDuration {
239         constexpr uint64_t kNoTiming = std::numeric_limits<uint64_t>::max();
240         if (halTiming == kNoTiming) {
241             return {};
242         }
243         if (halTiming > kMaxTiming) {
244             return Duration::max();
245         }
246         return HalDuration{halTiming};
247     };
248     return Timing{.timeOnDevice = convertTiming(timing.timeOnDevice),
249                   .timeInDriver = convertTiming(timing.timeInDriver)};
250 }
251 
unvalidatedConvert(const hal::V1_2::Extension & extension)252 GeneralResult<Extension> unvalidatedConvert(const hal::V1_2::Extension& extension) {
253     return Extension{
254             .name = extension.name,
255             .operandTypes = NN_TRY(unvalidatedConvert(extension.operandTypes)),
256     };
257 }
258 
unvalidatedConvert(const hal::V1_2::Extension::OperandTypeInformation & operandTypeInformation)259 GeneralResult<Extension::OperandTypeInformation> unvalidatedConvert(
260         const hal::V1_2::Extension::OperandTypeInformation& operandTypeInformation) {
261     return Extension::OperandTypeInformation{
262             .type = operandTypeInformation.type,
263             .isTensor = operandTypeInformation.isTensor,
264             .byteSize = operandTypeInformation.byteSize,
265     };
266 }
267 
unvalidatedConvert(const hidl_handle & hidlHandle)268 GeneralResult<SharedHandle> unvalidatedConvert(const hidl_handle& hidlHandle) {
269     if (hidlHandle.getNativeHandle() == nullptr) {
270         return nullptr;
271     }
272     auto handle = NN_TRY(hal::utils::sharedHandleFromNativeHandle(hidlHandle.getNativeHandle()));
273     return std::make_shared<const Handle>(std::move(handle));
274 }
275 
convert(const hal::V1_2::DeviceType & deviceType)276 GeneralResult<DeviceType> convert(const hal::V1_2::DeviceType& deviceType) {
277     return validatedConvert(deviceType);
278 }
279 
convert(const hal::V1_2::Capabilities & capabilities)280 GeneralResult<Capabilities> convert(const hal::V1_2::Capabilities& capabilities) {
281     return validatedConvert(capabilities);
282 }
283 
convert(const hal::V1_2::Model & model)284 GeneralResult<Model> convert(const hal::V1_2::Model& model) {
285     return validatedConvert(model);
286 }
287 
convert(const hal::V1_2::MeasureTiming & measureTiming)288 GeneralResult<MeasureTiming> convert(const hal::V1_2::MeasureTiming& measureTiming) {
289     return validatedConvert(measureTiming);
290 }
291 
convert(const hal::V1_2::Timing & timing)292 GeneralResult<Timing> convert(const hal::V1_2::Timing& timing) {
293     return validatedConvert(timing);
294 }
295 
convert(const hardware::hidl_memory & memory)296 GeneralResult<SharedMemory> convert(const hardware::hidl_memory& memory) {
297     return validatedConvert(memory);
298 }
299 
convert(const hidl_vec<hal::V1_2::Extension> & extensions)300 GeneralResult<std::vector<Extension>> convert(const hidl_vec<hal::V1_2::Extension>& extensions) {
301     return validatedConvert(extensions);
302 }
303 
convert(const hidl_vec<hidl_handle> & handles)304 GeneralResult<std::vector<SharedHandle>> convert(const hidl_vec<hidl_handle>& handles) {
305     return validatedConvert(handles);
306 }
307 
convert(const hidl_vec<hal::V1_2::OutputShape> & outputShapes)308 GeneralResult<std::vector<OutputShape>> convert(
309         const hidl_vec<hal::V1_2::OutputShape>& outputShapes) {
310     return validatedConvert(outputShapes);
311 }
312 
313 }  // namespace android::nn
314 
315 namespace android::hardware::neuralnetworks::V1_2::utils {
316 namespace {
317 
318 using utils::unvalidatedConvert;
319 
unvalidatedConvert(const nn::Operand::LifeTime & lifetime)320 nn::GeneralResult<V1_0::OperandLifeTime> unvalidatedConvert(const nn::Operand::LifeTime& lifetime) {
321     return V1_0::utils::unvalidatedConvert(lifetime);
322 }
323 
unvalidatedConvert(const nn::Capabilities::PerformanceInfo & performanceInfo)324 nn::GeneralResult<V1_0::PerformanceInfo> unvalidatedConvert(
325         const nn::Capabilities::PerformanceInfo& performanceInfo) {
326     return V1_0::utils::unvalidatedConvert(performanceInfo);
327 }
328 
unvalidatedConvert(const nn::DataLocation & location)329 nn::GeneralResult<V1_0::DataLocation> unvalidatedConvert(const nn::DataLocation& location) {
330     return V1_0::utils::unvalidatedConvert(location);
331 }
332 
unvalidatedConvert(const nn::Model::OperandValues & operandValues)333 nn::GeneralResult<hidl_vec<uint8_t>> unvalidatedConvert(
334         const nn::Model::OperandValues& operandValues) {
335     return V1_0::utils::unvalidatedConvert(operandValues);
336 }
337 
unvalidatedConvert(const nn::SharedMemory & memory)338 nn::GeneralResult<hidl_memory> unvalidatedConvert(const nn::SharedMemory& memory) {
339     return V1_0::utils::unvalidatedConvert(memory);
340 }
341 
342 template <typename Input>
343 using UnvalidatedConvertOutput =
344         std::decay_t<decltype(unvalidatedConvert(std::declval<Input>()).value())>;
345 
346 template <typename Type>
unvalidatedConvert(const std::vector<Type> & arguments)347 nn::GeneralResult<hidl_vec<UnvalidatedConvertOutput<Type>>> unvalidatedConvert(
348         const std::vector<Type>& arguments) {
349     hidl_vec<UnvalidatedConvertOutput<Type>> halObject(arguments.size());
350     for (size_t i = 0; i < arguments.size(); ++i) {
351         halObject[i] = NN_TRY(unvalidatedConvert(arguments[i]));
352     }
353     return halObject;
354 }
355 
makeExtraParams(nn::Operand::NoParams)356 nn::GeneralResult<Operand::ExtraParams> makeExtraParams(nn::Operand::NoParams /*noParams*/) {
357     return Operand::ExtraParams{};
358 }
359 
makeExtraParams(const nn::Operand::SymmPerChannelQuantParams & channelQuant)360 nn::GeneralResult<Operand::ExtraParams> makeExtraParams(
361         const nn::Operand::SymmPerChannelQuantParams& channelQuant) {
362     Operand::ExtraParams ret;
363     ret.channelQuant(NN_TRY(unvalidatedConvert(channelQuant)));
364     return ret;
365 }
366 
makeExtraParams(const nn::Operand::ExtensionParams & extension)367 nn::GeneralResult<Operand::ExtraParams> makeExtraParams(
368         const nn::Operand::ExtensionParams& extension) {
369     Operand::ExtraParams ret;
370     ret.extension(extension);
371     return ret;
372 }
373 
374 template <typename Type>
validatedConvert(const Type & canonical)375 nn::GeneralResult<UnvalidatedConvertOutput<Type>> validatedConvert(const Type& canonical) {
376     NN_TRY(compliantVersion(canonical));
377     return unvalidatedConvert(canonical);
378 }
379 
380 template <typename Type>
validatedConvert(const std::vector<Type> & arguments)381 nn::GeneralResult<hidl_vec<UnvalidatedConvertOutput<Type>>> validatedConvert(
382         const std::vector<Type>& arguments) {
383     hidl_vec<UnvalidatedConvertOutput<Type>> halObject(arguments.size());
384     for (size_t i = 0; i < arguments.size(); ++i) {
385         halObject[i] = NN_TRY(validatedConvert(arguments[i]));
386     }
387     return halObject;
388 }
389 
390 }  // anonymous namespace
391 
unvalidatedConvert(const nn::OperandType & operandType)392 nn::GeneralResult<OperandType> unvalidatedConvert(const nn::OperandType& operandType) {
393     return static_cast<OperandType>(operandType);
394 }
395 
unvalidatedConvert(const nn::OperationType & operationType)396 nn::GeneralResult<OperationType> unvalidatedConvert(const nn::OperationType& operationType) {
397     return static_cast<OperationType>(operationType);
398 }
399 
unvalidatedConvert(const nn::DeviceType & deviceType)400 nn::GeneralResult<DeviceType> unvalidatedConvert(const nn::DeviceType& deviceType) {
401     switch (deviceType) {
402         case nn::DeviceType::UNKNOWN:
403             return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE) << "Invalid DeviceType UNKNOWN";
404         case nn::DeviceType::OTHER:
405         case nn::DeviceType::CPU:
406         case nn::DeviceType::GPU:
407         case nn::DeviceType::ACCELERATOR:
408             return static_cast<DeviceType>(deviceType);
409     }
410     return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
411            << "Invalid DeviceType " << underlyingType(deviceType);
412 }
413 
unvalidatedConvert(const nn::Capabilities & capabilities)414 nn::GeneralResult<Capabilities> unvalidatedConvert(const nn::Capabilities& capabilities) {
415     std::vector<nn::Capabilities::OperandPerformance> operandPerformance;
416     operandPerformance.reserve(capabilities.operandPerformance.asVector().size());
417     std::copy_if(capabilities.operandPerformance.asVector().begin(),
418                  capabilities.operandPerformance.asVector().end(),
419                  std::back_inserter(operandPerformance),
420                  [](const nn::Capabilities::OperandPerformance& operandPerformance) {
421                      return compliantVersion(operandPerformance.type).has_value();
422                  });
423 
424     return Capabilities{
425             .relaxedFloat32toFloat16PerformanceScalar = NN_TRY(
426                     unvalidatedConvert(capabilities.relaxedFloat32toFloat16PerformanceScalar)),
427             .relaxedFloat32toFloat16PerformanceTensor = NN_TRY(
428                     unvalidatedConvert(capabilities.relaxedFloat32toFloat16PerformanceTensor)),
429             .operandPerformance = NN_TRY(unvalidatedConvert(operandPerformance)),
430     };
431 }
432 
unvalidatedConvert(const nn::Capabilities::OperandPerformance & operandPerformance)433 nn::GeneralResult<Capabilities::OperandPerformance> unvalidatedConvert(
434         const nn::Capabilities::OperandPerformance& operandPerformance) {
435     return Capabilities::OperandPerformance{
436             .type = NN_TRY(unvalidatedConvert(operandPerformance.type)),
437             .info = NN_TRY(unvalidatedConvert(operandPerformance.info)),
438     };
439 }
440 
unvalidatedConvert(const nn::Operation & operation)441 nn::GeneralResult<Operation> unvalidatedConvert(const nn::Operation& operation) {
442     return Operation{
443             .type = NN_TRY(unvalidatedConvert(operation.type)),
444             .inputs = operation.inputs,
445             .outputs = operation.outputs,
446     };
447 }
448 
unvalidatedConvert(const nn::Operand::SymmPerChannelQuantParams & symmPerChannelQuantParams)449 nn::GeneralResult<SymmPerChannelQuantParams> unvalidatedConvert(
450         const nn::Operand::SymmPerChannelQuantParams& symmPerChannelQuantParams) {
451     return SymmPerChannelQuantParams{
452             .scales = symmPerChannelQuantParams.scales,
453             .channelDim = symmPerChannelQuantParams.channelDim,
454     };
455 }
456 
unvalidatedConvert(const nn::Operand & operand)457 nn::GeneralResult<Operand> unvalidatedConvert(const nn::Operand& operand) {
458     return Operand{
459             .type = NN_TRY(unvalidatedConvert(operand.type)),
460             .dimensions = operand.dimensions,
461             .numberOfConsumers = 0,
462             .scale = operand.scale,
463             .zeroPoint = operand.zeroPoint,
464             .lifetime = NN_TRY(unvalidatedConvert(operand.lifetime)),
465             .location = NN_TRY(unvalidatedConvert(operand.location)),
466             .extraParams = NN_TRY(unvalidatedConvert(operand.extraParams)),
467     };
468 }
469 
unvalidatedConvert(const nn::Operand::ExtraParams & extraParams)470 nn::GeneralResult<Operand::ExtraParams> unvalidatedConvert(
471         const nn::Operand::ExtraParams& extraParams) {
472     return std::visit([](const auto& x) { return makeExtraParams(x); }, extraParams);
473 }
474 
unvalidatedConvert(const nn::Model & model)475 nn::GeneralResult<Model> unvalidatedConvert(const nn::Model& model) {
476     if (!hal::utils::hasNoPointerData(model)) {
477         return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT)
478                << "Model cannot be unvalidatedConverted because it contains pointer-based memory";
479     }
480 
481     auto operands = NN_TRY(unvalidatedConvert(model.main.operands));
482 
483     // Update number of consumers.
484     const auto numberOfConsumers =
485             NN_TRY(hal::utils::countNumberOfConsumers(operands.size(), model.main.operations));
486     CHECK(operands.size() == numberOfConsumers.size());
487     for (size_t i = 0; i < operands.size(); ++i) {
488         operands[i].numberOfConsumers = numberOfConsumers[i];
489     }
490 
491     return Model{
492             .operands = std::move(operands),
493             .operations = NN_TRY(unvalidatedConvert(model.main.operations)),
494             .inputIndexes = model.main.inputIndexes,
495             .outputIndexes = model.main.outputIndexes,
496             .operandValues = NN_TRY(unvalidatedConvert(model.operandValues)),
497             .pools = NN_TRY(unvalidatedConvert(model.pools)),
498             .relaxComputationFloat32toFloat16 = model.relaxComputationFloat32toFloat16,
499             .extensionNameToPrefix = NN_TRY(unvalidatedConvert(model.extensionNameToPrefix)),
500     };
501 }
502 
unvalidatedConvert(const nn::Model::ExtensionNameAndPrefix & extensionNameAndPrefix)503 nn::GeneralResult<Model::ExtensionNameAndPrefix> unvalidatedConvert(
504         const nn::Model::ExtensionNameAndPrefix& extensionNameAndPrefix) {
505     return Model::ExtensionNameAndPrefix{
506             .name = extensionNameAndPrefix.name,
507             .prefix = extensionNameAndPrefix.prefix,
508     };
509 }
510 
unvalidatedConvert(const nn::OutputShape & outputShape)511 nn::GeneralResult<OutputShape> unvalidatedConvert(const nn::OutputShape& outputShape) {
512     return OutputShape{.dimensions = outputShape.dimensions,
513                        .isSufficient = outputShape.isSufficient};
514 }
515 
unvalidatedConvert(const nn::MeasureTiming & measureTiming)516 nn::GeneralResult<MeasureTiming> unvalidatedConvert(const nn::MeasureTiming& measureTiming) {
517     return static_cast<MeasureTiming>(measureTiming);
518 }
519 
unvalidatedConvert(const nn::Timing & timing)520 nn::GeneralResult<Timing> unvalidatedConvert(const nn::Timing& timing) {
521     constexpr auto convertTiming = [](nn::OptionalDuration canonicalTiming) -> uint64_t {
522         constexpr uint64_t kNoTiming = std::numeric_limits<uint64_t>::max();
523         if (!canonicalTiming.has_value()) {
524             return kNoTiming;
525         }
526         return std::chrono::ceil<HalDuration>(*canonicalTiming).count();
527     };
528     return Timing{.timeOnDevice = convertTiming(timing.timeOnDevice),
529                   .timeInDriver = convertTiming(timing.timeInDriver)};
530 }
531 
unvalidatedConvert(const nn::Extension & extension)532 nn::GeneralResult<Extension> unvalidatedConvert(const nn::Extension& extension) {
533     return Extension{
534             .name = extension.name,
535             .operandTypes = NN_TRY(unvalidatedConvert(extension.operandTypes)),
536     };
537 }
538 
unvalidatedConvert(const nn::Extension::OperandTypeInformation & operandTypeInformation)539 nn::GeneralResult<Extension::OperandTypeInformation> unvalidatedConvert(
540         const nn::Extension::OperandTypeInformation& operandTypeInformation) {
541     return Extension::OperandTypeInformation{
542             .type = operandTypeInformation.type,
543             .isTensor = operandTypeInformation.isTensor,
544             .byteSize = operandTypeInformation.byteSize,
545     };
546 }
547 
unvalidatedConvert(const nn::SharedHandle & handle)548 nn::GeneralResult<hidl_handle> unvalidatedConvert(const nn::SharedHandle& handle) {
549     if (handle == nullptr) {
550         return {};
551     }
552     return hal::utils::hidlHandleFromSharedHandle(*handle);
553 }
554 
convert(const nn::DeviceType & deviceType)555 nn::GeneralResult<DeviceType> convert(const nn::DeviceType& deviceType) {
556     return validatedConvert(deviceType);
557 }
558 
convert(const nn::Capabilities & capabilities)559 nn::GeneralResult<Capabilities> convert(const nn::Capabilities& capabilities) {
560     return validatedConvert(capabilities);
561 }
562 
convert(const nn::Model & model)563 nn::GeneralResult<Model> convert(const nn::Model& model) {
564     return validatedConvert(model);
565 }
566 
convert(const nn::MeasureTiming & measureTiming)567 nn::GeneralResult<MeasureTiming> convert(const nn::MeasureTiming& measureTiming) {
568     return validatedConvert(measureTiming);
569 }
570 
convert(const nn::Timing & timing)571 nn::GeneralResult<Timing> convert(const nn::Timing& timing) {
572     return validatedConvert(timing);
573 }
574 
convert(const std::vector<nn::Extension> & extensions)575 nn::GeneralResult<hidl_vec<Extension>> convert(const std::vector<nn::Extension>& extensions) {
576     return validatedConvert(extensions);
577 }
578 
convert(const std::vector<nn::SharedHandle> & handles)579 nn::GeneralResult<hidl_vec<hidl_handle>> convert(const std::vector<nn::SharedHandle>& handles) {
580     return validatedConvert(handles);
581 }
582 
convert(const std::vector<nn::OutputShape> & outputShapes)583 nn::GeneralResult<hidl_vec<OutputShape>> convert(const std::vector<nn::OutputShape>& outputShapes) {
584     return validatedConvert(outputShapes);
585 }
586 
convert(const nn::DeviceStatus & deviceStatus)587 nn::GeneralResult<V1_0::DeviceStatus> convert(const nn::DeviceStatus& deviceStatus) {
588     return V1_1::utils::convert(deviceStatus);
589 }
590 
convert(const nn::Request & request)591 nn::GeneralResult<V1_0::Request> convert(const nn::Request& request) {
592     return V1_1::utils::convert(request);
593 }
594 
convert(const nn::ErrorStatus & status)595 nn::GeneralResult<V1_0::ErrorStatus> convert(const nn::ErrorStatus& status) {
596     return V1_1::utils::convert(status);
597 }
598 
convert(const nn::ExecutionPreference & executionPreference)599 nn::GeneralResult<V1_1::ExecutionPreference> convert(
600         const nn::ExecutionPreference& executionPreference) {
601     return V1_1::utils::convert(executionPreference);
602 }
603 
604 }  // namespace android::hardware::neuralnetworks::V1_2::utils
605