1 /*
2  *  Copyright 2004 The WebRTC Project Authors. All rights reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "webrtc/base/natsocketfactory.h"
12 #include "webrtc/base/natserver.h"
13 #include "webrtc/base/logging.h"
14 #include "webrtc/base/socketadapters.h"
15 
16 namespace rtc {
17 
RouteCmp(NAT * nat)18 RouteCmp::RouteCmp(NAT* nat) : symmetric(nat->IsSymmetric()) {
19 }
20 
operator ()(const SocketAddressPair & r) const21 size_t RouteCmp::operator()(const SocketAddressPair& r) const {
22   size_t h = r.source().Hash();
23   if (symmetric)
24     h ^= r.destination().Hash();
25   return h;
26 }
27 
operator ()(const SocketAddressPair & r1,const SocketAddressPair & r2) const28 bool RouteCmp::operator()(
29       const SocketAddressPair& r1, const SocketAddressPair& r2) const {
30   if (r1.source() < r2.source())
31     return true;
32   if (r2.source() < r1.source())
33     return false;
34   if (symmetric && (r1.destination() < r2.destination()))
35     return true;
36   if (symmetric && (r2.destination() < r1.destination()))
37     return false;
38   return false;
39 }
40 
AddrCmp(NAT * nat)41 AddrCmp::AddrCmp(NAT* nat)
42     : use_ip(nat->FiltersIP()), use_port(nat->FiltersPort()) {
43 }
44 
operator ()(const SocketAddress & a) const45 size_t AddrCmp::operator()(const SocketAddress& a) const {
46   size_t h = 0;
47   if (use_ip)
48     h ^= HashIP(a.ipaddr());
49   if (use_port)
50     h ^= a.port() | (a.port() << 16);
51   return h;
52 }
53 
operator ()(const SocketAddress & a1,const SocketAddress & a2) const54 bool AddrCmp::operator()(
55       const SocketAddress& a1, const SocketAddress& a2) const {
56   if (use_ip && (a1.ipaddr() < a2.ipaddr()))
57     return true;
58   if (use_ip && (a2.ipaddr() < a1.ipaddr()))
59     return false;
60   if (use_port && (a1.port() < a2.port()))
61     return true;
62   if (use_port && (a2.port() < a1.port()))
63     return false;
64   return false;
65 }
66 
67 // Proxy socket that will capture the external destination address intended for
68 // a TCP connection to the NAT server.
69 class NATProxyServerSocket : public AsyncProxyServerSocket {
70  public:
NATProxyServerSocket(AsyncSocket * socket)71   NATProxyServerSocket(AsyncSocket* socket)
72       : AsyncProxyServerSocket(socket, kNATEncodedIPv6AddressSize) {
73     BufferInput(true);
74   }
75 
SendConnectResult(int err,const SocketAddress & addr)76   void SendConnectResult(int err, const SocketAddress& addr) override {
77     char code = err ? 1 : 0;
78     BufferedReadAdapter::DirectSend(&code, sizeof(char));
79   }
80 
81  protected:
ProcessInput(char * data,size_t * len)82   void ProcessInput(char* data, size_t* len) override {
83     if (*len < 2) {
84       return;
85     }
86 
87     int family = data[1];
88     ASSERT(family == AF_INET || family == AF_INET6);
89     if ((family == AF_INET && *len < kNATEncodedIPv4AddressSize) ||
90         (family == AF_INET6 && *len < kNATEncodedIPv6AddressSize)) {
91       return;
92     }
93 
94     SocketAddress dest_addr;
95     size_t address_length = UnpackAddressFromNAT(data, *len, &dest_addr);
96 
97     *len -= address_length;
98     if (*len > 0) {
99       memmove(data, data + address_length, *len);
100     }
101 
102     bool remainder = (*len > 0);
103     BufferInput(false);
104     SignalConnectRequest(this, dest_addr);
105     if (remainder) {
106       SignalReadEvent(this);
107     }
108   }
109 
110 };
111 
112 class NATProxyServer : public ProxyServer {
113  public:
NATProxyServer(SocketFactory * int_factory,const SocketAddress & int_addr,SocketFactory * ext_factory,const SocketAddress & ext_ip)114   NATProxyServer(SocketFactory* int_factory, const SocketAddress& int_addr,
115                  SocketFactory* ext_factory, const SocketAddress& ext_ip)
116       : ProxyServer(int_factory, int_addr, ext_factory, ext_ip) {
117   }
118 
119  protected:
WrapSocket(AsyncSocket * socket)120   AsyncProxyServerSocket* WrapSocket(AsyncSocket* socket) override {
121     return new NATProxyServerSocket(socket);
122   }
123 };
124 
NATServer(NATType type,SocketFactory * internal,const SocketAddress & internal_udp_addr,const SocketAddress & internal_tcp_addr,SocketFactory * external,const SocketAddress & external_ip)125 NATServer::NATServer(
126     NATType type, SocketFactory* internal,
127     const SocketAddress& internal_udp_addr,
128     const SocketAddress& internal_tcp_addr,
129     SocketFactory* external, const SocketAddress& external_ip)
130     : external_(external), external_ip_(external_ip.ipaddr(), 0) {
131   nat_ = NAT::Create(type);
132 
133   udp_server_socket_ = AsyncUDPSocket::Create(internal, internal_udp_addr);
134   udp_server_socket_->SignalReadPacket.connect(this,
135                                                &NATServer::OnInternalUDPPacket);
136   tcp_proxy_server_ = new NATProxyServer(internal, internal_tcp_addr, external,
137                                          external_ip);
138 
139   int_map_ = new InternalMap(RouteCmp(nat_));
140   ext_map_ = new ExternalMap();
141 }
142 
~NATServer()143 NATServer::~NATServer() {
144   for (InternalMap::iterator iter = int_map_->begin();
145        iter != int_map_->end();
146        iter++)
147     delete iter->second;
148 
149   delete nat_;
150   delete udp_server_socket_;
151   delete tcp_proxy_server_;
152   delete int_map_;
153   delete ext_map_;
154 }
155 
OnInternalUDPPacket(AsyncPacketSocket * socket,const char * buf,size_t size,const SocketAddress & addr,const PacketTime & packet_time)156 void NATServer::OnInternalUDPPacket(
157     AsyncPacketSocket* socket, const char* buf, size_t size,
158     const SocketAddress& addr, const PacketTime& packet_time) {
159   // Read the intended destination from the wire.
160   SocketAddress dest_addr;
161   size_t length = UnpackAddressFromNAT(buf, size, &dest_addr);
162 
163   // Find the translation for these addresses (allocating one if necessary).
164   SocketAddressPair route(addr, dest_addr);
165   InternalMap::iterator iter = int_map_->find(route);
166   if (iter == int_map_->end()) {
167     Translate(route);
168     iter = int_map_->find(route);
169   }
170   ASSERT(iter != int_map_->end());
171 
172   // Allow the destination to send packets back to the source.
173   iter->second->WhitelistInsert(dest_addr);
174 
175   // Send the packet to its intended destination.
176   rtc::PacketOptions options;
177   iter->second->socket->SendTo(buf + length, size - length, dest_addr, options);
178 }
179 
OnExternalUDPPacket(AsyncPacketSocket * socket,const char * buf,size_t size,const SocketAddress & remote_addr,const PacketTime & packet_time)180 void NATServer::OnExternalUDPPacket(
181     AsyncPacketSocket* socket, const char* buf, size_t size,
182     const SocketAddress& remote_addr, const PacketTime& packet_time) {
183   SocketAddress local_addr = socket->GetLocalAddress();
184 
185   // Find the translation for this addresses.
186   ExternalMap::iterator iter = ext_map_->find(local_addr);
187   ASSERT(iter != ext_map_->end());
188 
189   // Allow the NAT to reject this packet.
190   if (ShouldFilterOut(iter->second, remote_addr)) {
191     LOG(LS_INFO) << "Packet from " << remote_addr.ToSensitiveString()
192                  << " was filtered out by the NAT.";
193     return;
194   }
195 
196   // Forward this packet to the internal address.
197   // First prepend the address in a quasi-STUN format.
198   scoped_ptr<char[]> real_buf(new char[size + kNATEncodedIPv6AddressSize]);
199   size_t addrlength = PackAddressForNAT(real_buf.get(),
200                                         size + kNATEncodedIPv6AddressSize,
201                                         remote_addr);
202   // Copy the data part after the address.
203   rtc::PacketOptions options;
204   memcpy(real_buf.get() + addrlength, buf, size);
205   udp_server_socket_->SendTo(real_buf.get(), size + addrlength,
206                              iter->second->route.source(), options);
207 }
208 
Translate(const SocketAddressPair & route)209 void NATServer::Translate(const SocketAddressPair& route) {
210   AsyncUDPSocket* socket = AsyncUDPSocket::Create(external_, external_ip_);
211 
212   if (!socket) {
213     LOG(LS_ERROR) << "Couldn't find a free port!";
214     return;
215   }
216 
217   TransEntry* entry = new TransEntry(route, socket, nat_);
218   (*int_map_)[route] = entry;
219   (*ext_map_)[socket->GetLocalAddress()] = entry;
220   socket->SignalReadPacket.connect(this, &NATServer::OnExternalUDPPacket);
221 }
222 
ShouldFilterOut(TransEntry * entry,const SocketAddress & ext_addr)223 bool NATServer::ShouldFilterOut(TransEntry* entry,
224                                 const SocketAddress& ext_addr) {
225   return entry->WhitelistContains(ext_addr);
226 }
227 
TransEntry(const SocketAddressPair & r,AsyncUDPSocket * s,NAT * nat)228 NATServer::TransEntry::TransEntry(
229     const SocketAddressPair& r, AsyncUDPSocket* s, NAT* nat)
230     : route(r), socket(s) {
231   whitelist = new AddressSet(AddrCmp(nat));
232 }
233 
~TransEntry()234 NATServer::TransEntry::~TransEntry() {
235   delete whitelist;
236   delete socket;
237 }
238 
WhitelistInsert(const SocketAddress & addr)239 void NATServer::TransEntry::WhitelistInsert(const SocketAddress& addr) {
240   CritScope cs(&crit_);
241   whitelist->insert(addr);
242 }
243 
WhitelistContains(const SocketAddress & ext_addr)244 bool NATServer::TransEntry::WhitelistContains(const SocketAddress& ext_addr) {
245   CritScope cs(&crit_);
246   return whitelist->find(ext_addr) == whitelist->end();
247 }
248 
249 }  // namespace rtc
250