1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "common/libs/utils/vsock_connection.h"
18 
19 #include <sys/types.h>
20 #include <sys/socket.h>
21 #include <sys/time.h>
22 
23 #include <functional>
24 #include <future>
25 #include <memory>
26 #include <mutex>
27 #include <new>
28 #include <ostream>
29 #include <string>
30 #include <tuple>
31 #include <utility>
32 #include <vector>
33 
34 #include <android-base/logging.h>
35 #include <json/json.h>
36 
37 #include "common/libs/fs/shared_buf.h"
38 #include "common/libs/fs/shared_select.h"
39 
40 namespace cuttlefish {
41 
~VsockConnection()42 VsockConnection::~VsockConnection() { Disconnect(); }
43 
ConnectAsync(unsigned int port,unsigned int cid,std::optional<int> vhost_user_vsock_cid_)44 std::future<bool> VsockConnection::ConnectAsync(
45     unsigned int port, unsigned int cid,
46     std::optional<int> vhost_user_vsock_cid_) {
47   return std::async(std::launch::async,
48                     [this, port, cid, vhost_user_vsock_cid_]() {
49                       return Connect(port, cid, vhost_user_vsock_cid_);
50                     });
51 }
52 
Disconnect()53 void VsockConnection::Disconnect() {
54   // We need to serialize all accesses to the SharedFD.
55   std::lock_guard<std::recursive_mutex> read_lock(read_mutex_);
56   std::lock_guard<std::recursive_mutex> write_lock(write_mutex_);
57 
58   LOG(INFO) << "Disconnecting with fd status:" << fd_->StrError();
59   fd_->Shutdown(SHUT_RDWR);
60   if (disconnect_callback_) {
61     disconnect_callback_();
62   }
63   fd_->Close();
64 }
65 
SetDisconnectCallback(std::function<void ()> callback)66 void VsockConnection::SetDisconnectCallback(std::function<void()> callback) {
67   disconnect_callback_ = callback;
68 }
69 
IsConnected()70 bool VsockConnection::IsConnected() {
71   // We need to serialize all accesses to the SharedFD.
72   std::lock_guard<std::recursive_mutex> read_lock(read_mutex_);
73   std::lock_guard<std::recursive_mutex> write_lock(write_mutex_);
74 
75   return fd_->IsOpen();
76 }
77 
DataAvailable()78 bool VsockConnection::DataAvailable() {
79   SharedFDSet read_set;
80 
81   // We need to serialize all accesses to the SharedFD.
82   std::lock_guard<std::recursive_mutex> read_lock(read_mutex_);
83   std::lock_guard<std::recursive_mutex> write_lock(write_mutex_);
84 
85   read_set.Set(fd_);
86   struct timeval timeout = {0, 0};
87   return Select(&read_set, nullptr, nullptr, &timeout) > 0;
88 }
89 
Read()90 int32_t VsockConnection::Read() {
91   std::lock_guard<std::recursive_mutex> lock(read_mutex_);
92   int32_t result;
93   if (ReadExactBinary(fd_, &result) != sizeof(result)) {
94     Disconnect();
95     return 0;
96   }
97   return result;
98 }
99 
Read(std::vector<char> & data)100 bool VsockConnection::Read(std::vector<char>& data) {
101   std::lock_guard<std::recursive_mutex> lock(read_mutex_);
102   return ReadExact(fd_, &data) == data.size();
103 }
104 
Read(size_t size)105 std::vector<char> VsockConnection::Read(size_t size) {
106   if (size == 0) {
107     return {};
108   }
109   std::lock_guard<std::recursive_mutex> lock(read_mutex_);
110   std::vector<char> result(size);
111   if (ReadExact(fd_, &result) != size) {
112     Disconnect();
113     return {};
114   }
115   return result;
116 }
117 
ReadAsync(size_t size)118 std::future<std::vector<char>> VsockConnection::ReadAsync(size_t size) {
119   return std::async(std::launch::async, [this, size]() { return Read(size); });
120 }
121 
122 // Message format is buffer size followed by buffer data
ReadMessage()123 std::vector<char> VsockConnection::ReadMessage() {
124   std::lock_guard<std::recursive_mutex> lock(read_mutex_);
125   auto size = Read();
126   if (size < 0) {
127     Disconnect();
128     return {};
129   }
130   return Read(size);
131 }
132 
ReadMessage(std::vector<char> & data)133 bool VsockConnection::ReadMessage(std::vector<char>& data) {
134   std::lock_guard<std::recursive_mutex> lock(read_mutex_);
135   auto size = Read();
136   if (size < 0) {
137     Disconnect();
138     return false;
139   }
140   data.resize(size);
141   return Read(data);
142 }
143 
ReadMessageAsync()144 std::future<std::vector<char>> VsockConnection::ReadMessageAsync() {
145   return std::async(std::launch::async, [this]() { return ReadMessage(); });
146 }
147 
ReadJsonMessage()148 Json::Value VsockConnection::ReadJsonMessage() {
149   auto msg = ReadMessage();
150   Json::CharReaderBuilder builder;
151   std::unique_ptr<Json::CharReader> reader(builder.newCharReader());
152   Json::Value json_msg;
153   std::string errors;
154   if (!reader->parse(msg.data(), msg.data() + msg.size(), &json_msg, &errors)) {
155     return {};
156   }
157   return json_msg;
158 }
159 
ReadJsonMessageAsync()160 std::future<Json::Value> VsockConnection::ReadJsonMessageAsync() {
161   return std::async(std::launch::async, [this]() { return ReadJsonMessage(); });
162 }
163 
Write(int32_t data)164 bool VsockConnection::Write(int32_t data) {
165   std::lock_guard<std::recursive_mutex> lock(write_mutex_);
166   if (WriteAllBinary(fd_, &data) != sizeof(data)) {
167     Disconnect();
168     return false;
169   }
170   return true;
171 }
172 
Write(const char * data,unsigned int size)173 bool VsockConnection::Write(const char* data, unsigned int size) {
174   std::lock_guard<std::recursive_mutex> lock(write_mutex_);
175   if (WriteAll(fd_, data, size) != size) {
176     Disconnect();
177     return false;
178   }
179   return true;
180 }
181 
Write(const std::vector<char> & data)182 bool VsockConnection::Write(const std::vector<char>& data) {
183   return Write(data.data(), data.size());
184 }
185 
186 // Message format is buffer size followed by buffer data
WriteMessage(const std::string & data)187 bool VsockConnection::WriteMessage(const std::string& data) {
188   return Write(data.size()) && Write(data.c_str(), data.length());
189 }
190 
WriteMessage(const std::vector<char> & data)191 bool VsockConnection::WriteMessage(const std::vector<char>& data) {
192   std::lock_guard<std::recursive_mutex> lock(write_mutex_);
193   return Write(data.size()) && Write(data);
194 }
195 
WriteMessage(const Json::Value & data)196 bool VsockConnection::WriteMessage(const Json::Value& data) {
197   Json::StreamWriterBuilder factory;
198   std::string message_str = Json::writeString(factory, data);
199   return WriteMessage(message_str);
200 }
201 
WriteStrides(const char * data,unsigned int size,unsigned int num_strides,int stride_size)202 bool VsockConnection::WriteStrides(const char* data, unsigned int size,
203                                    unsigned int num_strides, int stride_size) {
204   const char* src = data;
205   for (unsigned int i = 0; i < num_strides; ++i, src += stride_size) {
206     if (!Write(src, size)) {
207       return false;
208     }
209   }
210   return true;
211 }
212 
Connect(unsigned int port,unsigned int cid,std::optional<int> vhost_user)213 bool VsockClientConnection::Connect(unsigned int port, unsigned int cid,
214                                     std::optional<int> vhost_user) {
215   fd_ =
216       SharedFD::VsockClient(cid, port, SOCK_STREAM, vhost_user ? true : false);
217   if (!fd_->IsOpen()) {
218     LOG(ERROR) << "Failed to connect:" << fd_->StrError();
219   }
220   return fd_->IsOpen();
221 }
222 
~VsockServerConnection()223 VsockServerConnection::~VsockServerConnection() { ServerShutdown(); }
224 
ServerShutdown()225 void VsockServerConnection::ServerShutdown() {
226   if (server_fd_->IsOpen()) {
227     LOG(INFO) << __FUNCTION__
228               << ": server fd status:" << server_fd_->StrError();
229     server_fd_->Shutdown(SHUT_RDWR);
230     server_fd_->Close();
231   }
232 }
233 
Connect(unsigned int port,unsigned int cid,std::optional<int> vhost_user_vsock_cid)234 bool VsockServerConnection::Connect(unsigned int port, unsigned int cid,
235                                     std::optional<int> vhost_user_vsock_cid) {
236   if (!server_fd_->IsOpen()) {
237     server_fd_ = cuttlefish::SharedFD::VsockServer(port, SOCK_STREAM,
238                                                    vhost_user_vsock_cid, cid);
239   }
240   if (server_fd_->IsOpen()) {
241     fd_ = SharedFD::Accept(*server_fd_);
242     return fd_->IsOpen();
243   } else {
244     return false;
245   }
246 }
247 
248 }  // namespace cuttlefish
249