1 /*
2  *  Copyright 2015 The WebRTC Project Authors. All rights reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "p2p/stunprober/stun_prober.h"
12 
13 #include <map>
14 #include <memory>
15 #include <set>
16 #include <string>
17 #include <utility>
18 
19 #include "api/packet_socket_factory.h"
20 #include "api/transport/stun.h"
21 #include "rtc_base/async_packet_socket.h"
22 #include "rtc_base/async_resolver_interface.h"
23 #include "rtc_base/bind.h"
24 #include "rtc_base/checks.h"
25 #include "rtc_base/constructor_magic.h"
26 #include "rtc_base/helpers.h"
27 #include "rtc_base/logging.h"
28 #include "rtc_base/thread.h"
29 #include "rtc_base/time_utils.h"
30 
31 namespace stunprober {
32 
33 namespace {
34 
35 const int THREAD_WAKE_UP_INTERVAL_MS = 5;
36 
37 template <typename T>
IncrementCounterByAddress(std::map<T,int> * counter_per_ip,const T & ip)38 void IncrementCounterByAddress(std::map<T, int>* counter_per_ip, const T& ip) {
39   counter_per_ip->insert(std::make_pair(ip, 0)).first->second++;
40 }
41 
42 }  // namespace
43 
44 // A requester tracks the requests and responses from a single socket to many
45 // STUN servers
46 class StunProber::Requester : public sigslot::has_slots<> {
47  public:
48   // Each Request maps to a request and response.
49   struct Request {
50     // Actual time the STUN bind request was sent.
51     int64_t sent_time_ms = 0;
52     // Time the response was received.
53     int64_t received_time_ms = 0;
54 
55     // Server reflexive address from STUN response for this given request.
56     rtc::SocketAddress srflx_addr;
57 
58     rtc::IPAddress server_addr;
59 
rttstunprober::StunProber::Requester::Request60     int64_t rtt() { return received_time_ms - sent_time_ms; }
61     void ProcessResponse(const char* buf, size_t buf_len);
62   };
63 
64   // StunProber provides |server_ips| for Requester to probe. For shared
65   // socket mode, it'll be all the resolved IP addresses. For non-shared mode,
66   // it'll just be a single address.
67   Requester(StunProber* prober,
68             rtc::AsyncPacketSocket* socket,
69             const std::vector<rtc::SocketAddress>& server_ips);
70   ~Requester() override;
71 
72   // There is no callback for SendStunRequest as the underneath socket send is
73   // expected to be completed immediately. Otherwise, it'll skip this request
74   // and move to the next one.
75   void SendStunRequest();
76 
77   void OnStunResponseReceived(rtc::AsyncPacketSocket* socket,
78                               const char* buf,
79                               size_t size,
80                               const rtc::SocketAddress& addr,
81                               const int64_t& packet_time_us);
82 
requests()83   const std::vector<Request*>& requests() { return requests_; }
84 
85   // Whether this Requester has completed all requests.
Done()86   bool Done() {
87     return static_cast<size_t>(num_request_sent_) == server_ips_.size();
88   }
89 
90  private:
91   Request* GetRequestByAddress(const rtc::IPAddress& ip);
92 
93   StunProber* prober_;
94 
95   // The socket for this session.
96   std::unique_ptr<rtc::AsyncPacketSocket> socket_;
97 
98   // Temporary SocketAddress and buffer for RecvFrom.
99   rtc::SocketAddress addr_;
100   std::unique_ptr<rtc::ByteBufferWriter> response_packet_;
101 
102   std::vector<Request*> requests_;
103   std::vector<rtc::SocketAddress> server_ips_;
104   int16_t num_request_sent_ = 0;
105   int16_t num_response_received_ = 0;
106 
107   rtc::ThreadChecker& thread_checker_;
108 
109   RTC_DISALLOW_COPY_AND_ASSIGN(Requester);
110 };
111 
Requester(StunProber * prober,rtc::AsyncPacketSocket * socket,const std::vector<rtc::SocketAddress> & server_ips)112 StunProber::Requester::Requester(
113     StunProber* prober,
114     rtc::AsyncPacketSocket* socket,
115     const std::vector<rtc::SocketAddress>& server_ips)
116     : prober_(prober),
117       socket_(socket),
118       response_packet_(new rtc::ByteBufferWriter(nullptr, kMaxUdpBufferSize)),
119       server_ips_(server_ips),
120       thread_checker_(prober->thread_checker_) {
121   socket_->SignalReadPacket.connect(
122       this, &StunProber::Requester::OnStunResponseReceived);
123 }
124 
~Requester()125 StunProber::Requester::~Requester() {
126   if (socket_) {
127     socket_->Close();
128   }
129   for (auto* req : requests_) {
130     if (req) {
131       delete req;
132     }
133   }
134 }
135 
SendStunRequest()136 void StunProber::Requester::SendStunRequest() {
137   RTC_DCHECK(thread_checker_.IsCurrent());
138   requests_.push_back(new Request());
139   Request& request = *(requests_.back());
140   cricket::StunMessage message;
141 
142   // Random transaction ID, STUN_BINDING_REQUEST
143   message.SetTransactionID(
144       rtc::CreateRandomString(cricket::kStunTransactionIdLength));
145   message.SetType(cricket::STUN_BINDING_REQUEST);
146 
147   std::unique_ptr<rtc::ByteBufferWriter> request_packet(
148       new rtc::ByteBufferWriter(nullptr, kMaxUdpBufferSize));
149   if (!message.Write(request_packet.get())) {
150     prober_->ReportOnFinished(WRITE_FAILED);
151     return;
152   }
153 
154   auto addr = server_ips_[num_request_sent_];
155   request.server_addr = addr.ipaddr();
156 
157   // The write must succeed immediately. Otherwise, the calculating of the STUN
158   // request timing could become too complicated. Callback is ignored by passing
159   // empty AsyncCallback.
160   rtc::PacketOptions options;
161   int rv = socket_->SendTo(const_cast<char*>(request_packet->Data()),
162                            request_packet->Length(), addr, options);
163   if (rv < 0) {
164     prober_->ReportOnFinished(WRITE_FAILED);
165     return;
166   }
167 
168   request.sent_time_ms = rtc::TimeMillis();
169 
170   num_request_sent_++;
171   RTC_DCHECK(static_cast<size_t>(num_request_sent_) <= server_ips_.size());
172 }
173 
ProcessResponse(const char * buf,size_t buf_len)174 void StunProber::Requester::Request::ProcessResponse(const char* buf,
175                                                      size_t buf_len) {
176   int64_t now = rtc::TimeMillis();
177   rtc::ByteBufferReader message(buf, buf_len);
178   cricket::StunMessage stun_response;
179   if (!stun_response.Read(&message)) {
180     // Invalid or incomplete STUN packet.
181     received_time_ms = 0;
182     return;
183   }
184 
185   // Get external address of the socket.
186   const cricket::StunAddressAttribute* addr_attr =
187       stun_response.GetAddress(cricket::STUN_ATTR_MAPPED_ADDRESS);
188   if (addr_attr == nullptr) {
189     // Addresses not available to detect whether or not behind a NAT.
190     return;
191   }
192 
193   if (addr_attr->family() != cricket::STUN_ADDRESS_IPV4 &&
194       addr_attr->family() != cricket::STUN_ADDRESS_IPV6) {
195     return;
196   }
197 
198   received_time_ms = now;
199 
200   srflx_addr = addr_attr->GetAddress();
201 }
202 
OnStunResponseReceived(rtc::AsyncPacketSocket * socket,const char * buf,size_t size,const rtc::SocketAddress & addr,const int64_t &)203 void StunProber::Requester::OnStunResponseReceived(
204     rtc::AsyncPacketSocket* socket,
205     const char* buf,
206     size_t size,
207     const rtc::SocketAddress& addr,
208     const int64_t& /* packet_time_us */) {
209   RTC_DCHECK(thread_checker_.IsCurrent());
210   RTC_DCHECK(socket_);
211   Request* request = GetRequestByAddress(addr.ipaddr());
212   if (!request) {
213     // Something is wrong, finish the test.
214     prober_->ReportOnFinished(GENERIC_FAILURE);
215     return;
216   }
217 
218   num_response_received_++;
219   request->ProcessResponse(buf, size);
220 }
221 
GetRequestByAddress(const rtc::IPAddress & ipaddr)222 StunProber::Requester::Request* StunProber::Requester::GetRequestByAddress(
223     const rtc::IPAddress& ipaddr) {
224   RTC_DCHECK(thread_checker_.IsCurrent());
225   for (auto* request : requests_) {
226     if (request->server_addr == ipaddr) {
227       return request;
228     }
229   }
230 
231   return nullptr;
232 }
233 
234 StunProber::Stats::Stats() = default;
235 
236 StunProber::Stats::~Stats() = default;
237 
238 StunProber::ObserverAdapter::ObserverAdapter() = default;
239 
240 StunProber::ObserverAdapter::~ObserverAdapter() = default;
241 
OnPrepared(StunProber * stunprober,Status status)242 void StunProber::ObserverAdapter::OnPrepared(StunProber* stunprober,
243                                              Status status) {
244   if (status == SUCCESS) {
245     stunprober->Start(this);
246   } else {
247     callback_(stunprober, status);
248   }
249 }
250 
OnFinished(StunProber * stunprober,Status status)251 void StunProber::ObserverAdapter::OnFinished(StunProber* stunprober,
252                                              Status status) {
253   callback_(stunprober, status);
254 }
255 
StunProber(rtc::PacketSocketFactory * socket_factory,rtc::Thread * thread,const rtc::NetworkManager::NetworkList & networks)256 StunProber::StunProber(rtc::PacketSocketFactory* socket_factory,
257                        rtc::Thread* thread,
258                        const rtc::NetworkManager::NetworkList& networks)
259     : interval_ms_(0),
260       socket_factory_(socket_factory),
261       thread_(thread),
262       networks_(networks) {}
263 
~StunProber()264 StunProber::~StunProber() {
265   for (auto* req : requesters_) {
266     if (req) {
267       delete req;
268     }
269   }
270   for (auto* s : sockets_) {
271     if (s) {
272       delete s;
273     }
274   }
275 }
276 
Start(const std::vector<rtc::SocketAddress> & servers,bool shared_socket_mode,int interval_ms,int num_request_per_ip,int timeout_ms,const AsyncCallback callback)277 bool StunProber::Start(const std::vector<rtc::SocketAddress>& servers,
278                        bool shared_socket_mode,
279                        int interval_ms,
280                        int num_request_per_ip,
281                        int timeout_ms,
282                        const AsyncCallback callback) {
283   observer_adapter_.set_callback(callback);
284   return Prepare(servers, shared_socket_mode, interval_ms, num_request_per_ip,
285                  timeout_ms, &observer_adapter_);
286 }
287 
Prepare(const std::vector<rtc::SocketAddress> & servers,bool shared_socket_mode,int interval_ms,int num_request_per_ip,int timeout_ms,StunProber::Observer * observer)288 bool StunProber::Prepare(const std::vector<rtc::SocketAddress>& servers,
289                          bool shared_socket_mode,
290                          int interval_ms,
291                          int num_request_per_ip,
292                          int timeout_ms,
293                          StunProber::Observer* observer) {
294   RTC_DCHECK(thread_checker_.IsCurrent());
295   interval_ms_ = interval_ms;
296   shared_socket_mode_ = shared_socket_mode;
297 
298   requests_per_ip_ = num_request_per_ip;
299   if (requests_per_ip_ == 0 || servers.size() == 0) {
300     return false;
301   }
302 
303   timeout_ms_ = timeout_ms;
304   servers_ = servers;
305   observer_ = observer;
306   // Remove addresses that are already resolved.
307   for (auto it = servers_.begin(); it != servers_.end();) {
308     if (it->ipaddr().family() != AF_UNSPEC) {
309       all_servers_addrs_.push_back(*it);
310       it = servers_.erase(it);
311     } else {
312       ++it;
313     }
314   }
315   if (servers_.empty()) {
316     CreateSockets();
317     return true;
318   }
319   return ResolveServerName(servers_.back());
320 }
321 
Start(StunProber::Observer * observer)322 bool StunProber::Start(StunProber::Observer* observer) {
323   observer_ = observer;
324   if (total_ready_sockets_ != total_socket_required()) {
325     return false;
326   }
327   MaybeScheduleStunRequests();
328   return true;
329 }
330 
ResolveServerName(const rtc::SocketAddress & addr)331 bool StunProber::ResolveServerName(const rtc::SocketAddress& addr) {
332   rtc::AsyncResolverInterface* resolver =
333       socket_factory_->CreateAsyncResolver();
334   if (!resolver) {
335     return false;
336   }
337   resolver->SignalDone.connect(this, &StunProber::OnServerResolved);
338   resolver->Start(addr);
339   return true;
340 }
341 
OnSocketReady(rtc::AsyncPacketSocket * socket,const rtc::SocketAddress & addr)342 void StunProber::OnSocketReady(rtc::AsyncPacketSocket* socket,
343                                const rtc::SocketAddress& addr) {
344   total_ready_sockets_++;
345   if (total_ready_sockets_ == total_socket_required()) {
346     ReportOnPrepared(SUCCESS);
347   }
348 }
349 
OnServerResolved(rtc::AsyncResolverInterface * resolver)350 void StunProber::OnServerResolved(rtc::AsyncResolverInterface* resolver) {
351   RTC_DCHECK(thread_checker_.IsCurrent());
352 
353   if (resolver->GetError() == 0) {
354     rtc::SocketAddress addr(resolver->address().ipaddr(),
355                             resolver->address().port());
356     all_servers_addrs_.push_back(addr);
357   }
358 
359   // Deletion of AsyncResolverInterface can't be done in OnResolveResult which
360   // handles SignalDone.
361   invoker_.AsyncInvoke<void>(
362       RTC_FROM_HERE, thread_,
363       rtc::Bind(&rtc::AsyncResolverInterface::Destroy, resolver, false));
364   servers_.pop_back();
365 
366   if (servers_.size()) {
367     if (!ResolveServerName(servers_.back())) {
368       ReportOnPrepared(RESOLVE_FAILED);
369     }
370     return;
371   }
372 
373   if (all_servers_addrs_.size() == 0) {
374     ReportOnPrepared(RESOLVE_FAILED);
375     return;
376   }
377 
378   CreateSockets();
379 }
380 
CreateSockets()381 void StunProber::CreateSockets() {
382   // Dedupe.
383   std::set<rtc::SocketAddress> addrs(all_servers_addrs_.begin(),
384                                      all_servers_addrs_.end());
385   all_servers_addrs_.assign(addrs.begin(), addrs.end());
386 
387   // Prepare all the sockets beforehand. All of them will bind to "any" address.
388   while (sockets_.size() < total_socket_required()) {
389     std::unique_ptr<rtc::AsyncPacketSocket> socket(
390         socket_factory_->CreateUdpSocket(rtc::SocketAddress(INADDR_ANY, 0), 0,
391                                          0));
392     if (!socket) {
393       ReportOnPrepared(GENERIC_FAILURE);
394       return;
395     }
396     // Chrome and WebRTC behave differently in terms of the state of a socket
397     // once returned from PacketSocketFactory::CreateUdpSocket.
398     if (socket->GetState() == rtc::AsyncPacketSocket::STATE_BINDING) {
399       socket->SignalAddressReady.connect(this, &StunProber::OnSocketReady);
400     } else {
401       OnSocketReady(socket.get(), rtc::SocketAddress(INADDR_ANY, 0));
402     }
403     sockets_.push_back(socket.release());
404   }
405 }
406 
CreateRequester()407 StunProber::Requester* StunProber::CreateRequester() {
408   RTC_DCHECK(thread_checker_.IsCurrent());
409   if (!sockets_.size()) {
410     return nullptr;
411   }
412   StunProber::Requester* requester;
413   if (shared_socket_mode_) {
414     requester = new Requester(this, sockets_.back(), all_servers_addrs_);
415   } else {
416     std::vector<rtc::SocketAddress> server_ip;
417     server_ip.push_back(
418         all_servers_addrs_[(num_request_sent_ % all_servers_addrs_.size())]);
419     requester = new Requester(this, sockets_.back(), server_ip);
420   }
421 
422   sockets_.pop_back();
423   return requester;
424 }
425 
SendNextRequest()426 bool StunProber::SendNextRequest() {
427   if (!current_requester_ || current_requester_->Done()) {
428     current_requester_ = CreateRequester();
429     requesters_.push_back(current_requester_);
430   }
431   if (!current_requester_) {
432     return false;
433   }
434   current_requester_->SendStunRequest();
435   num_request_sent_++;
436   return true;
437 }
438 
should_send_next_request(int64_t now)439 bool StunProber::should_send_next_request(int64_t now) {
440   if (interval_ms_ < THREAD_WAKE_UP_INTERVAL_MS) {
441     return now >= next_request_time_ms_;
442   } else {
443     return (now + (THREAD_WAKE_UP_INTERVAL_MS / 2)) >= next_request_time_ms_;
444   }
445 }
446 
get_wake_up_interval_ms()447 int StunProber::get_wake_up_interval_ms() {
448   if (interval_ms_ < THREAD_WAKE_UP_INTERVAL_MS) {
449     return 1;
450   } else {
451     return THREAD_WAKE_UP_INTERVAL_MS;
452   }
453 }
454 
MaybeScheduleStunRequests()455 void StunProber::MaybeScheduleStunRequests() {
456   RTC_DCHECK(thread_checker_.IsCurrent());
457   int64_t now = rtc::TimeMillis();
458 
459   if (Done()) {
460     invoker_.AsyncInvokeDelayed<void>(
461         RTC_FROM_HERE, thread_,
462         rtc::Bind(&StunProber::ReportOnFinished, this, SUCCESS), timeout_ms_);
463     return;
464   }
465   if (should_send_next_request(now)) {
466     if (!SendNextRequest()) {
467       ReportOnFinished(GENERIC_FAILURE);
468       return;
469     }
470     next_request_time_ms_ = now + interval_ms_;
471   }
472   invoker_.AsyncInvokeDelayed<void>(
473       RTC_FROM_HERE, thread_,
474       rtc::Bind(&StunProber::MaybeScheduleStunRequests, this),
475       get_wake_up_interval_ms());
476 }
477 
GetStats(StunProber::Stats * prob_stats) const478 bool StunProber::GetStats(StunProber::Stats* prob_stats) const {
479   // No need to be on the same thread.
480   if (!prob_stats) {
481     return false;
482   }
483 
484   StunProber::Stats stats;
485 
486   int rtt_sum = 0;
487   int64_t first_sent_time = 0;
488   int64_t last_sent_time = 0;
489   NatType nat_type = NATTYPE_INVALID;
490 
491   // Track of how many srflx IP that we have seen.
492   std::set<rtc::IPAddress> srflx_ips;
493 
494   // If we're not receiving any response on a given IP, all requests sent to
495   // that IP should be ignored as this could just be an DNS error.
496   std::map<rtc::IPAddress, int> num_response_per_server;
497   std::map<rtc::IPAddress, int> num_request_per_server;
498 
499   for (auto* requester : requesters_) {
500     std::map<rtc::SocketAddress, int> num_response_per_srflx_addr;
501     for (auto* request : requester->requests()) {
502       if (request->sent_time_ms <= 0) {
503         continue;
504       }
505 
506       ++stats.raw_num_request_sent;
507       IncrementCounterByAddress(&num_request_per_server, request->server_addr);
508 
509       if (!first_sent_time) {
510         first_sent_time = request->sent_time_ms;
511       }
512       last_sent_time = request->sent_time_ms;
513 
514       if (request->received_time_ms < request->sent_time_ms) {
515         continue;
516       }
517 
518       IncrementCounterByAddress(&num_response_per_server, request->server_addr);
519       IncrementCounterByAddress(&num_response_per_srflx_addr,
520                                 request->srflx_addr);
521       rtt_sum += request->rtt();
522       stats.srflx_addrs.insert(request->srflx_addr.ToString());
523       srflx_ips.insert(request->srflx_addr.ipaddr());
524     }
525 
526     // If we're using shared mode and seeing >1 srflx addresses for a single
527     // requester, it's symmetric NAT.
528     if (shared_socket_mode_ && num_response_per_srflx_addr.size() > 1) {
529       nat_type = NATTYPE_SYMMETRIC;
530     }
531   }
532 
533   // We're probably not behind a regular NAT. We have more than 1 distinct
534   // server reflexive IPs.
535   if (srflx_ips.size() > 1) {
536     return false;
537   }
538 
539   int num_sent = 0;
540   int num_received = 0;
541   int num_server_ip_with_response = 0;
542 
543   for (const auto& kv : num_response_per_server) {
544     RTC_DCHECK_GT(kv.second, 0);
545     num_server_ip_with_response++;
546     num_received += kv.second;
547     num_sent += num_request_per_server[kv.first];
548   }
549 
550   // Shared mode is only true if we use the shared socket and there are more
551   // than 1 responding servers.
552   stats.shared_socket_mode =
553       shared_socket_mode_ && (num_server_ip_with_response > 1);
554 
555   if (stats.shared_socket_mode && nat_type == NATTYPE_INVALID) {
556     nat_type = NATTYPE_NON_SYMMETRIC;
557   }
558 
559   // If we could find a local IP matching srflx, we're not behind a NAT.
560   rtc::SocketAddress srflx_addr;
561   if (stats.srflx_addrs.size() &&
562       !srflx_addr.FromString(*(stats.srflx_addrs.begin()))) {
563     return false;
564   }
565   for (const auto* net : networks_) {
566     if (srflx_addr.ipaddr() == net->GetBestIP()) {
567       nat_type = stunprober::NATTYPE_NONE;
568       stats.host_ip = net->GetBestIP().ToString();
569       break;
570     }
571   }
572 
573   // Finally, we know we're behind a NAT but can't determine which type it is.
574   if (nat_type == NATTYPE_INVALID) {
575     nat_type = NATTYPE_UNKNOWN;
576   }
577 
578   stats.nat_type = nat_type;
579   stats.num_request_sent = num_sent;
580   stats.num_response_received = num_received;
581   stats.target_request_interval_ns = interval_ms_ * 1000;
582 
583   if (num_sent) {
584     stats.success_percent = static_cast<int>(100 * num_received / num_sent);
585   }
586 
587   if (stats.raw_num_request_sent > 1) {
588     stats.actual_request_interval_ns =
589         (1000 * (last_sent_time - first_sent_time)) /
590         (stats.raw_num_request_sent - 1);
591   }
592 
593   if (num_received) {
594     stats.average_rtt_ms = static_cast<int>((rtt_sum / num_received));
595   }
596 
597   *prob_stats = stats;
598   return true;
599 }
600 
ReportOnPrepared(StunProber::Status status)601 void StunProber::ReportOnPrepared(StunProber::Status status) {
602   if (observer_) {
603     observer_->OnPrepared(this, status);
604   }
605 }
606 
ReportOnFinished(StunProber::Status status)607 void StunProber::ReportOnFinished(StunProber::Status status) {
608   if (observer_) {
609     observer_->OnFinished(this, status);
610   }
611 }
612 
613 }  // namespace stunprober
614