1 /*
2  *  Copyright 2004 The WebRTC Project Authors. All rights reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include <algorithm>
12 #include <string>
13 
14 #include "webrtc/base/gunit.h"
15 #include "webrtc/base/logging.h"
16 #include "webrtc/base/natserver.h"
17 #include "webrtc/base/natsocketfactory.h"
18 #include "webrtc/base/nethelpers.h"
19 #include "webrtc/base/network.h"
20 #include "webrtc/base/physicalsocketserver.h"
21 #include "webrtc/base/testclient.h"
22 #include "webrtc/base/asynctcpsocket.h"
23 #include "webrtc/base/virtualsocketserver.h"
24 
25 using namespace rtc;
26 
CheckReceive(TestClient * client,bool should_receive,const char * buf,size_t size)27 bool CheckReceive(
28     TestClient* client, bool should_receive, const char* buf, size_t size) {
29   return (should_receive) ?
30       client->CheckNextPacket(buf, size, 0) :
31       client->CheckNoPacket();
32 }
33 
CreateTestClient(SocketFactory * factory,const SocketAddress & local_addr)34 TestClient* CreateTestClient(
35       SocketFactory* factory, const SocketAddress& local_addr) {
36   AsyncUDPSocket* socket = AsyncUDPSocket::Create(factory, local_addr);
37   return new TestClient(socket);
38 }
39 
CreateTCPTestClient(AsyncSocket * socket)40 TestClient* CreateTCPTestClient(AsyncSocket* socket) {
41   AsyncTCPSocket* packet_socket = new AsyncTCPSocket(socket, false);
42   return new TestClient(packet_socket);
43 }
44 
45 // Tests that when sending from internal_addr to external_addrs through the
46 // NAT type specified by nat_type, all external addrs receive the sent packet
47 // and, if exp_same is true, all use the same mapped-address on the NAT.
TestSend(SocketServer * internal,const SocketAddress & internal_addr,SocketServer * external,const SocketAddress external_addrs[4],NATType nat_type,bool exp_same)48 void TestSend(
49       SocketServer* internal, const SocketAddress& internal_addr,
50       SocketServer* external, const SocketAddress external_addrs[4],
51       NATType nat_type, bool exp_same) {
52   Thread th_int(internal);
53   Thread th_ext(external);
54 
55   SocketAddress server_addr = internal_addr;
56   server_addr.SetPort(0);  // Auto-select a port
57   NATServer* nat = new NATServer(nat_type, internal, server_addr, server_addr,
58                                  external, external_addrs[0]);
59   NATSocketFactory* natsf = new NATSocketFactory(internal,
60                                                  nat->internal_udp_address(),
61                                                  nat->internal_tcp_address());
62 
63   TestClient* in = CreateTestClient(natsf, internal_addr);
64   TestClient* out[4];
65   for (int i = 0; i < 4; i++)
66     out[i] = CreateTestClient(external, external_addrs[i]);
67 
68   th_int.Start();
69   th_ext.Start();
70 
71   const char* buf = "filter_test";
72   size_t len = strlen(buf);
73 
74   in->SendTo(buf, len, out[0]->address());
75   SocketAddress trans_addr;
76   EXPECT_TRUE(out[0]->CheckNextPacket(buf, len, &trans_addr));
77 
78   for (int i = 1; i < 4; i++) {
79     in->SendTo(buf, len, out[i]->address());
80     SocketAddress trans_addr2;
81     EXPECT_TRUE(out[i]->CheckNextPacket(buf, len, &trans_addr2));
82     bool are_same = (trans_addr == trans_addr2);
83     ASSERT_EQ(are_same, exp_same) << "same translated address";
84     ASSERT_NE(AF_UNSPEC, trans_addr.family());
85     ASSERT_NE(AF_UNSPEC, trans_addr2.family());
86   }
87 
88   th_int.Stop();
89   th_ext.Stop();
90 
91   delete nat;
92   delete natsf;
93   delete in;
94   for (int i = 0; i < 4; i++)
95     delete out[i];
96 }
97 
98 // Tests that when sending from external_addrs to internal_addr, the packet
99 // is delivered according to the specified filter_ip and filter_port rules.
TestRecv(SocketServer * internal,const SocketAddress & internal_addr,SocketServer * external,const SocketAddress external_addrs[4],NATType nat_type,bool filter_ip,bool filter_port)100 void TestRecv(
101       SocketServer* internal, const SocketAddress& internal_addr,
102       SocketServer* external, const SocketAddress external_addrs[4],
103       NATType nat_type, bool filter_ip, bool filter_port) {
104   Thread th_int(internal);
105   Thread th_ext(external);
106 
107   SocketAddress server_addr = internal_addr;
108   server_addr.SetPort(0);  // Auto-select a port
109   NATServer* nat = new NATServer(nat_type, internal, server_addr, server_addr,
110                                  external, external_addrs[0]);
111   NATSocketFactory* natsf = new NATSocketFactory(internal,
112                                                  nat->internal_udp_address(),
113                                                  nat->internal_tcp_address());
114 
115   TestClient* in = CreateTestClient(natsf, internal_addr);
116   TestClient* out[4];
117   for (int i = 0; i < 4; i++)
118     out[i] = CreateTestClient(external, external_addrs[i]);
119 
120   th_int.Start();
121   th_ext.Start();
122 
123   const char* buf = "filter_test";
124   size_t len = strlen(buf);
125 
126   in->SendTo(buf, len, out[0]->address());
127   SocketAddress trans_addr;
128   EXPECT_TRUE(out[0]->CheckNextPacket(buf, len, &trans_addr));
129 
130   out[1]->SendTo(buf, len, trans_addr);
131   EXPECT_TRUE(CheckReceive(in, !filter_ip, buf, len));
132 
133   out[2]->SendTo(buf, len, trans_addr);
134   EXPECT_TRUE(CheckReceive(in, !filter_port, buf, len));
135 
136   out[3]->SendTo(buf, len, trans_addr);
137   EXPECT_TRUE(CheckReceive(in, !filter_ip && !filter_port, buf, len));
138 
139   th_int.Stop();
140   th_ext.Stop();
141 
142   delete nat;
143   delete natsf;
144   delete in;
145   for (int i = 0; i < 4; i++)
146     delete out[i];
147 }
148 
149 // Tests that NATServer allocates bindings properly.
TestBindings(SocketServer * internal,const SocketAddress & internal_addr,SocketServer * external,const SocketAddress external_addrs[4])150 void TestBindings(
151     SocketServer* internal, const SocketAddress& internal_addr,
152     SocketServer* external, const SocketAddress external_addrs[4]) {
153   TestSend(internal, internal_addr, external, external_addrs,
154            NAT_OPEN_CONE, true);
155   TestSend(internal, internal_addr, external, external_addrs,
156            NAT_ADDR_RESTRICTED, true);
157   TestSend(internal, internal_addr, external, external_addrs,
158            NAT_PORT_RESTRICTED, true);
159   TestSend(internal, internal_addr, external, external_addrs,
160            NAT_SYMMETRIC, false);
161 }
162 
163 // Tests that NATServer filters packets properly.
TestFilters(SocketServer * internal,const SocketAddress & internal_addr,SocketServer * external,const SocketAddress external_addrs[4])164 void TestFilters(
165     SocketServer* internal, const SocketAddress& internal_addr,
166     SocketServer* external, const SocketAddress external_addrs[4]) {
167   TestRecv(internal, internal_addr, external, external_addrs,
168            NAT_OPEN_CONE, false, false);
169   TestRecv(internal, internal_addr, external, external_addrs,
170            NAT_ADDR_RESTRICTED, true, false);
171   TestRecv(internal, internal_addr, external, external_addrs,
172            NAT_PORT_RESTRICTED, true, true);
173   TestRecv(internal, internal_addr, external, external_addrs,
174            NAT_SYMMETRIC, true, true);
175 }
176 
TestConnectivity(const SocketAddress & src,const IPAddress & dst)177 bool TestConnectivity(const SocketAddress& src, const IPAddress& dst) {
178   // The physical NAT tests require connectivity to the selected ip from the
179   // internal address used for the NAT. Things like firewalls can break that, so
180   // check to see if it's worth even trying with this ip.
181   scoped_ptr<PhysicalSocketServer> pss(new PhysicalSocketServer());
182   scoped_ptr<AsyncSocket> client(pss->CreateAsyncSocket(src.family(),
183                                                         SOCK_DGRAM));
184   scoped_ptr<AsyncSocket> server(pss->CreateAsyncSocket(src.family(),
185                                                         SOCK_DGRAM));
186   if (client->Bind(SocketAddress(src.ipaddr(), 0)) != 0 ||
187       server->Bind(SocketAddress(dst, 0)) != 0) {
188     return false;
189   }
190   const char* buf = "hello other socket";
191   size_t len = strlen(buf);
192   int sent = client->SendTo(buf, len, server->GetLocalAddress());
193   SocketAddress addr;
194   const size_t kRecvBufSize = 64;
195   char recvbuf[kRecvBufSize];
196   Thread::Current()->SleepMs(100);
197   int received = server->RecvFrom(recvbuf, kRecvBufSize, &addr);
198   return received == sent && ::memcmp(buf, recvbuf, len) == 0;
199 }
200 
TestPhysicalInternal(const SocketAddress & int_addr)201 void TestPhysicalInternal(const SocketAddress& int_addr) {
202   BasicNetworkManager network_manager;
203   network_manager.set_ipv6_enabled(true);
204   network_manager.StartUpdating();
205   // Process pending messages so the network list is updated.
206   Thread::Current()->ProcessMessages(0);
207 
208   std::vector<Network*> networks;
209   network_manager.GetNetworks(&networks);
210   networks.erase(std::remove_if(networks.begin(), networks.end(),
211                                 [](rtc::Network* network) {
212                                   return rtc::kDefaultNetworkIgnoreMask &
213                                          network->type();
214                                 }),
215                  networks.end());
216   if (networks.empty()) {
217     LOG(LS_WARNING) << "Not enough network adapters for test.";
218     return;
219   }
220 
221   SocketAddress ext_addr1(int_addr);
222   SocketAddress ext_addr2;
223   // Find an available IP with matching family. The test breaks if int_addr
224   // can't talk to ip, so check for connectivity as well.
225   for (std::vector<Network*>::iterator it = networks.begin();
226       it != networks.end(); ++it) {
227     const IPAddress& ip = (*it)->GetBestIP();
228     if (ip.family() == int_addr.family() && TestConnectivity(int_addr, ip)) {
229       ext_addr2.SetIP(ip);
230       break;
231     }
232   }
233   if (ext_addr2.IsNil()) {
234     LOG(LS_WARNING) << "No available IP of same family as " << int_addr;
235     return;
236   }
237 
238   LOG(LS_INFO) << "selected ip " << ext_addr2.ipaddr();
239 
240   SocketAddress ext_addrs[4] = {
241       SocketAddress(ext_addr1),
242       SocketAddress(ext_addr2),
243       SocketAddress(ext_addr1),
244       SocketAddress(ext_addr2)
245   };
246 
247   scoped_ptr<PhysicalSocketServer> int_pss(new PhysicalSocketServer());
248   scoped_ptr<PhysicalSocketServer> ext_pss(new PhysicalSocketServer());
249 
250   TestBindings(int_pss.get(), int_addr, ext_pss.get(), ext_addrs);
251   TestFilters(int_pss.get(), int_addr, ext_pss.get(), ext_addrs);
252 }
253 
TEST(NatTest,TestPhysicalIPv4)254 TEST(NatTest, TestPhysicalIPv4) {
255   TestPhysicalInternal(SocketAddress("127.0.0.1", 0));
256 }
257 
TEST(NatTest,TestPhysicalIPv6)258 TEST(NatTest, TestPhysicalIPv6) {
259   if (HasIPv6Enabled()) {
260     TestPhysicalInternal(SocketAddress("::1", 0));
261   } else {
262     LOG(LS_WARNING) << "No IPv6, skipping";
263   }
264 }
265 
266 namespace {
267 
268 class TestVirtualSocketServer : public VirtualSocketServer {
269  public:
TestVirtualSocketServer(SocketServer * ss)270   explicit TestVirtualSocketServer(SocketServer* ss)
271       : VirtualSocketServer(ss),
272         ss_(ss) {}
273   // Expose this publicly
GetNextIP(int af)274   IPAddress GetNextIP(int af) { return VirtualSocketServer::GetNextIP(af); }
275 
276  private:
277   scoped_ptr<SocketServer> ss_;
278 };
279 
280 }  // namespace
281 
TestVirtualInternal(int family)282 void TestVirtualInternal(int family) {
283   scoped_ptr<TestVirtualSocketServer> int_vss(new TestVirtualSocketServer(
284       new PhysicalSocketServer()));
285   scoped_ptr<TestVirtualSocketServer> ext_vss(new TestVirtualSocketServer(
286       new PhysicalSocketServer()));
287 
288   SocketAddress int_addr;
289   SocketAddress ext_addrs[4];
290   int_addr.SetIP(int_vss->GetNextIP(family));
291   ext_addrs[0].SetIP(ext_vss->GetNextIP(int_addr.family()));
292   ext_addrs[1].SetIP(ext_vss->GetNextIP(int_addr.family()));
293   ext_addrs[2].SetIP(ext_addrs[0].ipaddr());
294   ext_addrs[3].SetIP(ext_addrs[1].ipaddr());
295 
296   TestBindings(int_vss.get(), int_addr, ext_vss.get(), ext_addrs);
297   TestFilters(int_vss.get(), int_addr, ext_vss.get(), ext_addrs);
298 }
299 
TEST(NatTest,TestVirtualIPv4)300 TEST(NatTest, TestVirtualIPv4) {
301   TestVirtualInternal(AF_INET);
302 }
303 
TEST(NatTest,TestVirtualIPv6)304 TEST(NatTest, TestVirtualIPv6) {
305   if (HasIPv6Enabled()) {
306     TestVirtualInternal(AF_INET6);
307   } else {
308     LOG(LS_WARNING) << "No IPv6, skipping";
309   }
310 }
311 
312 class NatTcpTest : public testing::Test, public sigslot::has_slots<> {
313  public:
NatTcpTest()314   NatTcpTest()
315       : int_addr_("192.168.0.1", 0),
316         ext_addr_("10.0.0.1", 0),
317         connected_(false),
318         int_pss_(new PhysicalSocketServer()),
319         ext_pss_(new PhysicalSocketServer()),
320         int_vss_(new TestVirtualSocketServer(int_pss_)),
321         ext_vss_(new TestVirtualSocketServer(ext_pss_)),
322         int_thread_(new Thread(int_vss_.get())),
323         ext_thread_(new Thread(ext_vss_.get())),
324         nat_(new NATServer(NAT_OPEN_CONE, int_vss_.get(), int_addr_, int_addr_,
325                            ext_vss_.get(), ext_addr_)),
326         natsf_(new NATSocketFactory(int_vss_.get(),
327                                     nat_->internal_udp_address(),
328                                     nat_->internal_tcp_address())) {
329     int_thread_->Start();
330     ext_thread_->Start();
331   }
332 
OnConnectEvent(AsyncSocket * socket)333   void OnConnectEvent(AsyncSocket* socket) {
334     connected_ = true;
335   }
336 
OnAcceptEvent(AsyncSocket * socket)337   void OnAcceptEvent(AsyncSocket* socket) {
338     accepted_.reset(server_->Accept(NULL));
339   }
340 
OnCloseEvent(AsyncSocket * socket,int error)341   void OnCloseEvent(AsyncSocket* socket, int error) {
342   }
343 
ConnectEvents()344   void ConnectEvents() {
345     server_->SignalReadEvent.connect(this, &NatTcpTest::OnAcceptEvent);
346     client_->SignalConnectEvent.connect(this, &NatTcpTest::OnConnectEvent);
347   }
348 
349   SocketAddress int_addr_;
350   SocketAddress ext_addr_;
351   bool connected_;
352   PhysicalSocketServer* int_pss_;
353   PhysicalSocketServer* ext_pss_;
354   rtc::scoped_ptr<TestVirtualSocketServer> int_vss_;
355   rtc::scoped_ptr<TestVirtualSocketServer> ext_vss_;
356   rtc::scoped_ptr<Thread> int_thread_;
357   rtc::scoped_ptr<Thread> ext_thread_;
358   rtc::scoped_ptr<NATServer> nat_;
359   rtc::scoped_ptr<NATSocketFactory> natsf_;
360   rtc::scoped_ptr<AsyncSocket> client_;
361   rtc::scoped_ptr<AsyncSocket> server_;
362   rtc::scoped_ptr<AsyncSocket> accepted_;
363 };
364 
TEST_F(NatTcpTest,DISABLED_TestConnectOut)365 TEST_F(NatTcpTest, DISABLED_TestConnectOut) {
366   server_.reset(ext_vss_->CreateAsyncSocket(SOCK_STREAM));
367   server_->Bind(ext_addr_);
368   server_->Listen(5);
369 
370   client_.reset(natsf_->CreateAsyncSocket(SOCK_STREAM));
371   EXPECT_GE(0, client_->Bind(int_addr_));
372   EXPECT_GE(0, client_->Connect(server_->GetLocalAddress()));
373 
374   ConnectEvents();
375 
376   EXPECT_TRUE_WAIT(connected_, 1000);
377   EXPECT_EQ(client_->GetRemoteAddress(), server_->GetLocalAddress());
378   EXPECT_EQ(accepted_->GetRemoteAddress().ipaddr(), ext_addr_.ipaddr());
379 
380   rtc::scoped_ptr<rtc::TestClient> in(CreateTCPTestClient(client_.release()));
381   rtc::scoped_ptr<rtc::TestClient> out(
382       CreateTCPTestClient(accepted_.release()));
383 
384   const char* buf = "test_packet";
385   size_t len = strlen(buf);
386 
387   in->Send(buf, len);
388   SocketAddress trans_addr;
389   EXPECT_TRUE(out->CheckNextPacket(buf, len, &trans_addr));
390 
391   out->Send(buf, len);
392   EXPECT_TRUE(in->CheckNextPacket(buf, len, &trans_addr));
393 }
394 // #endif
395