1 /*
2 * Copyright (C) 2020 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "Conversions.h"
18
19 #include <android-base/logging.h>
20 #include <android/hardware/neuralnetworks/1.2/types.h>
21 #include <nnapi/OperandTypes.h>
22 #include <nnapi/OperationTypes.h>
23 #include <nnapi/Result.h>
24 #include <nnapi/SharedMemory.h>
25 #include <nnapi/TypeUtils.h>
26 #include <nnapi/Types.h>
27 #include <nnapi/Validation.h>
28 #include <nnapi/hal/1.0/Conversions.h>
29 #include <nnapi/hal/1.1/Conversions.h>
30 #include <nnapi/hal/CommonUtils.h>
31 #include <nnapi/hal/HandleError.h>
32
33 #include <algorithm>
34 #include <functional>
35 #include <iterator>
36 #include <memory>
37 #include <type_traits>
38 #include <utility>
39
40 #include "Utils.h"
41
42 namespace {
43
44 template <typename Type>
underlyingType(Type value)45 constexpr std::underlying_type_t<Type> underlyingType(Type value) {
46 return static_cast<std::underlying_type_t<Type>>(value);
47 }
48
49 using HalDuration = std::chrono::duration<uint64_t, std::micro>;
50
51 } // namespace
52
53 namespace android::nn {
54 namespace {
55
56 using hardware::hidl_handle;
57 using hardware::hidl_vec;
58
59 template <typename Input>
60 using UnvalidatedConvertOutput =
61 std::decay_t<decltype(unvalidatedConvert(std::declval<Input>()).value())>;
62
63 template <typename Type>
unvalidatedConvert(const hidl_vec<Type> & arguments)64 GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> unvalidatedConvert(
65 const hidl_vec<Type>& arguments) {
66 std::vector<UnvalidatedConvertOutput<Type>> canonical;
67 canonical.reserve(arguments.size());
68 for (const auto& argument : arguments) {
69 canonical.push_back(NN_TRY(nn::unvalidatedConvert(argument)));
70 }
71 return canonical;
72 }
73
74 template <typename Type>
validatedConvert(const Type & halObject)75 GeneralResult<UnvalidatedConvertOutput<Type>> validatedConvert(const Type& halObject) {
76 auto canonical = NN_TRY(nn::unvalidatedConvert(halObject));
77 NN_TRY(hal::V1_2::utils::compliantVersion(canonical));
78 return canonical;
79 }
80
81 template <typename Type>
validatedConvert(const hidl_vec<Type> & arguments)82 GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> validatedConvert(
83 const hidl_vec<Type>& arguments) {
84 std::vector<UnvalidatedConvertOutput<Type>> canonical;
85 canonical.reserve(arguments.size());
86 for (const auto& argument : arguments) {
87 canonical.push_back(NN_TRY(validatedConvert(argument)));
88 }
89 return canonical;
90 }
91
92 } // anonymous namespace
93
unvalidatedConvert(const hal::V1_2::OperandType & operandType)94 GeneralResult<OperandType> unvalidatedConvert(const hal::V1_2::OperandType& operandType) {
95 return static_cast<OperandType>(operandType);
96 }
97
unvalidatedConvert(const hal::V1_2::OperationType & operationType)98 GeneralResult<OperationType> unvalidatedConvert(const hal::V1_2::OperationType& operationType) {
99 return static_cast<OperationType>(operationType);
100 }
101
unvalidatedConvert(const hal::V1_2::DeviceType & deviceType)102 GeneralResult<DeviceType> unvalidatedConvert(const hal::V1_2::DeviceType& deviceType) {
103 return static_cast<DeviceType>(deviceType);
104 }
105
unvalidatedConvert(const hal::V1_2::Capabilities & capabilities)106 GeneralResult<Capabilities> unvalidatedConvert(const hal::V1_2::Capabilities& capabilities) {
107 const bool validOperandTypes = std::all_of(
108 capabilities.operandPerformance.begin(), capabilities.operandPerformance.end(),
109 [](const hal::V1_2::Capabilities::OperandPerformance& operandPerformance) {
110 return validatedConvert(operandPerformance.type).has_value();
111 });
112 if (!validOperandTypes) {
113 return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
114 << "Invalid OperandType when converting OperandPerformance in Capabilities";
115 }
116
117 const auto relaxedFloat32toFloat16PerformanceScalar =
118 NN_TRY(unvalidatedConvert(capabilities.relaxedFloat32toFloat16PerformanceScalar));
119 const auto relaxedFloat32toFloat16PerformanceTensor =
120 NN_TRY(unvalidatedConvert(capabilities.relaxedFloat32toFloat16PerformanceTensor));
121 auto operandPerformance = NN_TRY(unvalidatedConvert(capabilities.operandPerformance));
122
123 auto table = NN_TRY(hal::utils::makeGeneralFailure(
124 Capabilities::OperandPerformanceTable::create(std::move(operandPerformance)),
125 nn::ErrorStatus::GENERAL_FAILURE));
126
127 return Capabilities{
128 .relaxedFloat32toFloat16PerformanceScalar = relaxedFloat32toFloat16PerformanceScalar,
129 .relaxedFloat32toFloat16PerformanceTensor = relaxedFloat32toFloat16PerformanceTensor,
130 .operandPerformance = std::move(table),
131 };
132 }
133
unvalidatedConvert(const hal::V1_2::Capabilities::OperandPerformance & operandPerformance)134 GeneralResult<Capabilities::OperandPerformance> unvalidatedConvert(
135 const hal::V1_2::Capabilities::OperandPerformance& operandPerformance) {
136 return Capabilities::OperandPerformance{
137 .type = NN_TRY(unvalidatedConvert(operandPerformance.type)),
138 .info = NN_TRY(unvalidatedConvert(operandPerformance.info)),
139 };
140 }
141
unvalidatedConvert(const hal::V1_2::Operation & operation)142 GeneralResult<Operation> unvalidatedConvert(const hal::V1_2::Operation& operation) {
143 return Operation{
144 .type = NN_TRY(unvalidatedConvert(operation.type)),
145 .inputs = operation.inputs,
146 .outputs = operation.outputs,
147 };
148 }
149
unvalidatedConvert(const hal::V1_2::SymmPerChannelQuantParams & symmPerChannelQuantParams)150 GeneralResult<Operand::SymmPerChannelQuantParams> unvalidatedConvert(
151 const hal::V1_2::SymmPerChannelQuantParams& symmPerChannelQuantParams) {
152 return Operand::SymmPerChannelQuantParams{
153 .scales = symmPerChannelQuantParams.scales,
154 .channelDim = symmPerChannelQuantParams.channelDim,
155 };
156 }
157
unvalidatedConvert(const hal::V1_2::Operand & operand)158 GeneralResult<Operand> unvalidatedConvert(const hal::V1_2::Operand& operand) {
159 return Operand{
160 .type = NN_TRY(unvalidatedConvert(operand.type)),
161 .dimensions = operand.dimensions,
162 .scale = operand.scale,
163 .zeroPoint = operand.zeroPoint,
164 .lifetime = NN_TRY(unvalidatedConvert(operand.lifetime)),
165 .location = NN_TRY(unvalidatedConvert(operand.location)),
166 .extraParams = NN_TRY(unvalidatedConvert(operand.extraParams)),
167 };
168 }
169
unvalidatedConvert(const hal::V1_2::Operand::ExtraParams & extraParams)170 GeneralResult<Operand::ExtraParams> unvalidatedConvert(
171 const hal::V1_2::Operand::ExtraParams& extraParams) {
172 using Discriminator = hal::V1_2::Operand::ExtraParams::hidl_discriminator;
173 switch (extraParams.getDiscriminator()) {
174 case Discriminator::none:
175 return Operand::NoParams{};
176 case Discriminator::channelQuant:
177 return unvalidatedConvert(extraParams.channelQuant());
178 case Discriminator::extension:
179 return extraParams.extension();
180 }
181 return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
182 << "Unrecognized Operand::ExtraParams discriminator: "
183 << underlyingType(extraParams.getDiscriminator());
184 }
185
unvalidatedConvert(const hal::V1_2::Model & model)186 GeneralResult<Model> unvalidatedConvert(const hal::V1_2::Model& model) {
187 auto operations = NN_TRY(unvalidatedConvert(model.operations));
188
189 // Verify number of consumers.
190 const auto numberOfConsumers =
191 NN_TRY(hal::utils::countNumberOfConsumers(model.operands.size(), operations));
192 CHECK(model.operands.size() == numberOfConsumers.size());
193 for (size_t i = 0; i < model.operands.size(); ++i) {
194 if (model.operands[i].numberOfConsumers != numberOfConsumers[i]) {
195 return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
196 << "Invalid numberOfConsumers for operand " << i << ", expected "
197 << numberOfConsumers[i] << " but found " << model.operands[i].numberOfConsumers;
198 }
199 }
200
201 auto main = Model::Subgraph{
202 .operands = NN_TRY(unvalidatedConvert(model.operands)),
203 .operations = std::move(operations),
204 .inputIndexes = model.inputIndexes,
205 .outputIndexes = model.outputIndexes,
206 };
207
208 return Model{
209 .main = std::move(main),
210 .operandValues = NN_TRY(unvalidatedConvert(model.operandValues)),
211 .pools = NN_TRY(unvalidatedConvert(model.pools)),
212 .relaxComputationFloat32toFloat16 = model.relaxComputationFloat32toFloat16,
213 .extensionNameToPrefix = NN_TRY(unvalidatedConvert(model.extensionNameToPrefix)),
214 };
215 }
216
unvalidatedConvert(const hal::V1_2::Model::ExtensionNameAndPrefix & extensionNameAndPrefix)217 GeneralResult<Model::ExtensionNameAndPrefix> unvalidatedConvert(
218 const hal::V1_2::Model::ExtensionNameAndPrefix& extensionNameAndPrefix) {
219 return Model::ExtensionNameAndPrefix{
220 .name = extensionNameAndPrefix.name,
221 .prefix = extensionNameAndPrefix.prefix,
222 };
223 }
224
unvalidatedConvert(const hal::V1_2::OutputShape & outputShape)225 GeneralResult<OutputShape> unvalidatedConvert(const hal::V1_2::OutputShape& outputShape) {
226 return OutputShape{
227 .dimensions = outputShape.dimensions,
228 .isSufficient = outputShape.isSufficient,
229 };
230 }
231
unvalidatedConvert(const hal::V1_2::MeasureTiming & measureTiming)232 GeneralResult<MeasureTiming> unvalidatedConvert(const hal::V1_2::MeasureTiming& measureTiming) {
233 return static_cast<MeasureTiming>(measureTiming);
234 }
235
unvalidatedConvert(const hal::V1_2::Timing & timing)236 GeneralResult<Timing> unvalidatedConvert(const hal::V1_2::Timing& timing) {
237 constexpr uint64_t kMaxTiming = std::chrono::floor<HalDuration>(Duration::max()).count();
238 constexpr auto convertTiming = [](uint64_t halTiming) -> OptionalDuration {
239 constexpr uint64_t kNoTiming = std::numeric_limits<uint64_t>::max();
240 if (halTiming == kNoTiming) {
241 return {};
242 }
243 if (halTiming > kMaxTiming) {
244 return Duration::max();
245 }
246 return HalDuration{halTiming};
247 };
248 return Timing{.timeOnDevice = convertTiming(timing.timeOnDevice),
249 .timeInDriver = convertTiming(timing.timeInDriver)};
250 }
251
unvalidatedConvert(const hal::V1_2::Extension & extension)252 GeneralResult<Extension> unvalidatedConvert(const hal::V1_2::Extension& extension) {
253 return Extension{
254 .name = extension.name,
255 .operandTypes = NN_TRY(unvalidatedConvert(extension.operandTypes)),
256 };
257 }
258
unvalidatedConvert(const hal::V1_2::Extension::OperandTypeInformation & operandTypeInformation)259 GeneralResult<Extension::OperandTypeInformation> unvalidatedConvert(
260 const hal::V1_2::Extension::OperandTypeInformation& operandTypeInformation) {
261 return Extension::OperandTypeInformation{
262 .type = operandTypeInformation.type,
263 .isTensor = operandTypeInformation.isTensor,
264 .byteSize = operandTypeInformation.byteSize,
265 };
266 }
267
unvalidatedConvert(const hidl_handle & hidlHandle)268 GeneralResult<SharedHandle> unvalidatedConvert(const hidl_handle& hidlHandle) {
269 if (hidlHandle.getNativeHandle() == nullptr) {
270 return nullptr;
271 }
272 auto handle = NN_TRY(hal::utils::sharedHandleFromNativeHandle(hidlHandle.getNativeHandle()));
273 return std::make_shared<const Handle>(std::move(handle));
274 }
275
convert(const hal::V1_2::DeviceType & deviceType)276 GeneralResult<DeviceType> convert(const hal::V1_2::DeviceType& deviceType) {
277 return validatedConvert(deviceType);
278 }
279
convert(const hal::V1_2::Capabilities & capabilities)280 GeneralResult<Capabilities> convert(const hal::V1_2::Capabilities& capabilities) {
281 return validatedConvert(capabilities);
282 }
283
convert(const hal::V1_2::Model & model)284 GeneralResult<Model> convert(const hal::V1_2::Model& model) {
285 return validatedConvert(model);
286 }
287
convert(const hal::V1_2::MeasureTiming & measureTiming)288 GeneralResult<MeasureTiming> convert(const hal::V1_2::MeasureTiming& measureTiming) {
289 return validatedConvert(measureTiming);
290 }
291
convert(const hal::V1_2::Timing & timing)292 GeneralResult<Timing> convert(const hal::V1_2::Timing& timing) {
293 return validatedConvert(timing);
294 }
295
convert(const hardware::hidl_memory & memory)296 GeneralResult<SharedMemory> convert(const hardware::hidl_memory& memory) {
297 return validatedConvert(memory);
298 }
299
convert(const hidl_vec<hal::V1_2::Extension> & extensions)300 GeneralResult<std::vector<Extension>> convert(const hidl_vec<hal::V1_2::Extension>& extensions) {
301 return validatedConvert(extensions);
302 }
303
convert(const hidl_vec<hidl_handle> & handles)304 GeneralResult<std::vector<SharedHandle>> convert(const hidl_vec<hidl_handle>& handles) {
305 return validatedConvert(handles);
306 }
307
convert(const hidl_vec<hal::V1_2::OutputShape> & outputShapes)308 GeneralResult<std::vector<OutputShape>> convert(
309 const hidl_vec<hal::V1_2::OutputShape>& outputShapes) {
310 return validatedConvert(outputShapes);
311 }
312
313 } // namespace android::nn
314
315 namespace android::hardware::neuralnetworks::V1_2::utils {
316 namespace {
317
318 using utils::unvalidatedConvert;
319
unvalidatedConvert(const nn::Operand::LifeTime & lifetime)320 nn::GeneralResult<V1_0::OperandLifeTime> unvalidatedConvert(const nn::Operand::LifeTime& lifetime) {
321 return V1_0::utils::unvalidatedConvert(lifetime);
322 }
323
unvalidatedConvert(const nn::Capabilities::PerformanceInfo & performanceInfo)324 nn::GeneralResult<V1_0::PerformanceInfo> unvalidatedConvert(
325 const nn::Capabilities::PerformanceInfo& performanceInfo) {
326 return V1_0::utils::unvalidatedConvert(performanceInfo);
327 }
328
unvalidatedConvert(const nn::DataLocation & location)329 nn::GeneralResult<V1_0::DataLocation> unvalidatedConvert(const nn::DataLocation& location) {
330 return V1_0::utils::unvalidatedConvert(location);
331 }
332
unvalidatedConvert(const nn::Model::OperandValues & operandValues)333 nn::GeneralResult<hidl_vec<uint8_t>> unvalidatedConvert(
334 const nn::Model::OperandValues& operandValues) {
335 return V1_0::utils::unvalidatedConvert(operandValues);
336 }
337
unvalidatedConvert(const nn::SharedMemory & memory)338 nn::GeneralResult<hidl_memory> unvalidatedConvert(const nn::SharedMemory& memory) {
339 return V1_0::utils::unvalidatedConvert(memory);
340 }
341
342 template <typename Input>
343 using UnvalidatedConvertOutput =
344 std::decay_t<decltype(unvalidatedConvert(std::declval<Input>()).value())>;
345
346 template <typename Type>
unvalidatedConvert(const std::vector<Type> & arguments)347 nn::GeneralResult<hidl_vec<UnvalidatedConvertOutput<Type>>> unvalidatedConvert(
348 const std::vector<Type>& arguments) {
349 hidl_vec<UnvalidatedConvertOutput<Type>> halObject(arguments.size());
350 for (size_t i = 0; i < arguments.size(); ++i) {
351 halObject[i] = NN_TRY(unvalidatedConvert(arguments[i]));
352 }
353 return halObject;
354 }
355
makeExtraParams(nn::Operand::NoParams)356 nn::GeneralResult<Operand::ExtraParams> makeExtraParams(nn::Operand::NoParams /*noParams*/) {
357 return Operand::ExtraParams{};
358 }
359
makeExtraParams(const nn::Operand::SymmPerChannelQuantParams & channelQuant)360 nn::GeneralResult<Operand::ExtraParams> makeExtraParams(
361 const nn::Operand::SymmPerChannelQuantParams& channelQuant) {
362 Operand::ExtraParams ret;
363 ret.channelQuant(NN_TRY(unvalidatedConvert(channelQuant)));
364 return ret;
365 }
366
makeExtraParams(const nn::Operand::ExtensionParams & extension)367 nn::GeneralResult<Operand::ExtraParams> makeExtraParams(
368 const nn::Operand::ExtensionParams& extension) {
369 Operand::ExtraParams ret;
370 ret.extension(extension);
371 return ret;
372 }
373
374 template <typename Type>
validatedConvert(const Type & canonical)375 nn::GeneralResult<UnvalidatedConvertOutput<Type>> validatedConvert(const Type& canonical) {
376 NN_TRY(compliantVersion(canonical));
377 return unvalidatedConvert(canonical);
378 }
379
380 template <typename Type>
validatedConvert(const std::vector<Type> & arguments)381 nn::GeneralResult<hidl_vec<UnvalidatedConvertOutput<Type>>> validatedConvert(
382 const std::vector<Type>& arguments) {
383 hidl_vec<UnvalidatedConvertOutput<Type>> halObject(arguments.size());
384 for (size_t i = 0; i < arguments.size(); ++i) {
385 halObject[i] = NN_TRY(validatedConvert(arguments[i]));
386 }
387 return halObject;
388 }
389
390 } // anonymous namespace
391
unvalidatedConvert(const nn::OperandType & operandType)392 nn::GeneralResult<OperandType> unvalidatedConvert(const nn::OperandType& operandType) {
393 return static_cast<OperandType>(operandType);
394 }
395
unvalidatedConvert(const nn::OperationType & operationType)396 nn::GeneralResult<OperationType> unvalidatedConvert(const nn::OperationType& operationType) {
397 return static_cast<OperationType>(operationType);
398 }
399
unvalidatedConvert(const nn::DeviceType & deviceType)400 nn::GeneralResult<DeviceType> unvalidatedConvert(const nn::DeviceType& deviceType) {
401 switch (deviceType) {
402 case nn::DeviceType::UNKNOWN:
403 return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE) << "Invalid DeviceType UNKNOWN";
404 case nn::DeviceType::OTHER:
405 case nn::DeviceType::CPU:
406 case nn::DeviceType::GPU:
407 case nn::DeviceType::ACCELERATOR:
408 return static_cast<DeviceType>(deviceType);
409 }
410 return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
411 << "Invalid DeviceType " << underlyingType(deviceType);
412 }
413
unvalidatedConvert(const nn::Capabilities & capabilities)414 nn::GeneralResult<Capabilities> unvalidatedConvert(const nn::Capabilities& capabilities) {
415 std::vector<nn::Capabilities::OperandPerformance> operandPerformance;
416 operandPerformance.reserve(capabilities.operandPerformance.asVector().size());
417 std::copy_if(capabilities.operandPerformance.asVector().begin(),
418 capabilities.operandPerformance.asVector().end(),
419 std::back_inserter(operandPerformance),
420 [](const nn::Capabilities::OperandPerformance& operandPerformance) {
421 return compliantVersion(operandPerformance.type).has_value();
422 });
423
424 return Capabilities{
425 .relaxedFloat32toFloat16PerformanceScalar = NN_TRY(
426 unvalidatedConvert(capabilities.relaxedFloat32toFloat16PerformanceScalar)),
427 .relaxedFloat32toFloat16PerformanceTensor = NN_TRY(
428 unvalidatedConvert(capabilities.relaxedFloat32toFloat16PerformanceTensor)),
429 .operandPerformance = NN_TRY(unvalidatedConvert(operandPerformance)),
430 };
431 }
432
unvalidatedConvert(const nn::Capabilities::OperandPerformance & operandPerformance)433 nn::GeneralResult<Capabilities::OperandPerformance> unvalidatedConvert(
434 const nn::Capabilities::OperandPerformance& operandPerformance) {
435 return Capabilities::OperandPerformance{
436 .type = NN_TRY(unvalidatedConvert(operandPerformance.type)),
437 .info = NN_TRY(unvalidatedConvert(operandPerformance.info)),
438 };
439 }
440
unvalidatedConvert(const nn::Operation & operation)441 nn::GeneralResult<Operation> unvalidatedConvert(const nn::Operation& operation) {
442 return Operation{
443 .type = NN_TRY(unvalidatedConvert(operation.type)),
444 .inputs = operation.inputs,
445 .outputs = operation.outputs,
446 };
447 }
448
unvalidatedConvert(const nn::Operand::SymmPerChannelQuantParams & symmPerChannelQuantParams)449 nn::GeneralResult<SymmPerChannelQuantParams> unvalidatedConvert(
450 const nn::Operand::SymmPerChannelQuantParams& symmPerChannelQuantParams) {
451 return SymmPerChannelQuantParams{
452 .scales = symmPerChannelQuantParams.scales,
453 .channelDim = symmPerChannelQuantParams.channelDim,
454 };
455 }
456
unvalidatedConvert(const nn::Operand & operand)457 nn::GeneralResult<Operand> unvalidatedConvert(const nn::Operand& operand) {
458 return Operand{
459 .type = NN_TRY(unvalidatedConvert(operand.type)),
460 .dimensions = operand.dimensions,
461 .numberOfConsumers = 0,
462 .scale = operand.scale,
463 .zeroPoint = operand.zeroPoint,
464 .lifetime = NN_TRY(unvalidatedConvert(operand.lifetime)),
465 .location = NN_TRY(unvalidatedConvert(operand.location)),
466 .extraParams = NN_TRY(unvalidatedConvert(operand.extraParams)),
467 };
468 }
469
unvalidatedConvert(const nn::Operand::ExtraParams & extraParams)470 nn::GeneralResult<Operand::ExtraParams> unvalidatedConvert(
471 const nn::Operand::ExtraParams& extraParams) {
472 return std::visit([](const auto& x) { return makeExtraParams(x); }, extraParams);
473 }
474
unvalidatedConvert(const nn::Model & model)475 nn::GeneralResult<Model> unvalidatedConvert(const nn::Model& model) {
476 if (!hal::utils::hasNoPointerData(model)) {
477 return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT)
478 << "Model cannot be unvalidatedConverted because it contains pointer-based memory";
479 }
480
481 auto operands = NN_TRY(unvalidatedConvert(model.main.operands));
482
483 // Update number of consumers.
484 const auto numberOfConsumers =
485 NN_TRY(hal::utils::countNumberOfConsumers(operands.size(), model.main.operations));
486 CHECK(operands.size() == numberOfConsumers.size());
487 for (size_t i = 0; i < operands.size(); ++i) {
488 operands[i].numberOfConsumers = numberOfConsumers[i];
489 }
490
491 return Model{
492 .operands = std::move(operands),
493 .operations = NN_TRY(unvalidatedConvert(model.main.operations)),
494 .inputIndexes = model.main.inputIndexes,
495 .outputIndexes = model.main.outputIndexes,
496 .operandValues = NN_TRY(unvalidatedConvert(model.operandValues)),
497 .pools = NN_TRY(unvalidatedConvert(model.pools)),
498 .relaxComputationFloat32toFloat16 = model.relaxComputationFloat32toFloat16,
499 .extensionNameToPrefix = NN_TRY(unvalidatedConvert(model.extensionNameToPrefix)),
500 };
501 }
502
unvalidatedConvert(const nn::Model::ExtensionNameAndPrefix & extensionNameAndPrefix)503 nn::GeneralResult<Model::ExtensionNameAndPrefix> unvalidatedConvert(
504 const nn::Model::ExtensionNameAndPrefix& extensionNameAndPrefix) {
505 return Model::ExtensionNameAndPrefix{
506 .name = extensionNameAndPrefix.name,
507 .prefix = extensionNameAndPrefix.prefix,
508 };
509 }
510
unvalidatedConvert(const nn::OutputShape & outputShape)511 nn::GeneralResult<OutputShape> unvalidatedConvert(const nn::OutputShape& outputShape) {
512 return OutputShape{.dimensions = outputShape.dimensions,
513 .isSufficient = outputShape.isSufficient};
514 }
515
unvalidatedConvert(const nn::MeasureTiming & measureTiming)516 nn::GeneralResult<MeasureTiming> unvalidatedConvert(const nn::MeasureTiming& measureTiming) {
517 return static_cast<MeasureTiming>(measureTiming);
518 }
519
unvalidatedConvert(const nn::Timing & timing)520 nn::GeneralResult<Timing> unvalidatedConvert(const nn::Timing& timing) {
521 constexpr auto convertTiming = [](nn::OptionalDuration canonicalTiming) -> uint64_t {
522 constexpr uint64_t kNoTiming = std::numeric_limits<uint64_t>::max();
523 if (!canonicalTiming.has_value()) {
524 return kNoTiming;
525 }
526 return std::chrono::ceil<HalDuration>(*canonicalTiming).count();
527 };
528 return Timing{.timeOnDevice = convertTiming(timing.timeOnDevice),
529 .timeInDriver = convertTiming(timing.timeInDriver)};
530 }
531
unvalidatedConvert(const nn::Extension & extension)532 nn::GeneralResult<Extension> unvalidatedConvert(const nn::Extension& extension) {
533 return Extension{
534 .name = extension.name,
535 .operandTypes = NN_TRY(unvalidatedConvert(extension.operandTypes)),
536 };
537 }
538
unvalidatedConvert(const nn::Extension::OperandTypeInformation & operandTypeInformation)539 nn::GeneralResult<Extension::OperandTypeInformation> unvalidatedConvert(
540 const nn::Extension::OperandTypeInformation& operandTypeInformation) {
541 return Extension::OperandTypeInformation{
542 .type = operandTypeInformation.type,
543 .isTensor = operandTypeInformation.isTensor,
544 .byteSize = operandTypeInformation.byteSize,
545 };
546 }
547
unvalidatedConvert(const nn::SharedHandle & handle)548 nn::GeneralResult<hidl_handle> unvalidatedConvert(const nn::SharedHandle& handle) {
549 if (handle == nullptr) {
550 return {};
551 }
552 return hal::utils::hidlHandleFromSharedHandle(*handle);
553 }
554
convert(const nn::DeviceType & deviceType)555 nn::GeneralResult<DeviceType> convert(const nn::DeviceType& deviceType) {
556 return validatedConvert(deviceType);
557 }
558
convert(const nn::Capabilities & capabilities)559 nn::GeneralResult<Capabilities> convert(const nn::Capabilities& capabilities) {
560 return validatedConvert(capabilities);
561 }
562
convert(const nn::Model & model)563 nn::GeneralResult<Model> convert(const nn::Model& model) {
564 return validatedConvert(model);
565 }
566
convert(const nn::MeasureTiming & measureTiming)567 nn::GeneralResult<MeasureTiming> convert(const nn::MeasureTiming& measureTiming) {
568 return validatedConvert(measureTiming);
569 }
570
convert(const nn::Timing & timing)571 nn::GeneralResult<Timing> convert(const nn::Timing& timing) {
572 return validatedConvert(timing);
573 }
574
convert(const std::vector<nn::Extension> & extensions)575 nn::GeneralResult<hidl_vec<Extension>> convert(const std::vector<nn::Extension>& extensions) {
576 return validatedConvert(extensions);
577 }
578
convert(const std::vector<nn::SharedHandle> & handles)579 nn::GeneralResult<hidl_vec<hidl_handle>> convert(const std::vector<nn::SharedHandle>& handles) {
580 return validatedConvert(handles);
581 }
582
convert(const std::vector<nn::OutputShape> & outputShapes)583 nn::GeneralResult<hidl_vec<OutputShape>> convert(const std::vector<nn::OutputShape>& outputShapes) {
584 return validatedConvert(outputShapes);
585 }
586
convert(const nn::DeviceStatus & deviceStatus)587 nn::GeneralResult<V1_0::DeviceStatus> convert(const nn::DeviceStatus& deviceStatus) {
588 return V1_1::utils::convert(deviceStatus);
589 }
590
convert(const nn::Request & request)591 nn::GeneralResult<V1_0::Request> convert(const nn::Request& request) {
592 return V1_1::utils::convert(request);
593 }
594
convert(const nn::ErrorStatus & status)595 nn::GeneralResult<V1_0::ErrorStatus> convert(const nn::ErrorStatus& status) {
596 return V1_1::utils::convert(status);
597 }
598
convert(const nn::ExecutionPreference & executionPreference)599 nn::GeneralResult<V1_1::ExecutionPreference> convert(
600 const nn::ExecutionPreference& executionPreference) {
601 return V1_1::utils::convert(executionPreference);
602 }
603
604 } // namespace android::hardware::neuralnetworks::V1_2::utils
605