1 /*
2  *  Copyright 2009 The WebRTC Project Authors. All rights reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #ifndef WEBRTC_P2P_BASE_FAKETRANSPORTCONTROLLER_H_
12 #define WEBRTC_P2P_BASE_FAKETRANSPORTCONTROLLER_H_
13 
14 #include <map>
15 #include <string>
16 #include <vector>
17 
18 #include "webrtc/p2p/base/transport.h"
19 #include "webrtc/p2p/base/transportchannel.h"
20 #include "webrtc/p2p/base/transportcontroller.h"
21 #include "webrtc/p2p/base/transportchannelimpl.h"
22 #include "webrtc/base/bind.h"
23 #include "webrtc/base/buffer.h"
24 #include "webrtc/base/fakesslidentity.h"
25 #include "webrtc/base/messagequeue.h"
26 #include "webrtc/base/sigslot.h"
27 #include "webrtc/base/sslfingerprint.h"
28 #include "webrtc/base/thread.h"
29 
30 namespace cricket {
31 
32 class FakeTransport;
33 
34 namespace {
35 struct PacketMessageData : public rtc::MessageData {
PacketMessageDataPacketMessageData36   PacketMessageData(const char* data, size_t len) : packet(data, len) {}
37   rtc::Buffer packet;
38 };
39 }  // namespace
40 
41 // Fake transport channel class, which can be passed to anything that needs a
42 // transport channel. Can be informed of another FakeTransportChannel via
43 // SetDestination.
44 // TODO(hbos): Move implementation to .cc file, this and other classes in file.
45 class FakeTransportChannel : public TransportChannelImpl,
46                              public rtc::MessageHandler {
47  public:
FakeTransportChannel(Transport * transport,const std::string & name,int component)48   explicit FakeTransportChannel(Transport* transport,
49                                 const std::string& name,
50                                 int component)
51       : TransportChannelImpl(name, component),
52         transport_(transport),
53         dtls_fingerprint_("", nullptr, 0) {}
~FakeTransportChannel()54   ~FakeTransportChannel() { Reset(); }
55 
IceTiebreaker()56   uint64_t IceTiebreaker() const { return tiebreaker_; }
remote_ice_mode()57   IceMode remote_ice_mode() const { return remote_ice_mode_; }
ice_ufrag()58   const std::string& ice_ufrag() const { return ice_ufrag_; }
ice_pwd()59   const std::string& ice_pwd() const { return ice_pwd_; }
remote_ice_ufrag()60   const std::string& remote_ice_ufrag() const { return remote_ice_ufrag_; }
remote_ice_pwd()61   const std::string& remote_ice_pwd() const { return remote_ice_pwd_; }
dtls_fingerprint()62   const rtc::SSLFingerprint& dtls_fingerprint() const {
63     return dtls_fingerprint_;
64   }
65 
66   // If async, will send packets by "Post"-ing to message queue instead of
67   // synchronously "Send"-ing.
SetAsync(bool async)68   void SetAsync(bool async) { async_ = async; }
69 
GetTransport()70   Transport* GetTransport() override { return transport_; }
71 
GetState()72   TransportChannelState GetState() const override {
73     if (connection_count_ == 0) {
74       return had_connection_ ? TransportChannelState::STATE_FAILED
75                              : TransportChannelState::STATE_INIT;
76     }
77 
78     if (connection_count_ == 1) {
79       return TransportChannelState::STATE_COMPLETED;
80     }
81 
82     return TransportChannelState::STATE_CONNECTING;
83   }
84 
SetIceRole(IceRole role)85   void SetIceRole(IceRole role) override { role_ = role; }
GetIceRole()86   IceRole GetIceRole() const override { return role_; }
SetIceTiebreaker(uint64_t tiebreaker)87   void SetIceTiebreaker(uint64_t tiebreaker) override {
88     tiebreaker_ = tiebreaker;
89   }
SetIceCredentials(const std::string & ice_ufrag,const std::string & ice_pwd)90   void SetIceCredentials(const std::string& ice_ufrag,
91                          const std::string& ice_pwd) override {
92     ice_ufrag_ = ice_ufrag;
93     ice_pwd_ = ice_pwd;
94   }
SetRemoteIceCredentials(const std::string & ice_ufrag,const std::string & ice_pwd)95   void SetRemoteIceCredentials(const std::string& ice_ufrag,
96                                const std::string& ice_pwd) override {
97     remote_ice_ufrag_ = ice_ufrag;
98     remote_ice_pwd_ = ice_pwd;
99   }
100 
SetRemoteIceMode(IceMode mode)101   void SetRemoteIceMode(IceMode mode) override { remote_ice_mode_ = mode; }
SetRemoteFingerprint(const std::string & alg,const uint8_t * digest,size_t digest_len)102   bool SetRemoteFingerprint(const std::string& alg,
103                             const uint8_t* digest,
104                             size_t digest_len) override {
105     dtls_fingerprint_ = rtc::SSLFingerprint(alg, digest, digest_len);
106     return true;
107   }
SetSslRole(rtc::SSLRole role)108   bool SetSslRole(rtc::SSLRole role) override {
109     ssl_role_ = role;
110     return true;
111   }
GetSslRole(rtc::SSLRole * role)112   bool GetSslRole(rtc::SSLRole* role) const override {
113     *role = ssl_role_;
114     return true;
115   }
116 
Connect()117   void Connect() override {
118     if (state_ == STATE_INIT) {
119       state_ = STATE_CONNECTING;
120     }
121   }
122 
MaybeStartGathering()123   void MaybeStartGathering() override {
124     if (gathering_state_ == kIceGatheringNew) {
125       gathering_state_ = kIceGatheringGathering;
126       SignalGatheringState(this);
127     }
128   }
129 
gathering_state()130   IceGatheringState gathering_state() const override {
131     return gathering_state_;
132   }
133 
Reset()134   void Reset() {
135     if (state_ != STATE_INIT) {
136       state_ = STATE_INIT;
137       if (dest_) {
138         dest_->state_ = STATE_INIT;
139         dest_->dest_ = nullptr;
140         dest_ = nullptr;
141       }
142     }
143   }
144 
SetWritable(bool writable)145   void SetWritable(bool writable) { set_writable(writable); }
146 
SetDestination(FakeTransportChannel * dest)147   void SetDestination(FakeTransportChannel* dest) {
148     if (state_ == STATE_CONNECTING && dest) {
149       // This simulates the delivery of candidates.
150       dest_ = dest;
151       dest_->dest_ = this;
152       if (local_cert_ && dest_->local_cert_) {
153         do_dtls_ = true;
154         dest_->do_dtls_ = true;
155         NegotiateSrtpCiphers();
156       }
157       state_ = STATE_CONNECTED;
158       dest_->state_ = STATE_CONNECTED;
159       set_writable(true);
160       dest_->set_writable(true);
161     } else if (state_ == STATE_CONNECTED && !dest) {
162       // Simulates loss of connectivity, by asymmetrically forgetting dest_.
163       dest_ = nullptr;
164       state_ = STATE_CONNECTING;
165       set_writable(false);
166     }
167   }
168 
SetConnectionCount(size_t connection_count)169   void SetConnectionCount(size_t connection_count) {
170     size_t old_connection_count = connection_count_;
171     connection_count_ = connection_count;
172     if (connection_count)
173       had_connection_ = true;
174     if (connection_count_ < old_connection_count)
175       SignalConnectionRemoved(this);
176   }
177 
SetCandidatesGatheringComplete()178   void SetCandidatesGatheringComplete() {
179     if (gathering_state_ != kIceGatheringComplete) {
180       gathering_state_ = kIceGatheringComplete;
181       SignalGatheringState(this);
182     }
183   }
184 
SetReceiving(bool receiving)185   void SetReceiving(bool receiving) { set_receiving(receiving); }
186 
SetIceConfig(const IceConfig & config)187   void SetIceConfig(const IceConfig& config) override {
188     receiving_timeout_ = config.receiving_timeout_ms;
189     gather_continually_ = config.gather_continually;
190   }
191 
receiving_timeout()192   int receiving_timeout() const { return receiving_timeout_; }
gather_continually()193   bool gather_continually() const { return gather_continually_; }
194 
SendPacket(const char * data,size_t len,const rtc::PacketOptions & options,int flags)195   int SendPacket(const char* data,
196                  size_t len,
197                  const rtc::PacketOptions& options,
198                  int flags) override {
199     if (state_ != STATE_CONNECTED) {
200       return -1;
201     }
202 
203     if (flags != PF_SRTP_BYPASS && flags != 0) {
204       return -1;
205     }
206 
207     PacketMessageData* packet = new PacketMessageData(data, len);
208     if (async_) {
209       rtc::Thread::Current()->Post(this, 0, packet);
210     } else {
211       rtc::Thread::Current()->Send(this, 0, packet);
212     }
213     rtc::SentPacket sent_packet(options.packet_id, rtc::Time());
214     SignalSentPacket(this, sent_packet);
215     return static_cast<int>(len);
216   }
SetOption(rtc::Socket::Option opt,int value)217   int SetOption(rtc::Socket::Option opt, int value) override { return true; }
GetOption(rtc::Socket::Option opt,int * value)218   bool GetOption(rtc::Socket::Option opt, int* value) override { return true; }
GetError()219   int GetError() override { return 0; }
220 
AddRemoteCandidate(const Candidate & candidate)221   void AddRemoteCandidate(const Candidate& candidate) override {
222     remote_candidates_.push_back(candidate);
223   }
remote_candidates()224   const Candidates& remote_candidates() const { return remote_candidates_; }
225 
OnMessage(rtc::Message * msg)226   void OnMessage(rtc::Message* msg) override {
227     PacketMessageData* data = static_cast<PacketMessageData*>(msg->pdata);
228     dest_->SignalReadPacket(dest_, data->packet.data<char>(),
229                             data->packet.size(), rtc::CreatePacketTime(0), 0);
230     delete data;
231   }
232 
SetLocalCertificate(const rtc::scoped_refptr<rtc::RTCCertificate> & certificate)233   bool SetLocalCertificate(
234       const rtc::scoped_refptr<rtc::RTCCertificate>& certificate) {
235     local_cert_ = certificate;
236     return true;
237   }
238 
SetRemoteSSLCertificate(rtc::FakeSSLCertificate * cert)239   void SetRemoteSSLCertificate(rtc::FakeSSLCertificate* cert) {
240     remote_cert_ = cert;
241   }
242 
IsDtlsActive()243   bool IsDtlsActive() const override { return do_dtls_; }
244 
SetSrtpCryptoSuites(const std::vector<int> & ciphers)245   bool SetSrtpCryptoSuites(const std::vector<int>& ciphers) override {
246     srtp_ciphers_ = ciphers;
247     return true;
248   }
249 
GetSrtpCryptoSuite(int * crypto_suite)250   bool GetSrtpCryptoSuite(int* crypto_suite) override {
251     if (chosen_crypto_suite_ != rtc::SRTP_INVALID_CRYPTO_SUITE) {
252       *crypto_suite = chosen_crypto_suite_;
253       return true;
254     }
255     return false;
256   }
257 
GetSslCipherSuite(int * cipher_suite)258   bool GetSslCipherSuite(int* cipher_suite) override { return false; }
259 
GetLocalCertificate()260   rtc::scoped_refptr<rtc::RTCCertificate> GetLocalCertificate() const {
261     return local_cert_;
262   }
263 
GetRemoteSSLCertificate(rtc::SSLCertificate ** cert)264   bool GetRemoteSSLCertificate(rtc::SSLCertificate** cert) const override {
265     if (!remote_cert_)
266       return false;
267 
268     *cert = remote_cert_->GetReference();
269     return true;
270   }
271 
ExportKeyingMaterial(const std::string & label,const uint8_t * context,size_t context_len,bool use_context,uint8_t * result,size_t result_len)272   bool ExportKeyingMaterial(const std::string& label,
273                             const uint8_t* context,
274                             size_t context_len,
275                             bool use_context,
276                             uint8_t* result,
277                             size_t result_len) override {
278     if (chosen_crypto_suite_ != rtc::SRTP_INVALID_CRYPTO_SUITE) {
279       memset(result, 0xff, result_len);
280       return true;
281     }
282 
283     return false;
284   }
285 
NegotiateSrtpCiphers()286   void NegotiateSrtpCiphers() {
287     for (std::vector<int>::const_iterator it1 = srtp_ciphers_.begin();
288          it1 != srtp_ciphers_.end(); ++it1) {
289       for (std::vector<int>::const_iterator it2 = dest_->srtp_ciphers_.begin();
290            it2 != dest_->srtp_ciphers_.end(); ++it2) {
291         if (*it1 == *it2) {
292           chosen_crypto_suite_ = *it1;
293           dest_->chosen_crypto_suite_ = *it2;
294           return;
295         }
296       }
297     }
298   }
299 
GetStats(ConnectionInfos * infos)300   bool GetStats(ConnectionInfos* infos) override {
301     ConnectionInfo info;
302     infos->clear();
303     infos->push_back(info);
304     return true;
305   }
306 
set_ssl_max_protocol_version(rtc::SSLProtocolVersion version)307   void set_ssl_max_protocol_version(rtc::SSLProtocolVersion version) {
308     ssl_max_version_ = version;
309   }
ssl_max_protocol_version()310   rtc::SSLProtocolVersion ssl_max_protocol_version() const {
311     return ssl_max_version_;
312   }
313 
314  private:
315   enum State { STATE_INIT, STATE_CONNECTING, STATE_CONNECTED };
316   Transport* transport_;
317   FakeTransportChannel* dest_ = nullptr;
318   State state_ = STATE_INIT;
319   bool async_ = false;
320   Candidates remote_candidates_;
321   rtc::scoped_refptr<rtc::RTCCertificate> local_cert_;
322   rtc::FakeSSLCertificate* remote_cert_ = nullptr;
323   bool do_dtls_ = false;
324   std::vector<int> srtp_ciphers_;
325   int chosen_crypto_suite_ = rtc::SRTP_INVALID_CRYPTO_SUITE;
326   int receiving_timeout_ = -1;
327   bool gather_continually_ = false;
328   IceRole role_ = ICEROLE_UNKNOWN;
329   uint64_t tiebreaker_ = 0;
330   std::string ice_ufrag_;
331   std::string ice_pwd_;
332   std::string remote_ice_ufrag_;
333   std::string remote_ice_pwd_;
334   IceMode remote_ice_mode_ = ICEMODE_FULL;
335   rtc::SSLProtocolVersion ssl_max_version_ = rtc::SSL_PROTOCOL_DTLS_12;
336   rtc::SSLFingerprint dtls_fingerprint_;
337   rtc::SSLRole ssl_role_ = rtc::SSL_CLIENT;
338   size_t connection_count_ = 0;
339   IceGatheringState gathering_state_ = kIceGatheringNew;
340   bool had_connection_ = false;
341 };
342 
343 // Fake transport class, which can be passed to anything that needs a Transport.
344 // Can be informed of another FakeTransport via SetDestination (low-tech way
345 // of doing candidates)
346 class FakeTransport : public Transport {
347  public:
348   typedef std::map<int, FakeTransportChannel*> ChannelMap;
349 
FakeTransport(const std::string & name)350   explicit FakeTransport(const std::string& name) : Transport(name, nullptr) {}
351 
352   // Note that we only have a constructor with the allocator parameter so it can
353   // be wrapped by a DtlsTransport.
FakeTransport(const std::string & name,PortAllocator * allocator)354   FakeTransport(const std::string& name, PortAllocator* allocator)
355       : Transport(name, nullptr) {}
356 
~FakeTransport()357   ~FakeTransport() { DestroyAllChannels(); }
358 
channels()359   const ChannelMap& channels() const { return channels_; }
360 
361   // If async, will send packets by "Post"-ing to message queue instead of
362   // synchronously "Send"-ing.
SetAsync(bool async)363   void SetAsync(bool async) { async_ = async; }
SetDestination(FakeTransport * dest)364   void SetDestination(FakeTransport* dest) {
365     dest_ = dest;
366     for (const auto& kv : channels_) {
367       kv.second->SetLocalCertificate(certificate_);
368       SetChannelDestination(kv.first, kv.second);
369     }
370   }
371 
SetWritable(bool writable)372   void SetWritable(bool writable) {
373     for (const auto& kv : channels_) {
374       kv.second->SetWritable(writable);
375     }
376   }
377 
SetLocalCertificate(const rtc::scoped_refptr<rtc::RTCCertificate> & certificate)378   void SetLocalCertificate(
379       const rtc::scoped_refptr<rtc::RTCCertificate>& certificate) override {
380     certificate_ = certificate;
381   }
GetLocalCertificate(rtc::scoped_refptr<rtc::RTCCertificate> * certificate)382   bool GetLocalCertificate(
383       rtc::scoped_refptr<rtc::RTCCertificate>* certificate) override {
384     if (!certificate_)
385       return false;
386 
387     *certificate = certificate_;
388     return true;
389   }
390 
GetSslRole(rtc::SSLRole * role)391   bool GetSslRole(rtc::SSLRole* role) const override {
392     if (channels_.empty()) {
393       return false;
394     }
395     return channels_.begin()->second->GetSslRole(role);
396   }
397 
SetSslMaxProtocolVersion(rtc::SSLProtocolVersion version)398   bool SetSslMaxProtocolVersion(rtc::SSLProtocolVersion version) override {
399     ssl_max_version_ = version;
400     for (const auto& kv : channels_) {
401       kv.second->set_ssl_max_protocol_version(ssl_max_version_);
402     }
403     return true;
404   }
ssl_max_protocol_version()405   rtc::SSLProtocolVersion ssl_max_protocol_version() const {
406     return ssl_max_version_;
407   }
408 
409   using Transport::local_description;
410   using Transport::remote_description;
411 
412  protected:
CreateTransportChannel(int component)413   TransportChannelImpl* CreateTransportChannel(int component) override {
414     if (channels_.find(component) != channels_.end()) {
415       return nullptr;
416     }
417     FakeTransportChannel* channel =
418         new FakeTransportChannel(this, name(), component);
419     channel->set_ssl_max_protocol_version(ssl_max_version_);
420     channel->SetAsync(async_);
421     SetChannelDestination(component, channel);
422     channels_[component] = channel;
423     return channel;
424   }
425 
DestroyTransportChannel(TransportChannelImpl * channel)426   void DestroyTransportChannel(TransportChannelImpl* channel) override {
427     channels_.erase(channel->component());
428     delete channel;
429   }
430 
431  private:
GetFakeChannel(int component)432   FakeTransportChannel* GetFakeChannel(int component) {
433     auto it = channels_.find(component);
434     return (it != channels_.end()) ? it->second : nullptr;
435   }
436 
SetChannelDestination(int component,FakeTransportChannel * channel)437   void SetChannelDestination(int component, FakeTransportChannel* channel) {
438     FakeTransportChannel* dest_channel = nullptr;
439     if (dest_) {
440       dest_channel = dest_->GetFakeChannel(component);
441       if (dest_channel) {
442         dest_channel->SetLocalCertificate(dest_->certificate_);
443       }
444     }
445     channel->SetDestination(dest_channel);
446   }
447 
448   // Note, this is distinct from the Channel map owned by Transport.
449   // This map just tracks the FakeTransportChannels created by this class.
450   // It's mainly needed so that we can access a FakeTransportChannel directly,
451   // even if wrapped by a DtlsTransportChannelWrapper.
452   ChannelMap channels_;
453   FakeTransport* dest_ = nullptr;
454   bool async_ = false;
455   rtc::scoped_refptr<rtc::RTCCertificate> certificate_;
456   rtc::SSLProtocolVersion ssl_max_version_ = rtc::SSL_PROTOCOL_DTLS_12;
457 };
458 
459 // Fake TransportController class, which can be passed into a BaseChannel object
460 // for test purposes. Can be connected to other FakeTransportControllers via
461 // Connect().
462 //
463 // This fake is unusual in that for the most part, it's implemented with the
464 // real TransportController code, but with fake TransportChannels underneath.
465 class FakeTransportController : public TransportController {
466  public:
FakeTransportController()467   FakeTransportController()
468       : TransportController(rtc::Thread::Current(),
469                             rtc::Thread::Current(),
470                             nullptr),
471         fail_create_channel_(false) {}
472 
FakeTransportController(IceRole role)473   explicit FakeTransportController(IceRole role)
474       : TransportController(rtc::Thread::Current(),
475                             rtc::Thread::Current(),
476                             nullptr),
477         fail_create_channel_(false) {
478     SetIceRole(role);
479   }
480 
FakeTransportController(rtc::Thread * worker_thread)481   explicit FakeTransportController(rtc::Thread* worker_thread)
482       : TransportController(rtc::Thread::Current(), worker_thread, nullptr),
483         fail_create_channel_(false) {}
484 
FakeTransportController(rtc::Thread * worker_thread,IceRole role)485   FakeTransportController(rtc::Thread* worker_thread, IceRole role)
486       : TransportController(rtc::Thread::Current(), worker_thread, nullptr),
487         fail_create_channel_(false) {
488     SetIceRole(role);
489   }
490 
GetTransport_w(const std::string & transport_name)491   FakeTransport* GetTransport_w(const std::string& transport_name) {
492     return static_cast<FakeTransport*>(
493         TransportController::GetTransport_w(transport_name));
494   }
495 
Connect(FakeTransportController * dest)496   void Connect(FakeTransportController* dest) {
497     worker_thread()->Invoke<void>(
498         rtc::Bind(&FakeTransportController::Connect_w, this, dest));
499   }
500 
CreateTransportChannel_w(const std::string & transport_name,int component)501   TransportChannel* CreateTransportChannel_w(const std::string& transport_name,
502                                              int component) override {
503     if (fail_create_channel_) {
504       return nullptr;
505     }
506     return TransportController::CreateTransportChannel_w(transport_name,
507                                                          component);
508   }
509 
set_fail_channel_creation(bool fail_channel_creation)510   void set_fail_channel_creation(bool fail_channel_creation) {
511     fail_create_channel_ = fail_channel_creation;
512   }
513 
514  protected:
CreateTransport_w(const std::string & transport_name)515   Transport* CreateTransport_w(const std::string& transport_name) override {
516     return new FakeTransport(transport_name);
517   }
518 
Connect_w(FakeTransportController * dest)519   void Connect_w(FakeTransportController* dest) {
520     // Simulate the exchange of candidates.
521     ConnectChannels_w();
522     dest->ConnectChannels_w();
523     for (auto& kv : transports()) {
524       FakeTransport* transport = static_cast<FakeTransport*>(kv.second);
525       transport->SetDestination(dest->GetTransport_w(kv.first));
526     }
527   }
528 
ConnectChannels_w()529   void ConnectChannels_w() {
530     for (auto& kv : transports()) {
531       FakeTransport* transport = static_cast<FakeTransport*>(kv.second);
532       transport->ConnectChannels();
533       transport->MaybeStartGathering();
534     }
535   }
536 
537  private:
538   bool fail_create_channel_;
539 };
540 
541 }  // namespace cricket
542 
543 #endif  // WEBRTC_P2P_BASE_FAKETRANSPORTCONTROLLER_H_
544