1 /*
2 * Copyright (C) 2021 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "common/libs/utils/vsock_connection.h"
18
19 #include <sys/types.h>
20 #include <sys/socket.h>
21 #include <sys/time.h>
22
23 #include <functional>
24 #include <future>
25 #include <memory>
26 #include <mutex>
27 #include <new>
28 #include <ostream>
29 #include <string>
30 #include <tuple>
31 #include <utility>
32 #include <vector>
33
34 #include <android-base/logging.h>
35 #include <json/json.h>
36
37 #include "common/libs/fs/shared_buf.h"
38 #include "common/libs/fs/shared_select.h"
39
40 namespace cuttlefish {
41
~VsockConnection()42 VsockConnection::~VsockConnection() { Disconnect(); }
43
ConnectAsync(unsigned int port,unsigned int cid,std::optional<int> vhost_user_vsock_cid_)44 std::future<bool> VsockConnection::ConnectAsync(
45 unsigned int port, unsigned int cid,
46 std::optional<int> vhost_user_vsock_cid_) {
47 return std::async(std::launch::async,
48 [this, port, cid, vhost_user_vsock_cid_]() {
49 return Connect(port, cid, vhost_user_vsock_cid_);
50 });
51 }
52
Disconnect()53 void VsockConnection::Disconnect() {
54 // We need to serialize all accesses to the SharedFD.
55 std::lock_guard<std::recursive_mutex> read_lock(read_mutex_);
56 std::lock_guard<std::recursive_mutex> write_lock(write_mutex_);
57
58 LOG(INFO) << "Disconnecting with fd status:" << fd_->StrError();
59 fd_->Shutdown(SHUT_RDWR);
60 if (disconnect_callback_) {
61 disconnect_callback_();
62 }
63 fd_->Close();
64 }
65
SetDisconnectCallback(std::function<void ()> callback)66 void VsockConnection::SetDisconnectCallback(std::function<void()> callback) {
67 disconnect_callback_ = callback;
68 }
69
IsConnected()70 bool VsockConnection::IsConnected() {
71 // We need to serialize all accesses to the SharedFD.
72 std::lock_guard<std::recursive_mutex> read_lock(read_mutex_);
73 std::lock_guard<std::recursive_mutex> write_lock(write_mutex_);
74
75 return fd_->IsOpen();
76 }
77
DataAvailable()78 bool VsockConnection::DataAvailable() {
79 SharedFDSet read_set;
80
81 // We need to serialize all accesses to the SharedFD.
82 std::lock_guard<std::recursive_mutex> read_lock(read_mutex_);
83 std::lock_guard<std::recursive_mutex> write_lock(write_mutex_);
84
85 read_set.Set(fd_);
86 struct timeval timeout = {0, 0};
87 return Select(&read_set, nullptr, nullptr, &timeout) > 0;
88 }
89
Read()90 int32_t VsockConnection::Read() {
91 std::lock_guard<std::recursive_mutex> lock(read_mutex_);
92 int32_t result;
93 if (ReadExactBinary(fd_, &result) != sizeof(result)) {
94 Disconnect();
95 return 0;
96 }
97 return result;
98 }
99
Read(std::vector<char> & data)100 bool VsockConnection::Read(std::vector<char>& data) {
101 std::lock_guard<std::recursive_mutex> lock(read_mutex_);
102 return ReadExact(fd_, &data) == data.size();
103 }
104
Read(size_t size)105 std::vector<char> VsockConnection::Read(size_t size) {
106 if (size == 0) {
107 return {};
108 }
109 std::lock_guard<std::recursive_mutex> lock(read_mutex_);
110 std::vector<char> result(size);
111 if (ReadExact(fd_, &result) != size) {
112 Disconnect();
113 return {};
114 }
115 return result;
116 }
117
ReadAsync(size_t size)118 std::future<std::vector<char>> VsockConnection::ReadAsync(size_t size) {
119 return std::async(std::launch::async, [this, size]() { return Read(size); });
120 }
121
122 // Message format is buffer size followed by buffer data
ReadMessage()123 std::vector<char> VsockConnection::ReadMessage() {
124 std::lock_guard<std::recursive_mutex> lock(read_mutex_);
125 auto size = Read();
126 if (size < 0) {
127 Disconnect();
128 return {};
129 }
130 return Read(size);
131 }
132
ReadMessage(std::vector<char> & data)133 bool VsockConnection::ReadMessage(std::vector<char>& data) {
134 std::lock_guard<std::recursive_mutex> lock(read_mutex_);
135 auto size = Read();
136 if (size < 0) {
137 Disconnect();
138 return false;
139 }
140 data.resize(size);
141 return Read(data);
142 }
143
ReadMessageAsync()144 std::future<std::vector<char>> VsockConnection::ReadMessageAsync() {
145 return std::async(std::launch::async, [this]() { return ReadMessage(); });
146 }
147
ReadJsonMessage()148 Json::Value VsockConnection::ReadJsonMessage() {
149 auto msg = ReadMessage();
150 Json::CharReaderBuilder builder;
151 std::unique_ptr<Json::CharReader> reader(builder.newCharReader());
152 Json::Value json_msg;
153 std::string errors;
154 if (!reader->parse(msg.data(), msg.data() + msg.size(), &json_msg, &errors)) {
155 return {};
156 }
157 return json_msg;
158 }
159
ReadJsonMessageAsync()160 std::future<Json::Value> VsockConnection::ReadJsonMessageAsync() {
161 return std::async(std::launch::async, [this]() { return ReadJsonMessage(); });
162 }
163
Write(int32_t data)164 bool VsockConnection::Write(int32_t data) {
165 std::lock_guard<std::recursive_mutex> lock(write_mutex_);
166 if (WriteAllBinary(fd_, &data) != sizeof(data)) {
167 Disconnect();
168 return false;
169 }
170 return true;
171 }
172
Write(const char * data,unsigned int size)173 bool VsockConnection::Write(const char* data, unsigned int size) {
174 std::lock_guard<std::recursive_mutex> lock(write_mutex_);
175 if (WriteAll(fd_, data, size) != size) {
176 Disconnect();
177 return false;
178 }
179 return true;
180 }
181
Write(const std::vector<char> & data)182 bool VsockConnection::Write(const std::vector<char>& data) {
183 return Write(data.data(), data.size());
184 }
185
186 // Message format is buffer size followed by buffer data
WriteMessage(const std::string & data)187 bool VsockConnection::WriteMessage(const std::string& data) {
188 return Write(data.size()) && Write(data.c_str(), data.length());
189 }
190
WriteMessage(const std::vector<char> & data)191 bool VsockConnection::WriteMessage(const std::vector<char>& data) {
192 std::lock_guard<std::recursive_mutex> lock(write_mutex_);
193 return Write(data.size()) && Write(data);
194 }
195
WriteMessage(const Json::Value & data)196 bool VsockConnection::WriteMessage(const Json::Value& data) {
197 Json::StreamWriterBuilder factory;
198 std::string message_str = Json::writeString(factory, data);
199 return WriteMessage(message_str);
200 }
201
WriteStrides(const char * data,unsigned int size,unsigned int num_strides,int stride_size)202 bool VsockConnection::WriteStrides(const char* data, unsigned int size,
203 unsigned int num_strides, int stride_size) {
204 const char* src = data;
205 for (unsigned int i = 0; i < num_strides; ++i, src += stride_size) {
206 if (!Write(src, size)) {
207 return false;
208 }
209 }
210 return true;
211 }
212
Connect(unsigned int port,unsigned int cid,std::optional<int> vhost_user)213 bool VsockClientConnection::Connect(unsigned int port, unsigned int cid,
214 std::optional<int> vhost_user) {
215 fd_ =
216 SharedFD::VsockClient(cid, port, SOCK_STREAM, vhost_user ? true : false);
217 if (!fd_->IsOpen()) {
218 LOG(ERROR) << "Failed to connect:" << fd_->StrError();
219 }
220 return fd_->IsOpen();
221 }
222
~VsockServerConnection()223 VsockServerConnection::~VsockServerConnection() { ServerShutdown(); }
224
ServerShutdown()225 void VsockServerConnection::ServerShutdown() {
226 if (server_fd_->IsOpen()) {
227 LOG(INFO) << __FUNCTION__
228 << ": server fd status:" << server_fd_->StrError();
229 server_fd_->Shutdown(SHUT_RDWR);
230 server_fd_->Close();
231 }
232 }
233
Connect(unsigned int port,unsigned int cid,std::optional<int> vhost_user_vsock_cid)234 bool VsockServerConnection::Connect(unsigned int port, unsigned int cid,
235 std::optional<int> vhost_user_vsock_cid) {
236 if (!server_fd_->IsOpen()) {
237 server_fd_ = cuttlefish::SharedFD::VsockServer(port, SOCK_STREAM,
238 vhost_user_vsock_cid, cid);
239 }
240 if (server_fd_->IsOpen()) {
241 fd_ = SharedFD::Accept(*server_fd_);
242 return fd_->IsOpen();
243 } else {
244 return false;
245 }
246 }
247
248 } // namespace cuttlefish
249