1 /*
2  *  Copyright 2004 The WebRTC Project Authors. All rights reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "webrtc/base/natsocketfactory.h"
12 
13 #include "webrtc/base/arraysize.h"
14 #include "webrtc/base/logging.h"
15 #include "webrtc/base/natserver.h"
16 #include "webrtc/base/virtualsocketserver.h"
17 
18 namespace rtc {
19 
20 // Packs the given socketaddress into the buffer in buf, in the quasi-STUN
21 // format that the natserver uses.
22 // Returns 0 if an invalid address is passed.
PackAddressForNAT(char * buf,size_t buf_size,const SocketAddress & remote_addr)23 size_t PackAddressForNAT(char* buf, size_t buf_size,
24                          const SocketAddress& remote_addr) {
25   const IPAddress& ip = remote_addr.ipaddr();
26   int family = ip.family();
27   buf[0] = 0;
28   buf[1] = family;
29   // Writes the port.
30   *(reinterpret_cast<uint16_t*>(&buf[2])) = HostToNetwork16(remote_addr.port());
31   if (family == AF_INET) {
32     ASSERT(buf_size >= kNATEncodedIPv4AddressSize);
33     in_addr v4addr = ip.ipv4_address();
34     memcpy(&buf[4], &v4addr, kNATEncodedIPv4AddressSize - 4);
35     return kNATEncodedIPv4AddressSize;
36   } else if (family == AF_INET6) {
37     ASSERT(buf_size >= kNATEncodedIPv6AddressSize);
38     in6_addr v6addr = ip.ipv6_address();
39     memcpy(&buf[4], &v6addr, kNATEncodedIPv6AddressSize - 4);
40     return kNATEncodedIPv6AddressSize;
41   }
42   return 0U;
43 }
44 
45 // Decodes the remote address from a packet that has been encoded with the nat's
46 // quasi-STUN format. Returns the length of the address (i.e., the offset into
47 // data where the original packet starts).
UnpackAddressFromNAT(const char * buf,size_t buf_size,SocketAddress * remote_addr)48 size_t UnpackAddressFromNAT(const char* buf, size_t buf_size,
49                             SocketAddress* remote_addr) {
50   ASSERT(buf_size >= 8);
51   ASSERT(buf[0] == 0);
52   int family = buf[1];
53   uint16_t port =
54       NetworkToHost16(*(reinterpret_cast<const uint16_t*>(&buf[2])));
55   if (family == AF_INET) {
56     const in_addr* v4addr = reinterpret_cast<const in_addr*>(&buf[4]);
57     *remote_addr = SocketAddress(IPAddress(*v4addr), port);
58     return kNATEncodedIPv4AddressSize;
59   } else if (family == AF_INET6) {
60     ASSERT(buf_size >= 20);
61     const in6_addr* v6addr = reinterpret_cast<const in6_addr*>(&buf[4]);
62     *remote_addr = SocketAddress(IPAddress(*v6addr), port);
63     return kNATEncodedIPv6AddressSize;
64   }
65   return 0U;
66 }
67 
68 
69 // NATSocket
70 class NATSocket : public AsyncSocket, public sigslot::has_slots<> {
71  public:
NATSocket(NATInternalSocketFactory * sf,int family,int type)72   explicit NATSocket(NATInternalSocketFactory* sf, int family, int type)
73       : sf_(sf), family_(family), type_(type), connected_(false),
74         socket_(NULL), buf_(NULL), size_(0) {
75   }
76 
~NATSocket()77   ~NATSocket() override {
78     delete socket_;
79     delete[] buf_;
80   }
81 
GetLocalAddress() const82   SocketAddress GetLocalAddress() const override {
83     return (socket_) ? socket_->GetLocalAddress() : SocketAddress();
84   }
85 
GetRemoteAddress() const86   SocketAddress GetRemoteAddress() const override {
87     return remote_addr_;  // will be NIL if not connected
88   }
89 
Bind(const SocketAddress & addr)90   int Bind(const SocketAddress& addr) override {
91     if (socket_) {  // already bound, bubble up error
92       return -1;
93     }
94 
95     int result;
96     socket_ = sf_->CreateInternalSocket(family_, type_, addr, &server_addr_);
97     result = (socket_) ? socket_->Bind(addr) : -1;
98     if (result >= 0) {
99       socket_->SignalConnectEvent.connect(this, &NATSocket::OnConnectEvent);
100       socket_->SignalReadEvent.connect(this, &NATSocket::OnReadEvent);
101       socket_->SignalWriteEvent.connect(this, &NATSocket::OnWriteEvent);
102       socket_->SignalCloseEvent.connect(this, &NATSocket::OnCloseEvent);
103     } else {
104       server_addr_.Clear();
105       delete socket_;
106       socket_ = NULL;
107     }
108 
109     return result;
110   }
111 
Connect(const SocketAddress & addr)112   int Connect(const SocketAddress& addr) override {
113     if (!socket_) {  // socket must be bound, for now
114       return -1;
115     }
116 
117     int result = 0;
118     if (type_ == SOCK_STREAM) {
119       result = socket_->Connect(server_addr_.IsNil() ? addr : server_addr_);
120     } else {
121       connected_ = true;
122     }
123 
124     if (result >= 0) {
125       remote_addr_ = addr;
126     }
127 
128     return result;
129   }
130 
Send(const void * data,size_t size)131   int Send(const void* data, size_t size) override {
132     ASSERT(connected_);
133     return SendTo(data, size, remote_addr_);
134   }
135 
SendTo(const void * data,size_t size,const SocketAddress & addr)136   int SendTo(const void* data,
137              size_t size,
138              const SocketAddress& addr) override {
139     ASSERT(!connected_ || addr == remote_addr_);
140     if (server_addr_.IsNil() || type_ == SOCK_STREAM) {
141       return socket_->SendTo(data, size, addr);
142     }
143     // This array will be too large for IPv4 packets, but only by 12 bytes.
144     scoped_ptr<char[]> buf(new char[size + kNATEncodedIPv6AddressSize]);
145     size_t addrlength = PackAddressForNAT(buf.get(),
146                                           size + kNATEncodedIPv6AddressSize,
147                                           addr);
148     size_t encoded_size = size + addrlength;
149     memcpy(buf.get() + addrlength, data, size);
150     int result = socket_->SendTo(buf.get(), encoded_size, server_addr_);
151     if (result >= 0) {
152       ASSERT(result == static_cast<int>(encoded_size));
153       result = result - static_cast<int>(addrlength);
154     }
155     return result;
156   }
157 
Recv(void * data,size_t size)158   int Recv(void* data, size_t size) override {
159     SocketAddress addr;
160     return RecvFrom(data, size, &addr);
161   }
162 
RecvFrom(void * data,size_t size,SocketAddress * out_addr)163   int RecvFrom(void* data, size_t size, SocketAddress* out_addr) override {
164     if (server_addr_.IsNil() || type_ == SOCK_STREAM) {
165       return socket_->RecvFrom(data, size, out_addr);
166     }
167     // Make sure we have enough room to read the requested amount plus the
168     // largest possible header address.
169     SocketAddress remote_addr;
170     Grow(size + kNATEncodedIPv6AddressSize);
171 
172     // Read the packet from the socket.
173     int result = socket_->RecvFrom(buf_, size_, &remote_addr);
174     if (result >= 0) {
175       ASSERT(remote_addr == server_addr_);
176 
177       // TODO: we need better framing so we know how many bytes we can
178       // return before we need to read the next address. For UDP, this will be
179       // fine as long as the reader always reads everything in the packet.
180       ASSERT((size_t)result < size_);
181 
182       // Decode the wire packet into the actual results.
183       SocketAddress real_remote_addr;
184       size_t addrlength = UnpackAddressFromNAT(buf_, result, &real_remote_addr);
185       memcpy(data, buf_ + addrlength, result - addrlength);
186 
187       // Make sure this packet should be delivered before returning it.
188       if (!connected_ || (real_remote_addr == remote_addr_)) {
189         if (out_addr)
190           *out_addr = real_remote_addr;
191         result = result - static_cast<int>(addrlength);
192       } else {
193         LOG(LS_ERROR) << "Dropping packet from unknown remote address: "
194                       << real_remote_addr.ToString();
195         result = 0;  // Tell the caller we didn't read anything
196       }
197     }
198 
199     return result;
200   }
201 
Close()202   int Close() override {
203     int result = 0;
204     if (socket_) {
205       result = socket_->Close();
206       if (result >= 0) {
207         connected_ = false;
208         remote_addr_ = SocketAddress();
209         delete socket_;
210         socket_ = NULL;
211       }
212     }
213     return result;
214   }
215 
Listen(int backlog)216   int Listen(int backlog) override { return socket_->Listen(backlog); }
Accept(SocketAddress * paddr)217   AsyncSocket* Accept(SocketAddress* paddr) override {
218     return socket_->Accept(paddr);
219   }
GetError() const220   int GetError() const override { return socket_->GetError(); }
SetError(int error)221   void SetError(int error) override { socket_->SetError(error); }
GetState() const222   ConnState GetState() const override {
223     return connected_ ? CS_CONNECTED : CS_CLOSED;
224   }
EstimateMTU(uint16_t * mtu)225   int EstimateMTU(uint16_t* mtu) override { return socket_->EstimateMTU(mtu); }
GetOption(Option opt,int * value)226   int GetOption(Option opt, int* value) override {
227     return socket_->GetOption(opt, value);
228   }
SetOption(Option opt,int value)229   int SetOption(Option opt, int value) override {
230     return socket_->SetOption(opt, value);
231   }
232 
OnConnectEvent(AsyncSocket * socket)233   void OnConnectEvent(AsyncSocket* socket) {
234     // If we're NATed, we need to send a message with the real addr to use.
235     ASSERT(socket == socket_);
236     if (server_addr_.IsNil()) {
237       connected_ = true;
238       SignalConnectEvent(this);
239     } else {
240       SendConnectRequest();
241     }
242   }
OnReadEvent(AsyncSocket * socket)243   void OnReadEvent(AsyncSocket* socket) {
244     // If we're NATed, we need to process the connect reply.
245     ASSERT(socket == socket_);
246     if (type_ == SOCK_STREAM && !server_addr_.IsNil() && !connected_) {
247       HandleConnectReply();
248     } else {
249       SignalReadEvent(this);
250     }
251   }
OnWriteEvent(AsyncSocket * socket)252   void OnWriteEvent(AsyncSocket* socket) {
253     ASSERT(socket == socket_);
254     SignalWriteEvent(this);
255   }
OnCloseEvent(AsyncSocket * socket,int error)256   void OnCloseEvent(AsyncSocket* socket, int error) {
257     ASSERT(socket == socket_);
258     SignalCloseEvent(this, error);
259   }
260 
261  private:
262   // Makes sure the buffer is at least the given size.
Grow(size_t new_size)263   void Grow(size_t new_size) {
264     if (size_ < new_size) {
265       delete[] buf_;
266       size_ = new_size;
267       buf_ = new char[size_];
268     }
269   }
270 
271   // Sends the destination address to the server to tell it to connect.
SendConnectRequest()272   void SendConnectRequest() {
273     char buf[kNATEncodedIPv6AddressSize];
274     size_t length = PackAddressForNAT(buf, arraysize(buf), remote_addr_);
275     socket_->Send(buf, length);
276   }
277 
278   // Handles the byte sent back from the server and fires the appropriate event.
HandleConnectReply()279   void HandleConnectReply() {
280     char code;
281     socket_->Recv(&code, sizeof(code));
282     if (code == 0) {
283       connected_ = true;
284       SignalConnectEvent(this);
285     } else {
286       Close();
287       SignalCloseEvent(this, code);
288     }
289   }
290 
291   NATInternalSocketFactory* sf_;
292   int family_;
293   int type_;
294   bool connected_;
295   SocketAddress remote_addr_;
296   SocketAddress server_addr_;  // address of the NAT server
297   AsyncSocket* socket_;
298   char* buf_;
299   size_t size_;
300 };
301 
302 // NATSocketFactory
NATSocketFactory(SocketFactory * factory,const SocketAddress & nat_udp_addr,const SocketAddress & nat_tcp_addr)303 NATSocketFactory::NATSocketFactory(SocketFactory* factory,
304                                    const SocketAddress& nat_udp_addr,
305                                    const SocketAddress& nat_tcp_addr)
306     : factory_(factory), nat_udp_addr_(nat_udp_addr),
307       nat_tcp_addr_(nat_tcp_addr) {
308 }
309 
CreateSocket(int type)310 Socket* NATSocketFactory::CreateSocket(int type) {
311   return CreateSocket(AF_INET, type);
312 }
313 
CreateSocket(int family,int type)314 Socket* NATSocketFactory::CreateSocket(int family, int type) {
315   return new NATSocket(this, family, type);
316 }
317 
CreateAsyncSocket(int type)318 AsyncSocket* NATSocketFactory::CreateAsyncSocket(int type) {
319   return CreateAsyncSocket(AF_INET, type);
320 }
321 
CreateAsyncSocket(int family,int type)322 AsyncSocket* NATSocketFactory::CreateAsyncSocket(int family, int type) {
323   return new NATSocket(this, family, type);
324 }
325 
CreateInternalSocket(int family,int type,const SocketAddress & local_addr,SocketAddress * nat_addr)326 AsyncSocket* NATSocketFactory::CreateInternalSocket(int family, int type,
327     const SocketAddress& local_addr, SocketAddress* nat_addr) {
328   if (type == SOCK_STREAM) {
329     *nat_addr = nat_tcp_addr_;
330   } else {
331     *nat_addr = nat_udp_addr_;
332   }
333   return factory_->CreateAsyncSocket(family, type);
334 }
335 
336 // NATSocketServer
NATSocketServer(SocketServer * server)337 NATSocketServer::NATSocketServer(SocketServer* server)
338     : server_(server), msg_queue_(NULL) {
339 }
340 
GetTranslator(const SocketAddress & ext_ip)341 NATSocketServer::Translator* NATSocketServer::GetTranslator(
342     const SocketAddress& ext_ip) {
343   return nats_.Get(ext_ip);
344 }
345 
AddTranslator(const SocketAddress & ext_ip,const SocketAddress & int_ip,NATType type)346 NATSocketServer::Translator* NATSocketServer::AddTranslator(
347     const SocketAddress& ext_ip, const SocketAddress& int_ip, NATType type) {
348   // Fail if a translator already exists with this extternal address.
349   if (nats_.Get(ext_ip))
350     return NULL;
351 
352   return nats_.Add(ext_ip, new Translator(this, type, int_ip, server_, ext_ip));
353 }
354 
RemoveTranslator(const SocketAddress & ext_ip)355 void NATSocketServer::RemoveTranslator(
356     const SocketAddress& ext_ip) {
357   nats_.Remove(ext_ip);
358 }
359 
CreateSocket(int type)360 Socket* NATSocketServer::CreateSocket(int type) {
361   return CreateSocket(AF_INET, type);
362 }
363 
CreateSocket(int family,int type)364 Socket* NATSocketServer::CreateSocket(int family, int type) {
365   return new NATSocket(this, family, type);
366 }
367 
CreateAsyncSocket(int type)368 AsyncSocket* NATSocketServer::CreateAsyncSocket(int type) {
369   return CreateAsyncSocket(AF_INET, type);
370 }
371 
CreateAsyncSocket(int family,int type)372 AsyncSocket* NATSocketServer::CreateAsyncSocket(int family, int type) {
373   return new NATSocket(this, family, type);
374 }
375 
SetMessageQueue(MessageQueue * queue)376 void NATSocketServer::SetMessageQueue(MessageQueue* queue) {
377   msg_queue_ = queue;
378   server_->SetMessageQueue(queue);
379 }
380 
Wait(int cms,bool process_io)381 bool NATSocketServer::Wait(int cms, bool process_io) {
382   return server_->Wait(cms, process_io);
383 }
384 
WakeUp()385 void NATSocketServer::WakeUp() {
386   server_->WakeUp();
387 }
388 
CreateInternalSocket(int family,int type,const SocketAddress & local_addr,SocketAddress * nat_addr)389 AsyncSocket* NATSocketServer::CreateInternalSocket(int family, int type,
390     const SocketAddress& local_addr, SocketAddress* nat_addr) {
391   AsyncSocket* socket = NULL;
392   Translator* nat = nats_.FindClient(local_addr);
393   if (nat) {
394     socket = nat->internal_factory()->CreateAsyncSocket(family, type);
395     *nat_addr = (type == SOCK_STREAM) ?
396         nat->internal_tcp_address() : nat->internal_udp_address();
397   } else {
398     socket = server_->CreateAsyncSocket(family, type);
399   }
400   return socket;
401 }
402 
403 // NATSocketServer::Translator
Translator(NATSocketServer * server,NATType type,const SocketAddress & int_ip,SocketFactory * ext_factory,const SocketAddress & ext_ip)404 NATSocketServer::Translator::Translator(
405     NATSocketServer* server, NATType type, const SocketAddress& int_ip,
406     SocketFactory* ext_factory, const SocketAddress& ext_ip)
407     : server_(server) {
408   // Create a new private network, and a NATServer running on the private
409   // network that bridges to the external network. Also tell the private
410   // network to use the same message queue as us.
411   VirtualSocketServer* internal_server = new VirtualSocketServer(server_);
412   internal_server->SetMessageQueue(server_->queue());
413   internal_factory_.reset(internal_server);
414   nat_server_.reset(new NATServer(type, internal_server, int_ip, int_ip,
415                                   ext_factory, ext_ip));
416 }
417 
418 NATSocketServer::Translator::~Translator() = default;
419 
GetTranslator(const SocketAddress & ext_ip)420 NATSocketServer::Translator* NATSocketServer::Translator::GetTranslator(
421     const SocketAddress& ext_ip) {
422   return nats_.Get(ext_ip);
423 }
424 
AddTranslator(const SocketAddress & ext_ip,const SocketAddress & int_ip,NATType type)425 NATSocketServer::Translator* NATSocketServer::Translator::AddTranslator(
426     const SocketAddress& ext_ip, const SocketAddress& int_ip, NATType type) {
427   // Fail if a translator already exists with this extternal address.
428   if (nats_.Get(ext_ip))
429     return NULL;
430 
431   AddClient(ext_ip);
432   return nats_.Add(ext_ip,
433                    new Translator(server_, type, int_ip, server_, ext_ip));
434 }
RemoveTranslator(const SocketAddress & ext_ip)435 void NATSocketServer::Translator::RemoveTranslator(
436     const SocketAddress& ext_ip) {
437   nats_.Remove(ext_ip);
438   RemoveClient(ext_ip);
439 }
440 
AddClient(const SocketAddress & int_ip)441 bool NATSocketServer::Translator::AddClient(
442     const SocketAddress& int_ip) {
443   // Fail if a client already exists with this internal address.
444   if (clients_.find(int_ip) != clients_.end())
445     return false;
446 
447   clients_.insert(int_ip);
448   return true;
449 }
450 
RemoveClient(const SocketAddress & int_ip)451 void NATSocketServer::Translator::RemoveClient(
452     const SocketAddress& int_ip) {
453   std::set<SocketAddress>::iterator it = clients_.find(int_ip);
454   if (it != clients_.end()) {
455     clients_.erase(it);
456   }
457 }
458 
FindClient(const SocketAddress & int_ip)459 NATSocketServer::Translator* NATSocketServer::Translator::FindClient(
460     const SocketAddress& int_ip) {
461   // See if we have the requested IP, or any of our children do.
462   return (clients_.find(int_ip) != clients_.end()) ?
463       this : nats_.FindClient(int_ip);
464 }
465 
466 // NATSocketServer::TranslatorMap
~TranslatorMap()467 NATSocketServer::TranslatorMap::~TranslatorMap() {
468   for (TranslatorMap::iterator it = begin(); it != end(); ++it) {
469     delete it->second;
470   }
471 }
472 
Get(const SocketAddress & ext_ip)473 NATSocketServer::Translator* NATSocketServer::TranslatorMap::Get(
474     const SocketAddress& ext_ip) {
475   TranslatorMap::iterator it = find(ext_ip);
476   return (it != end()) ? it->second : NULL;
477 }
478 
Add(const SocketAddress & ext_ip,Translator * nat)479 NATSocketServer::Translator* NATSocketServer::TranslatorMap::Add(
480     const SocketAddress& ext_ip, Translator* nat) {
481   (*this)[ext_ip] = nat;
482   return nat;
483 }
484 
Remove(const SocketAddress & ext_ip)485 void NATSocketServer::TranslatorMap::Remove(
486     const SocketAddress& ext_ip) {
487   TranslatorMap::iterator it = find(ext_ip);
488   if (it != end()) {
489     delete it->second;
490     erase(it);
491   }
492 }
493 
FindClient(const SocketAddress & int_ip)494 NATSocketServer::Translator* NATSocketServer::TranslatorMap::FindClient(
495     const SocketAddress& int_ip) {
496   Translator* nat = NULL;
497   for (TranslatorMap::iterator it = begin(); it != end() && !nat; ++it) {
498     nat = it->second->FindClient(int_ip);
499   }
500   return nat;
501 }
502 
503 }  // namespace rtc
504