1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "platform/impl/stream_socket_posix.h"
6 
7 #include <fcntl.h>
8 #include <netinet/in.h>
9 #include <netinet/ip.h>
10 #include <string.h>
11 #include <sys/socket.h>
12 #include <sys/types.h>
13 #include <unistd.h>
14 
15 namespace openscreen {
16 
17 namespace {
18 constexpr int kDefaultMaxBacklogSize = 64;
19 
20 // Call Select with no timeout, so that it doesn't block. Then use the result
21 // to determine if any connection is pending.
IsConnectionPending(int fd)22 bool IsConnectionPending(int fd) {
23   fd_set handle_set;
24   FD_ZERO(&handle_set);
25   FD_SET(fd, &handle_set);
26   struct timeval tv {
27     0
28   };
29   return select(fd + 1, &handle_set, nullptr, nullptr, &tv) > 0;
30 }
31 }  // namespace
32 
StreamSocketPosix(IPAddress::Version version)33 StreamSocketPosix::StreamSocketPosix(IPAddress::Version version)
34     : version_(version) {}
35 
StreamSocketPosix(const IPEndpoint & local_endpoint)36 StreamSocketPosix::StreamSocketPosix(const IPEndpoint& local_endpoint)
37     : version_(local_endpoint.address.version()),
38       local_address_(local_endpoint) {}
39 
StreamSocketPosix(SocketAddressPosix local_address,IPEndpoint remote_address,int file_descriptor)40 StreamSocketPosix::StreamSocketPosix(SocketAddressPosix local_address,
41                                      IPEndpoint remote_address,
42                                      int file_descriptor)
43     : handle_(file_descriptor),
44       version_(local_address.version()),
45       local_address_(local_address),
46       remote_address_(remote_address),
47       state_(TcpSocketState::kConnected) {
48   Initialize();
49 }
50 
51 StreamSocketPosix::StreamSocketPosix(StreamSocketPosix&& other) noexcept =
52     default;
53 StreamSocketPosix& StreamSocketPosix::operator=(StreamSocketPosix&& other) =
54     default;
55 
~StreamSocketPosix()56 StreamSocketPosix::~StreamSocketPosix() {
57   if (handle_.fd != kUnsetHandleFd) {
58     OSP_DCHECK(state_ != TcpSocketState::kClosed);
59     Close();
60   }
61 }
62 
GetWeakPtr() const63 WeakPtr<StreamSocketPosix> StreamSocketPosix::GetWeakPtr() const {
64   return weak_factory_.GetWeakPtr();
65 }
66 
Accept()67 ErrorOr<std::unique_ptr<StreamSocket>> StreamSocketPosix::Accept() {
68   if (!EnsureInitializedAndOpen()) {
69     return ReportSocketClosedError();
70   }
71 
72   if (!is_bound_ || state_ != TcpSocketState::kListening) {
73     return CloseOnError(Error::Code::kSocketInvalidState);
74   }
75 
76   // Check if any connection is pending, and return a special error code if not.
77   if (!IsConnectionPending(handle_.fd)) {
78     return Error::Code::kAgain;
79   }
80 
81   // We copy our address to new_remote_address since it should be in the same
82   // family. The accept call will overwrite it.
83   SocketAddressPosix new_remote_address = local_address_.value();
84   socklen_t remote_address_size = new_remote_address.size();
85   const int new_file_descriptor =
86       accept(handle_.fd, new_remote_address.address(), &remote_address_size);
87   if (new_file_descriptor == kUnsetHandleFd) {
88     return CloseOnError(
89         Error(Error::Code::kSocketAcceptFailure, strerror(errno)));
90   }
91   new_remote_address.RecomputeEndpoint();
92 
93   return ErrorOr<std::unique_ptr<StreamSocket>>(
94       std::make_unique<StreamSocketPosix>(local_address_.value(),
95                                           new_remote_address.endpoint(),
96                                           new_file_descriptor));
97 }
98 
Bind()99 Error StreamSocketPosix::Bind() {
100   if (!local_address_.has_value()) {
101     return CloseOnError(Error::Code::kSocketInvalidState);
102   }
103 
104   if (!EnsureInitializedAndOpen()) {
105     return ReportSocketClosedError();
106   }
107 
108   if (is_bound_) {
109     return CloseOnError(Error::Code::kSocketInvalidState);
110   }
111 
112   if (bind(handle_.fd, local_address_.value().address(),
113            local_address_.value().size()) != 0) {
114     return CloseOnError(
115         Error(Error::Code::kSocketBindFailure, strerror(errno)));
116   }
117 
118   is_bound_ = true;
119   return Error::None();
120 }
121 
Close()122 Error StreamSocketPosix::Close() {
123   if (handle_.fd == kUnsetHandleFd) {
124     return ReportSocketClosedError();
125   }
126 
127   OSP_DCHECK(state_ != TcpSocketState::kClosed);
128   state_ = TcpSocketState::kClosed;
129 
130   const int file_descriptor_to_close = handle_.fd;
131   handle_.fd = kUnsetHandleFd;
132   if (close(file_descriptor_to_close) != 0) {
133     return last_error_code_ = Error::Code::kSocketInvalidState;
134   }
135 
136   return Error::None();
137 }
138 
Connect(const IPEndpoint & remote_endpoint)139 Error StreamSocketPosix::Connect(const IPEndpoint& remote_endpoint) {
140   if (!EnsureInitializedAndOpen()) {
141     return ReportSocketClosedError();
142   }
143 
144   SocketAddressPosix address(remote_endpoint);
145   int ret = connect(handle_.fd, address.address(), address.size());
146   if (ret != 0 && errno != EINPROGRESS) {
147     return CloseOnError(
148         Error(Error::Code::kSocketConnectFailure, strerror(errno)));
149   }
150 
151   if (!is_bound_) {
152     if (local_address_.has_value()) {
153       return CloseOnError(Error::Code::kSocketInvalidState);
154     }
155 
156     struct sockaddr_in6 address;
157     socklen_t size = sizeof(address);
158     if (getsockname(handle_.fd, reinterpret_cast<struct sockaddr*>(&address),
159                     &size) != 0) {
160       return CloseOnError(Error::Code::kSocketConnectFailure);
161     }
162 
163     local_address_.emplace(reinterpret_cast<struct sockaddr&>(address));
164     is_bound_ = true;
165   }
166 
167   remote_address_ = remote_endpoint;
168   state_ = TcpSocketState::kConnected;
169   return Error::None();
170 }
171 
Listen()172 Error StreamSocketPosix::Listen() {
173   return Listen(kDefaultMaxBacklogSize);
174 }
175 
Listen(int max_backlog_size)176 Error StreamSocketPosix::Listen(int max_backlog_size) {
177   OSP_DCHECK(state_ == TcpSocketState::kNotConnected);
178   if (!EnsureInitializedAndOpen()) {
179     return ReportSocketClosedError();
180   }
181 
182   if (listen(handle_.fd, max_backlog_size) != 0) {
183     return CloseOnError(
184         Error(Error::Code::kSocketListenFailure, strerror(errno)));
185   }
186 
187   state_ = TcpSocketState::kListening;
188   return Error::None();
189 }
190 
remote_address() const191 absl::optional<IPEndpoint> StreamSocketPosix::remote_address() const {
192   if ((state_ != TcpSocketState::kConnected) || !remote_address_) {
193     return absl::nullopt;
194   }
195   return remote_address_.value();
196 }
197 
local_address() const198 absl::optional<IPEndpoint> StreamSocketPosix::local_address() const {
199   if (!local_address_.has_value()) {
200     return absl::nullopt;
201   }
202   return local_address_.value().endpoint();
203 }
204 
state() const205 TcpSocketState StreamSocketPosix::state() const {
206   return state_;
207 }
208 
version() const209 IPAddress::Version StreamSocketPosix::version() const {
210   return version_;
211 }
212 
EnsureInitializedAndOpen()213 bool StreamSocketPosix::EnsureInitializedAndOpen() {
214   if (state_ == TcpSocketState::kNotConnected &&
215       (handle_.fd == kUnsetHandleFd) &&
216       (last_error_code_ == Error::Code::kNone)) {
217     return Initialize() == Error::None();
218   }
219 
220   return handle_.fd != kUnsetHandleFd;
221 }
222 
Initialize()223 Error StreamSocketPosix::Initialize() {
224   if (handle_.fd == kUnsetHandleFd) {
225     int domain;
226     switch (version_) {
227       case IPAddress::Version::kV4:
228         domain = AF_INET;
229         break;
230       case IPAddress::Version::kV6:
231         domain = AF_INET6;
232         break;
233     }
234 
235     handle_.fd = socket(domain, SOCK_STREAM, 0);
236     if (handle_.fd == kUnsetHandleFd) {
237       return last_error_code_ = Error::Code::kSocketInvalidState;
238     }
239   }
240 
241   const int current_flags = fcntl(handle_.fd, F_GETFL, 0);
242   if (fcntl(handle_.fd, F_SETFL, current_flags | O_NONBLOCK) == -1) {
243     return CloseOnError(Error::Code::kSocketInvalidState);
244   }
245 
246   OSP_DCHECK_EQ(last_error_code_, Error::Code::kNone);
247   return Error::None();
248 }
249 
CloseOnError(Error error)250 Error StreamSocketPosix::CloseOnError(Error error) {
251   last_error_code_ = error.code();
252   Close();
253   return error;
254 }
255 
256 // If is_open is false, the socket has either not been initialized
257 // or has been closed, either on purpose or due to error.
ReportSocketClosedError()258 Error StreamSocketPosix::ReportSocketClosedError() {
259   return last_error_code_ = Error::Code::kSocketClosedFailure;
260 }
261 }  // namespace openscreen
262