/* * Copyright (C) 2020 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "Conversions.h" #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "Utils.h" namespace { template constexpr std::underlying_type_t underlyingType(Type value) { return static_cast>(value); } } // namespace namespace android::nn { namespace { using hardware::hidl_memory; using hardware::hidl_vec; template using UnvalidatedConvertOutput = std::decay_t()).value())>; template GeneralResult>> unvalidatedConvert( const hidl_vec& arguments) { std::vector> canonical; canonical.reserve(arguments.size()); for (const auto& argument : arguments) { canonical.push_back(NN_TRY(nn::unvalidatedConvert(argument))); } return canonical; } template GeneralResult> validatedConvert(const Type& halObject) { auto canonical = NN_TRY(nn::unvalidatedConvert(halObject)); NN_TRY(hal::V1_0::utils::compliantVersion(canonical)); return canonical; } } // anonymous namespace GeneralResult unvalidatedConvert(const hal::V1_0::OperandType& operandType) { return static_cast(operandType); } GeneralResult unvalidatedConvert(const hal::V1_0::OperationType& operationType) { return static_cast(operationType); } GeneralResult unvalidatedConvert(const hal::V1_0::OperandLifeTime& lifetime) { return static_cast(lifetime); } GeneralResult unvalidatedConvert(const hal::V1_0::DeviceStatus& deviceStatus) { return static_cast(deviceStatus); } GeneralResult unvalidatedConvert( const hal::V1_0::PerformanceInfo& performanceInfo) { return Capabilities::PerformanceInfo{ .execTime = performanceInfo.execTime, .powerUsage = performanceInfo.powerUsage, }; } GeneralResult unvalidatedConvert(const hal::V1_0::Capabilities& capabilities) { const auto quantized8Performance = NN_TRY(unvalidatedConvert(capabilities.quantized8Performance)); const auto float32Performance = NN_TRY(unvalidatedConvert(capabilities.float32Performance)); auto table = hal::utils::makeQuantized8PerformanceConsistentWithP(float32Performance, quantized8Performance); return Capabilities{ .relaxedFloat32toFloat16PerformanceScalar = float32Performance, .relaxedFloat32toFloat16PerformanceTensor = float32Performance, .operandPerformance = std::move(table), }; } GeneralResult unvalidatedConvert(const hal::V1_0::DataLocation& location) { return DataLocation{ .poolIndex = location.poolIndex, .offset = location.offset, .length = location.length, }; } GeneralResult unvalidatedConvert(const hal::V1_0::Operand& operand) { return Operand{ .type = NN_TRY(unvalidatedConvert(operand.type)), .dimensions = operand.dimensions, .scale = operand.scale, .zeroPoint = operand.zeroPoint, .lifetime = NN_TRY(unvalidatedConvert(operand.lifetime)), .location = NN_TRY(unvalidatedConvert(operand.location)), }; } GeneralResult unvalidatedConvert(const hal::V1_0::Operation& operation) { return Operation{ .type = NN_TRY(unvalidatedConvert(operation.type)), .inputs = operation.inputs, .outputs = operation.outputs, }; } GeneralResult unvalidatedConvert(const hidl_vec& operandValues) { return Model::OperandValues(operandValues.data(), operandValues.size()); } GeneralResult unvalidatedConvert(const hidl_memory& memory) { return hal::utils::createSharedMemoryFromHidlMemory(memory); } GeneralResult unvalidatedConvert(const hal::V1_0::Model& model) { auto operations = NN_TRY(unvalidatedConvert(model.operations)); // Verify number of consumers. const auto numberOfConsumers = NN_TRY(hal::utils::countNumberOfConsumers(model.operands.size(), operations)); CHECK(model.operands.size() == numberOfConsumers.size()); for (size_t i = 0; i < model.operands.size(); ++i) { if (model.operands[i].numberOfConsumers != numberOfConsumers[i]) { return NN_ERROR(ErrorStatus::GENERAL_FAILURE) << "Invalid numberOfConsumers for operand " << i << ", expected " << numberOfConsumers[i] << " but found " << model.operands[i].numberOfConsumers; } } auto main = Model::Subgraph{ .operands = NN_TRY(unvalidatedConvert(model.operands)), .operations = std::move(operations), .inputIndexes = model.inputIndexes, .outputIndexes = model.outputIndexes, }; return Model{ .main = std::move(main), .operandValues = NN_TRY(unvalidatedConvert(model.operandValues)), .pools = NN_TRY(unvalidatedConvert(model.pools)), }; } GeneralResult unvalidatedConvert(const hal::V1_0::RequestArgument& argument) { const auto lifetime = argument.hasNoValue ? Request::Argument::LifeTime::NO_VALUE : Request::Argument::LifeTime::POOL; return Request::Argument{ .lifetime = lifetime, .location = NN_TRY(unvalidatedConvert(argument.location)), .dimensions = argument.dimensions, }; } GeneralResult unvalidatedConvert(const hal::V1_0::Request& request) { auto memories = NN_TRY(unvalidatedConvert(request.pools)); std::vector pools; pools.reserve(memories.size()); std::move(memories.begin(), memories.end(), std::back_inserter(pools)); return Request{ .inputs = NN_TRY(unvalidatedConvert(request.inputs)), .outputs = NN_TRY(unvalidatedConvert(request.outputs)), .pools = std::move(pools), }; } GeneralResult unvalidatedConvert(const hal::V1_0::ErrorStatus& status) { switch (status) { case hal::V1_0::ErrorStatus::NONE: case hal::V1_0::ErrorStatus::DEVICE_UNAVAILABLE: case hal::V1_0::ErrorStatus::GENERAL_FAILURE: case hal::V1_0::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE: case hal::V1_0::ErrorStatus::INVALID_ARGUMENT: return static_cast(status); } return NN_ERROR(ErrorStatus::GENERAL_FAILURE) << "Invalid ErrorStatus " << underlyingType(status); } GeneralResult convert(const hal::V1_0::DeviceStatus& deviceStatus) { return validatedConvert(deviceStatus); } GeneralResult convert(const hal::V1_0::Capabilities& capabilities) { return validatedConvert(capabilities); } GeneralResult convert(const hal::V1_0::Model& model) { return validatedConvert(model); } GeneralResult convert(const hal::V1_0::Request& request) { return validatedConvert(request); } GeneralResult convert(const hal::V1_0::ErrorStatus& status) { return validatedConvert(status); } } // namespace android::nn namespace android::hardware::neuralnetworks::V1_0::utils { namespace { template using UnvalidatedConvertOutput = std::decay_t()).value())>; template nn::GeneralResult>> unvalidatedConvert( const std::vector& arguments) { hidl_vec> halObject(arguments.size()); for (size_t i = 0; i < arguments.size(); ++i) { halObject[i] = NN_TRY(utils::unvalidatedConvert(arguments[i])); } return halObject; } template nn::GeneralResult> validatedConvert(const Type& canonical) { NN_TRY(compliantVersion(canonical)); return utils::unvalidatedConvert(canonical); } } // anonymous namespace nn::GeneralResult unvalidatedConvert(const nn::OperandType& operandType) { return static_cast(operandType); } nn::GeneralResult unvalidatedConvert(const nn::OperationType& operationType) { return static_cast(operationType); } nn::GeneralResult unvalidatedConvert(const nn::Operand::LifeTime& lifetime) { if (lifetime == nn::Operand::LifeTime::POINTER) { return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT) << "Model cannot be unvalidatedConverted because it contains pointer-based memory"; } return static_cast(lifetime); } nn::GeneralResult unvalidatedConvert(const nn::DeviceStatus& deviceStatus) { return static_cast(deviceStatus); } nn::GeneralResult unvalidatedConvert( const nn::Capabilities::PerformanceInfo& performanceInfo) { return PerformanceInfo{ .execTime = performanceInfo.execTime, .powerUsage = performanceInfo.powerUsage, }; } nn::GeneralResult unvalidatedConvert(const nn::Capabilities& capabilities) { return Capabilities{ .float32Performance = NN_TRY(unvalidatedConvert( capabilities.operandPerformance.lookup(nn::OperandType::TENSOR_FLOAT32))), .quantized8Performance = NN_TRY(unvalidatedConvert( capabilities.operandPerformance.lookup(nn::OperandType::TENSOR_QUANT8_ASYMM))), }; } nn::GeneralResult unvalidatedConvert(const nn::DataLocation& location) { return DataLocation{ .poolIndex = location.poolIndex, .offset = location.offset, .length = location.length, }; } nn::GeneralResult unvalidatedConvert(const nn::Operand& operand) { return Operand{ .type = NN_TRY(unvalidatedConvert(operand.type)), .dimensions = operand.dimensions, .numberOfConsumers = 0, .scale = operand.scale, .zeroPoint = operand.zeroPoint, .lifetime = NN_TRY(unvalidatedConvert(operand.lifetime)), .location = NN_TRY(unvalidatedConvert(operand.location)), }; } nn::GeneralResult unvalidatedConvert(const nn::Operation& operation) { return Operation{ .type = NN_TRY(unvalidatedConvert(operation.type)), .inputs = operation.inputs, .outputs = operation.outputs, }; } nn::GeneralResult> unvalidatedConvert( const nn::Model::OperandValues& operandValues) { return hidl_vec(operandValues.data(), operandValues.data() + operandValues.size()); } nn::GeneralResult unvalidatedConvert(const nn::SharedMemory& memory) { return hal::utils::createHidlMemoryFromSharedMemory(memory); } nn::GeneralResult unvalidatedConvert(const nn::Model& model) { if (!hal::utils::hasNoPointerData(model)) { return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT) << "Mdoel cannot be unvalidatedConverted because it contains pointer-based memory"; } auto operands = NN_TRY(unvalidatedConvert(model.main.operands)); // Update number of consumers. const auto numberOfConsumers = NN_TRY(hal::utils::countNumberOfConsumers(operands.size(), model.main.operations)); CHECK(operands.size() == numberOfConsumers.size()); for (size_t i = 0; i < operands.size(); ++i) { operands[i].numberOfConsumers = numberOfConsumers[i]; } return Model{ .operands = std::move(operands), .operations = NN_TRY(unvalidatedConvert(model.main.operations)), .inputIndexes = model.main.inputIndexes, .outputIndexes = model.main.outputIndexes, .operandValues = NN_TRY(unvalidatedConvert(model.operandValues)), .pools = NN_TRY(unvalidatedConvert(model.pools)), }; } nn::GeneralResult unvalidatedConvert( const nn::Request::Argument& requestArgument) { if (requestArgument.lifetime == nn::Request::Argument::LifeTime::POINTER) { return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT) << "Request cannot be unvalidatedConverted because it contains pointer-based memory"; } const bool hasNoValue = requestArgument.lifetime == nn::Request::Argument::LifeTime::NO_VALUE; return RequestArgument{ .hasNoValue = hasNoValue, .location = NN_TRY(unvalidatedConvert(requestArgument.location)), .dimensions = requestArgument.dimensions, }; } nn::GeneralResult unvalidatedConvert(const nn::Request::MemoryPool& memoryPool) { return unvalidatedConvert(std::get(memoryPool)); } nn::GeneralResult unvalidatedConvert(const nn::Request& request) { if (!hal::utils::hasNoPointerData(request)) { return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT) << "Request cannot be unvalidatedConverted because it contains pointer-based memory"; } return Request{ .inputs = NN_TRY(unvalidatedConvert(request.inputs)), .outputs = NN_TRY(unvalidatedConvert(request.outputs)), .pools = NN_TRY(unvalidatedConvert(request.pools)), }; } nn::GeneralResult unvalidatedConvert(const nn::ErrorStatus& status) { switch (status) { case nn::ErrorStatus::NONE: case nn::ErrorStatus::DEVICE_UNAVAILABLE: case nn::ErrorStatus::GENERAL_FAILURE: case nn::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE: case nn::ErrorStatus::INVALID_ARGUMENT: return static_cast(status); default: return ErrorStatus::GENERAL_FAILURE; } } nn::GeneralResult convert(const nn::DeviceStatus& deviceStatus) { return validatedConvert(deviceStatus); } nn::GeneralResult convert(const nn::Capabilities& capabilities) { return validatedConvert(capabilities); } nn::GeneralResult convert(const nn::Model& model) { return validatedConvert(model); } nn::GeneralResult convert(const nn::Request& request) { return validatedConvert(request); } nn::GeneralResult convert(const nn::ErrorStatus& status) { return validatedConvert(status); } } // namespace android::hardware::neuralnetworks::V1_0::utils