1 /*
2  *  Copyright 2004 The WebRTC Project Authors. All rights reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "rtc_base/async_tcp_socket.h"
12 
13 #include <stdint.h>
14 #include <string.h>
15 
16 #include <algorithm>
17 #include <memory>
18 
19 #include "api/array_view.h"
20 #include "rtc_base/byte_order.h"
21 #include "rtc_base/checks.h"
22 #include "rtc_base/logging.h"
23 #include "rtc_base/network/sent_packet.h"
24 #include "rtc_base/third_party/sigslot/sigslot.h"
25 #include "rtc_base/time_utils.h"  // for TimeMillis
26 
27 #if defined(WEBRTC_POSIX)
28 #include <errno.h>
29 #endif  // WEBRTC_POSIX
30 
31 namespace rtc {
32 
33 static const size_t kMaxPacketSize = 64 * 1024;
34 
35 typedef uint16_t PacketLength;
36 static const size_t kPacketLenSize = sizeof(PacketLength);
37 
38 static const size_t kBufSize = kMaxPacketSize + kPacketLenSize;
39 
40 // The input buffer will be resized so that at least kMinimumRecvSize bytes can
41 // be received (but it will not grow above the maximum size passed to the
42 // constructor).
43 static const size_t kMinimumRecvSize = 128;
44 
45 static const int kListenBacklog = 5;
46 
47 // Binds and connects |socket|
ConnectSocket(rtc::AsyncSocket * socket,const rtc::SocketAddress & bind_address,const rtc::SocketAddress & remote_address)48 AsyncSocket* AsyncTCPSocketBase::ConnectSocket(
49     rtc::AsyncSocket* socket,
50     const rtc::SocketAddress& bind_address,
51     const rtc::SocketAddress& remote_address) {
52   std::unique_ptr<rtc::AsyncSocket> owned_socket(socket);
53   if (socket->Bind(bind_address) < 0) {
54     RTC_LOG(LS_ERROR) << "Bind() failed with error " << socket->GetError();
55     return nullptr;
56   }
57   if (socket->Connect(remote_address) < 0) {
58     RTC_LOG(LS_ERROR) << "Connect() failed with error " << socket->GetError();
59     return nullptr;
60   }
61   return owned_socket.release();
62 }
63 
AsyncTCPSocketBase(AsyncSocket * socket,bool listen,size_t max_packet_size)64 AsyncTCPSocketBase::AsyncTCPSocketBase(AsyncSocket* socket,
65                                        bool listen,
66                                        size_t max_packet_size)
67     : socket_(socket),
68       listen_(listen),
69       max_insize_(max_packet_size),
70       max_outsize_(max_packet_size) {
71   if (!listen_) {
72     // Listening sockets don't send/receive data, so they don't need buffers.
73     inbuf_.EnsureCapacity(kMinimumRecvSize);
74   }
75 
76   RTC_DCHECK(socket_.get() != nullptr);
77   socket_->SignalConnectEvent.connect(this,
78                                       &AsyncTCPSocketBase::OnConnectEvent);
79   socket_->SignalReadEvent.connect(this, &AsyncTCPSocketBase::OnReadEvent);
80   socket_->SignalWriteEvent.connect(this, &AsyncTCPSocketBase::OnWriteEvent);
81   socket_->SignalCloseEvent.connect(this, &AsyncTCPSocketBase::OnCloseEvent);
82 
83   if (listen_) {
84     if (socket_->Listen(kListenBacklog) < 0) {
85       RTC_LOG(LS_ERROR) << "Listen() failed with error " << socket_->GetError();
86     }
87   }
88 }
89 
~AsyncTCPSocketBase()90 AsyncTCPSocketBase::~AsyncTCPSocketBase() {}
91 
GetLocalAddress() const92 SocketAddress AsyncTCPSocketBase::GetLocalAddress() const {
93   return socket_->GetLocalAddress();
94 }
95 
GetRemoteAddress() const96 SocketAddress AsyncTCPSocketBase::GetRemoteAddress() const {
97   return socket_->GetRemoteAddress();
98 }
99 
Close()100 int AsyncTCPSocketBase::Close() {
101   return socket_->Close();
102 }
103 
GetState() const104 AsyncTCPSocket::State AsyncTCPSocketBase::GetState() const {
105   switch (socket_->GetState()) {
106     case Socket::CS_CLOSED:
107       return STATE_CLOSED;
108     case Socket::CS_CONNECTING:
109       if (listen_) {
110         return STATE_BOUND;
111       } else {
112         return STATE_CONNECTING;
113       }
114     case Socket::CS_CONNECTED:
115       return STATE_CONNECTED;
116     default:
117       RTC_NOTREACHED();
118       return STATE_CLOSED;
119   }
120 }
121 
GetOption(Socket::Option opt,int * value)122 int AsyncTCPSocketBase::GetOption(Socket::Option opt, int* value) {
123   return socket_->GetOption(opt, value);
124 }
125 
SetOption(Socket::Option opt,int value)126 int AsyncTCPSocketBase::SetOption(Socket::Option opt, int value) {
127   return socket_->SetOption(opt, value);
128 }
129 
GetError() const130 int AsyncTCPSocketBase::GetError() const {
131   return socket_->GetError();
132 }
133 
SetError(int error)134 void AsyncTCPSocketBase::SetError(int error) {
135   return socket_->SetError(error);
136 }
137 
SendTo(const void * pv,size_t cb,const SocketAddress & addr,const rtc::PacketOptions & options)138 int AsyncTCPSocketBase::SendTo(const void* pv,
139                                size_t cb,
140                                const SocketAddress& addr,
141                                const rtc::PacketOptions& options) {
142   const SocketAddress& remote_address = GetRemoteAddress();
143   if (addr == remote_address)
144     return Send(pv, cb, options);
145   // Remote address may be empty if there is a sudden network change.
146   RTC_DCHECK(remote_address.IsNil());
147   socket_->SetError(ENOTCONN);
148   return -1;
149 }
150 
FlushOutBuffer()151 int AsyncTCPSocketBase::FlushOutBuffer() {
152   RTC_DCHECK(!listen_);
153   RTC_DCHECK_GT(outbuf_.size(), 0);
154   rtc::ArrayView<uint8_t> view = outbuf_;
155   int res;
156   while (view.size() > 0) {
157     res = socket_->Send(view.data(), view.size());
158     if (res <= 0) {
159       break;
160     }
161     if (static_cast<size_t>(res) > view.size()) {
162       RTC_NOTREACHED();
163       res = -1;
164       break;
165     }
166     view = view.subview(res);
167   }
168   if (res > 0) {
169     // The output buffer may have been written out over multiple partial Send(),
170     // so reconstruct the total written length.
171     RTC_DCHECK_EQ(view.size(), 0);
172     res = outbuf_.size();
173     outbuf_.Clear();
174   } else {
175     // There was an error when calling Send(), so there will still be data left
176     // to send at a later point.
177     RTC_DCHECK_GT(view.size(), 0);
178     // In the special case of EWOULDBLOCK, signal that we had a partial write.
179     if (socket_->GetError() == EWOULDBLOCK) {
180       res = outbuf_.size() - view.size();
181     }
182     if (view.size() < outbuf_.size()) {
183       memmove(outbuf_.data(), view.data(), view.size());
184       outbuf_.SetSize(view.size());
185     }
186   }
187   return res;
188 }
189 
AppendToOutBuffer(const void * pv,size_t cb)190 void AsyncTCPSocketBase::AppendToOutBuffer(const void* pv, size_t cb) {
191   RTC_DCHECK(outbuf_.size() + cb <= max_outsize_);
192   RTC_DCHECK(!listen_);
193   outbuf_.AppendData(static_cast<const uint8_t*>(pv), cb);
194 }
195 
OnConnectEvent(AsyncSocket * socket)196 void AsyncTCPSocketBase::OnConnectEvent(AsyncSocket* socket) {
197   SignalConnect(this);
198 }
199 
OnReadEvent(AsyncSocket * socket)200 void AsyncTCPSocketBase::OnReadEvent(AsyncSocket* socket) {
201   RTC_DCHECK(socket_.get() == socket);
202 
203   if (listen_) {
204     rtc::SocketAddress address;
205     rtc::AsyncSocket* new_socket = socket->Accept(&address);
206     if (!new_socket) {
207       // TODO(stefan): Do something better like forwarding the error
208       // to the user.
209       RTC_LOG(LS_ERROR) << "TCP accept failed with error "
210                         << socket_->GetError();
211       return;
212     }
213 
214     HandleIncomingConnection(new_socket);
215 
216     // Prime a read event in case data is waiting.
217     new_socket->SignalReadEvent(new_socket);
218   } else {
219     size_t total_recv = 0;
220     while (true) {
221       size_t free_size = inbuf_.capacity() - inbuf_.size();
222       if (free_size < kMinimumRecvSize && inbuf_.capacity() < max_insize_) {
223         inbuf_.EnsureCapacity(std::min(max_insize_, inbuf_.capacity() * 2));
224         free_size = inbuf_.capacity() - inbuf_.size();
225       }
226 
227       int len =
228           socket_->Recv(inbuf_.data() + inbuf_.size(), free_size, nullptr);
229       if (len < 0) {
230         // TODO(stefan): Do something better like forwarding the error to the
231         // user.
232         if (!socket_->IsBlocking()) {
233           RTC_LOG(LS_ERROR) << "Recv() returned error: " << socket_->GetError();
234         }
235         break;
236       }
237 
238       total_recv += len;
239       inbuf_.SetSize(inbuf_.size() + len);
240       if (!len || static_cast<size_t>(len) < free_size) {
241         break;
242       }
243     }
244 
245     if (!total_recv) {
246       return;
247     }
248 
249     size_t size = inbuf_.size();
250     ProcessInput(inbuf_.data<char>(), &size);
251 
252     if (size > inbuf_.size()) {
253       RTC_LOG(LS_ERROR) << "input buffer overflow";
254       RTC_NOTREACHED();
255       inbuf_.Clear();
256     } else {
257       inbuf_.SetSize(size);
258     }
259   }
260 }
261 
OnWriteEvent(AsyncSocket * socket)262 void AsyncTCPSocketBase::OnWriteEvent(AsyncSocket* socket) {
263   RTC_DCHECK(socket_.get() == socket);
264 
265   if (outbuf_.size() > 0) {
266     FlushOutBuffer();
267   }
268 
269   if (outbuf_.size() == 0) {
270     SignalReadyToSend(this);
271   }
272 }
273 
OnCloseEvent(AsyncSocket * socket,int error)274 void AsyncTCPSocketBase::OnCloseEvent(AsyncSocket* socket, int error) {
275   SignalClose(this, error);
276 }
277 
278 // AsyncTCPSocket
279 // Binds and connects |socket| and creates AsyncTCPSocket for
280 // it. Takes ownership of |socket|. Returns null if bind() or
281 // connect() fail (|socket| is destroyed in that case).
Create(AsyncSocket * socket,const SocketAddress & bind_address,const SocketAddress & remote_address)282 AsyncTCPSocket* AsyncTCPSocket::Create(AsyncSocket* socket,
283                                        const SocketAddress& bind_address,
284                                        const SocketAddress& remote_address) {
285   return new AsyncTCPSocket(
286       AsyncTCPSocketBase::ConnectSocket(socket, bind_address, remote_address),
287       false);
288 }
289 
AsyncTCPSocket(AsyncSocket * socket,bool listen)290 AsyncTCPSocket::AsyncTCPSocket(AsyncSocket* socket, bool listen)
291     : AsyncTCPSocketBase(socket, listen, kBufSize) {}
292 
Send(const void * pv,size_t cb,const rtc::PacketOptions & options)293 int AsyncTCPSocket::Send(const void* pv,
294                          size_t cb,
295                          const rtc::PacketOptions& options) {
296   if (cb > kBufSize) {
297     SetError(EMSGSIZE);
298     return -1;
299   }
300 
301   // If we are blocking on send, then silently drop this packet
302   if (!IsOutBufferEmpty())
303     return static_cast<int>(cb);
304 
305   PacketLength pkt_len = HostToNetwork16(static_cast<PacketLength>(cb));
306   AppendToOutBuffer(&pkt_len, kPacketLenSize);
307   AppendToOutBuffer(pv, cb);
308 
309   int res = FlushOutBuffer();
310   if (res <= 0) {
311     // drop packet if we made no progress
312     ClearOutBuffer();
313     return res;
314   }
315 
316   rtc::SentPacket sent_packet(options.packet_id, rtc::TimeMillis(),
317                               options.info_signaled_after_sent);
318   CopySocketInformationToPacketInfo(cb, *this, false, &sent_packet.info);
319   SignalSentPacket(this, sent_packet);
320 
321   // We claim to have sent the whole thing, even if we only sent partial
322   return static_cast<int>(cb);
323 }
324 
ProcessInput(char * data,size_t * len)325 void AsyncTCPSocket::ProcessInput(char* data, size_t* len) {
326   SocketAddress remote_addr(GetRemoteAddress());
327 
328   while (true) {
329     if (*len < kPacketLenSize)
330       return;
331 
332     PacketLength pkt_len = rtc::GetBE16(data);
333     if (*len < kPacketLenSize + pkt_len)
334       return;
335 
336     SignalReadPacket(this, data + kPacketLenSize, pkt_len, remote_addr,
337                      TimeMicros());
338 
339     *len -= kPacketLenSize + pkt_len;
340     if (*len > 0) {
341       memmove(data, data + kPacketLenSize + pkt_len, *len);
342     }
343   }
344 }
345 
HandleIncomingConnection(AsyncSocket * socket)346 void AsyncTCPSocket::HandleIncomingConnection(AsyncSocket* socket) {
347   SignalNewConnection(this, new AsyncTCPSocket(socket, false));
348 }
349 
350 }  // namespace rtc
351