1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "Utils.h"
18 
19 #include <aidl/android/hardware/common/Ashmem.h>
20 #include <aidl/android/hardware/common/MappableFile.h>
21 #include <aidl/android/hardware/graphics/common/HardwareBuffer.h>
22 #include <android/binder_auto_utils.h>
23 #include <android/binder_status.h>
24 #include <nnapi/Result.h>
25 #include <nnapi/SharedMemory.h>
26 
27 namespace aidl::android::hardware::neuralnetworks::utils {
28 namespace {
29 
30 nn::GeneralResult<ndk::ScopedFileDescriptor> clone(const ndk::ScopedFileDescriptor& fd);
31 using utils::clone;
32 
33 template <typename Type>
cloneVec(const std::vector<Type> & arguments)34 nn::GeneralResult<std::vector<Type>> cloneVec(const std::vector<Type>& arguments) {
35     std::vector<Type> clonedObjects;
36     clonedObjects.reserve(arguments.size());
37     for (const auto& argument : arguments) {
38         clonedObjects.push_back(NN_TRY(clone(argument)));
39     }
40     return clonedObjects;
41 }
42 
43 template <typename Type>
clone(const std::vector<Type> & arguments)44 nn::GeneralResult<std::vector<Type>> clone(const std::vector<Type>& arguments) {
45     return cloneVec(arguments);
46 }
47 
clone(const ndk::ScopedFileDescriptor & fd)48 nn::GeneralResult<ndk::ScopedFileDescriptor> clone(const ndk::ScopedFileDescriptor& fd) {
49     auto duplicatedFd = NN_TRY(nn::dupFd(fd.get()));
50     return ndk::ScopedFileDescriptor(duplicatedFd.release());
51 }
52 
clone(const common::NativeHandle & handle)53 nn::GeneralResult<common::NativeHandle> clone(const common::NativeHandle& handle) {
54     return common::NativeHandle{
55             .fds = NN_TRY(cloneVec(handle.fds)),
56             .ints = handle.ints,
57     };
58 }
59 
60 }  // namespace
61 
clone(const Memory & memory)62 nn::GeneralResult<Memory> clone(const Memory& memory) {
63     switch (memory.getTag()) {
64         case Memory::Tag::ashmem: {
65             const auto& ashmem = memory.get<Memory::Tag::ashmem>();
66             auto handle = common::Ashmem{
67                     .fd = NN_TRY(clone(ashmem.fd)),
68                     .size = ashmem.size,
69             };
70             return Memory::make<Memory::Tag::ashmem>(std::move(handle));
71         }
72         case Memory::Tag::mappableFile: {
73             const auto& memFd = memory.get<Memory::Tag::mappableFile>();
74             auto handle = common::MappableFile{
75                     .length = memFd.length,
76                     .prot = memFd.prot,
77                     .fd = NN_TRY(clone(memFd.fd)),
78                     .offset = memFd.offset,
79             };
80             return Memory::make<Memory::Tag::mappableFile>(std::move(handle));
81         }
82         case Memory::Tag::hardwareBuffer: {
83             const auto& hardwareBuffer = memory.get<Memory::Tag::hardwareBuffer>();
84             auto handle = graphics::common::HardwareBuffer{
85                     .description = hardwareBuffer.description,
86                     .handle = NN_TRY(clone(hardwareBuffer.handle)),
87             };
88             return Memory::make<Memory::Tag::hardwareBuffer>(std::move(handle));
89         }
90     }
91     return (NN_ERROR() << "Unrecognized Memory::Tag: " << memory.getTag())
92             .
93             operator nn::GeneralResult<Memory>();
94 }
95 
clone(const RequestMemoryPool & requestPool)96 nn::GeneralResult<RequestMemoryPool> clone(const RequestMemoryPool& requestPool) {
97     using Tag = RequestMemoryPool::Tag;
98     switch (requestPool.getTag()) {
99         case Tag::pool:
100             return RequestMemoryPool::make<Tag::pool>(NN_TRY(clone(requestPool.get<Tag::pool>())));
101         case Tag::token:
102             return RequestMemoryPool::make<Tag::token>(requestPool.get<Tag::token>());
103     }
104     // Using explicit type conversion because std::variant inside the RequestMemoryPool confuses the
105     // compiler.
106     return (NN_ERROR() << "Unrecognized request pool tag: " << requestPool.getTag())
107             .
108             operator nn::GeneralResult<RequestMemoryPool>();
109 }
110 
clone(const Request & request)111 nn::GeneralResult<Request> clone(const Request& request) {
112     return Request{
113             .inputs = request.inputs,
114             .outputs = request.outputs,
115             .pools = NN_TRY(clone(request.pools)),
116     };
117 }
118 
clone(const Model & model)119 nn::GeneralResult<Model> clone(const Model& model) {
120     return Model{
121             .main = model.main,
122             .referenced = model.referenced,
123             .operandValues = model.operandValues,
124             .pools = NN_TRY(clone(model.pools)),
125             .relaxComputationFloat32toFloat16 = model.relaxComputationFloat32toFloat16,
126             .extensionNameToPrefix = model.extensionNameToPrefix,
127     };
128 }
129 
handleTransportError(const ndk::ScopedAStatus & ret)130 nn::GeneralResult<void> handleTransportError(const ndk::ScopedAStatus& ret) {
131     if (ret.getStatus() == STATUS_DEAD_OBJECT) {
132         return nn::error(nn::ErrorStatus::DEAD_OBJECT)
133                << "Binder transaction returned STATUS_DEAD_OBJECT: " << ret.getDescription();
134     }
135     if (ret.isOk()) {
136         return {};
137     }
138     if (ret.getExceptionCode() != EX_SERVICE_SPECIFIC) {
139         return nn::error(nn::ErrorStatus::GENERAL_FAILURE)
140                << "Binder transaction returned exception: " << ret.getDescription();
141     }
142     return nn::error(static_cast<nn::ErrorStatus>(ret.getServiceSpecificError()))
143            << ret.getMessage();
144 }
145 
146 }  // namespace aidl::android::hardware::neuralnetworks::utils
147