1 /*
2  * Copyright (C) 2016 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "UnixSocket.h"
18 
19 #include <string.h>
20 #include <sys/socket.h>
21 #include <sys/un.h>
22 #include <unistd.h>
23 
24 #include <algorithm>
25 
26 #include <android-base/logging.h>
27 
28 #include "IOEventLoop.h"
29 
CreateUnixSocketAddress(const std::string & server_path,bool is_abstract,sockaddr_un & serv_addr)30 static bool CreateUnixSocketAddress(const std::string& server_path,
31                                     bool is_abstract, sockaddr_un& serv_addr) {
32   memset(&serv_addr, 0, sizeof(serv_addr));
33   serv_addr.sun_family = AF_UNIX;
34   size_t sun_path_len = sizeof(serv_addr.sun_path);
35   char* p = serv_addr.sun_path;
36   if (is_abstract) {
37     sun_path_len--;
38     p++;
39   }
40   if (server_path.size() + 1 > sun_path_len) {
41     LOG(ERROR) << "can't create unix domain socket as server path is too long: "
42                << server_path;
43     return false;
44   }
45   strcpy(p, server_path.c_str());
46   return true;
47 }
48 
Create(const std::string & server_path,bool is_abstract)49 std::unique_ptr<UnixSocketServer> UnixSocketServer::Create(
50     const std::string& server_path, bool is_abstract) {
51   int sockfd = socket(AF_UNIX, SOCK_STREAM, 0);
52   if (sockfd < 0) {
53     PLOG(ERROR) << "socket() failed";
54     return nullptr;
55   }
56   sockaddr_un serv_addr;
57   if (!CreateUnixSocketAddress(server_path, is_abstract, serv_addr)) {
58     return nullptr;
59   }
60   if (bind(sockfd, reinterpret_cast<sockaddr*>(&serv_addr), sizeof(serv_addr)) <
61       0) {
62     PLOG(ERROR) << "bind() failed for " << server_path;
63     return nullptr;
64   }
65   if (listen(sockfd, 1) < 0) {
66     PLOG(ERROR) << "listen() failed";
67     return nullptr;
68   }
69   return std::unique_ptr<UnixSocketServer>(
70       new UnixSocketServer(sockfd, server_path));
71 }
72 
~UnixSocketServer()73 UnixSocketServer::~UnixSocketServer() { close(server_fd_); }
74 
AcceptConnection()75 std::unique_ptr<UnixSocketConnection> UnixSocketServer::AcceptConnection() {
76   int sockfd = accept(server_fd_, nullptr, nullptr);
77   if (sockfd < 0) {
78     PLOG(ERROR) << "accept() failed";
79     return nullptr;
80   }
81   return std::unique_ptr<UnixSocketConnection>(
82       new UnixSocketConnection(sockfd));
83 }
84 
Connect(const std::string & server_path,bool is_abstract)85 std::unique_ptr<UnixSocketConnection> UnixSocketConnection::Connect(
86     const std::string& server_path, bool is_abstract) {
87   int sockfd = socket(AF_UNIX, SOCK_STREAM, 0);
88   if (sockfd < 0) {
89     PLOG(DEBUG) << "socket() failed";
90     return nullptr;
91   }
92   sockaddr_un serv_addr;
93   if (!CreateUnixSocketAddress(server_path, is_abstract, serv_addr)) {
94     return nullptr;
95   }
96   if (connect(sockfd, reinterpret_cast<sockaddr*>(&serv_addr),
97               sizeof(serv_addr)) < 0) {
98     PLOG(DEBUG) << "connect() failed, server_path = " << server_path;
99     return nullptr;
100   }
101   return std::unique_ptr<UnixSocketConnection>(
102       new UnixSocketConnection(sockfd));
103 }
104 
PrepareForIO(IOEventLoop & loop,const std::function<bool (const UnixSocketMessage &)> & receive_message_callback,const std::function<bool ()> & close_connection_callback)105 bool UnixSocketConnection::PrepareForIO(
106     IOEventLoop& loop, const std::function<bool(const UnixSocketMessage&)>&
107                            receive_message_callback,
108     const std::function<bool()>& close_connection_callback) {
109   read_callback_ = receive_message_callback;
110   close_callback_ = close_connection_callback;
111   read_event_ = loop.AddReadEvent(fd_, [&]() { return ReadData(); });
112   if (read_event_ == nullptr) {
113     return false;
114   }
115   std::lock_guard<std::mutex> lock(send_buffer_and_write_event_mtx_);
116   write_event_ = loop.AddWriteEvent(fd_, [&]() { return WriteData(); });
117   if (write_event_ == nullptr) {
118     return false;
119   }
120   return DisableWriteEventWithLock();
121 }
122 
WriteData()123 bool UnixSocketConnection::WriteData() {
124   const char* write_data;
125   size_t write_data_size;
126   if (!GetDataFromSendBuffer(&write_data, &write_data_size)) {
127     return false;
128   }
129   if (write_data_size == 0u) {
130     return true;
131   }
132   // Use MSG_NOSIGNAL to prevent receiving SIGPIPE.
133   ssize_t result =
134       TEMP_FAILURE_RETRY(send(fd_, write_data, write_data_size, MSG_NOSIGNAL));
135   if (result >= 0) {
136     std::lock_guard<std::mutex> lock(send_buffer_and_write_event_mtx_);
137     send_buffer_.CommitData(result);
138   } else if (errno != EAGAIN) {
139     PLOG(ERROR) << "send() failed";
140     return false;
141   }
142   return true;
143 }
144 
GetDataFromSendBuffer(const char ** pdata,size_t * pdata_size)145 bool UnixSocketConnection::GetDataFromSendBuffer(const char** pdata,
146                                                  size_t* pdata_size) {
147   {
148     std::lock_guard<std::mutex> lock(send_buffer_and_write_event_mtx_);
149     *pdata_size = send_buffer_.PeekData(pdata);
150     if (*pdata_size != 0u) {
151       return true;
152     }
153     // The send buffer is empty. If we can receive more messages, just disable
154     // the write event temporarily, otherwise close the connection.
155     if (!no_more_message_) {
156       return DisableWriteEventWithLock();
157     }
158   }
159   return CloseConnection();
160 }
161 
ReadData()162 bool UnixSocketConnection::ReadData() {
163   ssize_t result =
164       TEMP_FAILURE_RETRY(read(fd_, &read_buffer_[read_buffer_size_],
165                               read_buffer_.size() - read_buffer_size_));
166   if (result < 0) {
167     if (errno == EAGAIN) {
168       return true;
169     }
170     PLOG(ERROR) << "read() failed";
171     return false;
172   } else if (result == 0) {
173     // The connection is closed, and no need to write pending messages.
174     return CloseConnection();
175   }
176   read_buffer_size_ += result;
177   return ConsumeDataInReadBuffer();
178 }
179 
ConsumeDataInReadBuffer()180 bool UnixSocketConnection::ConsumeDataInReadBuffer() {
181   char* p = read_buffer_.data();
182   size_t left_size = read_buffer_size_;
183   uint32_t aligned_len = 0;
184   while (left_size >= sizeof(UnixSocketMessage)) {
185     UnixSocketMessage* msg = reinterpret_cast<UnixSocketMessage*>(p);
186     aligned_len = Align(msg->len, UnixSocketMessageAlignment);
187     if (left_size < aligned_len) {
188       break;
189     }
190     if (!read_callback_(*msg)) {
191       return false;
192     }
193     p += aligned_len;
194     left_size -= aligned_len;
195   }
196   if (left_size > 0u) {
197     // Move the unfinished message to the start of read_buffer_.
198     memmove(read_buffer_.data(), p, left_size);
199     // Extend the buffer to store this big message.
200     if (aligned_len > read_buffer_.size()) {
201       read_buffer_.resize(aligned_len);
202     }
203   }
204   read_buffer_size_ = left_size;
205   return true;
206 }
207 
CloseConnection()208 bool UnixSocketConnection::CloseConnection() {
209   // Disable read_event and write_event here, so ReadData() and WriteData()
210   // won't be called in the future.
211   if (!IOEventLoop::DisableEvent(read_event_)) {
212     return false;
213   }
214   {
215     std::lock_guard<std::mutex> lock(send_buffer_and_write_event_mtx_);
216     no_more_message_ = true;
217     if (!DisableWriteEventWithLock()) {
218       return false;
219     }
220   }
221   close(fd_);
222   fd_ = -1;
223   return close_callback_();
224 }
225 
~UnixSocketConnection()226 UnixSocketConnection::~UnixSocketConnection() {
227   if (fd_ != -1) {
228     // It only happens when IO operations are not finished properly by
229     // CloseConnection(). Don't call CloseConnection() here as the
230     // IOEventLoop used to register read/write events may have been destroyed.
231     close(fd_);
232   }
233 }
234