1 /*
2  *  Copyright 2015 The WebRTC Project Authors. All rights reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #ifndef WEBRTC_P2P_BASE_TRANSPORTCONTROLLER_H_
12 #define WEBRTC_P2P_BASE_TRANSPORTCONTROLLER_H_
13 
14 #include <map>
15 #include <string>
16 #include <vector>
17 
18 #include "webrtc/base/sigslot.h"
19 #include "webrtc/base/sslstreamadapter.h"
20 #include "webrtc/p2p/base/candidate.h"
21 #include "webrtc/p2p/base/transport.h"
22 
23 namespace rtc {
24 class Thread;
25 }
26 
27 namespace cricket {
28 
29 class TransportController : public sigslot::has_slots<>,
30                             public rtc::MessageHandler {
31  public:
32   TransportController(rtc::Thread* signaling_thread,
33                       rtc::Thread* worker_thread,
34                       PortAllocator* port_allocator);
35 
36   virtual ~TransportController();
37 
signaling_thread()38   rtc::Thread* signaling_thread() const { return signaling_thread_; }
worker_thread()39   rtc::Thread* worker_thread() const { return worker_thread_; }
40 
port_allocator()41   PortAllocator* port_allocator() const { return port_allocator_; }
42 
43   // Can only be set before transports are created.
44   // TODO(deadbeef): Make this an argument to the constructor once BaseSession
45   // and WebRtcSession are combined
46   bool SetSslMaxProtocolVersion(rtc::SSLProtocolVersion version);
47 
48   void SetIceConfig(const IceConfig& config);
49   void SetIceRole(IceRole ice_role);
50 
51   bool GetSslRole(const std::string& transport_name, rtc::SSLRole* role);
52 
53   // Specifies the identity to use in this session.
54   // Can only be called once.
55   bool SetLocalCertificate(
56       const rtc::scoped_refptr<rtc::RTCCertificate>& certificate);
57   bool GetLocalCertificate(
58       const std::string& transport_name,
59       rtc::scoped_refptr<rtc::RTCCertificate>* certificate);
60   // Caller owns returned certificate
61   bool GetRemoteSSLCertificate(const std::string& transport_name,
62                                rtc::SSLCertificate** cert);
63   bool SetLocalTransportDescription(const std::string& transport_name,
64                                     const TransportDescription& tdesc,
65                                     ContentAction action,
66                                     std::string* err);
67   bool SetRemoteTransportDescription(const std::string& transport_name,
68                                      const TransportDescription& tdesc,
69                                      ContentAction action,
70                                      std::string* err);
71   // Start gathering candidates for any new transports, or transports doing an
72   // ICE restart.
73   void MaybeStartGathering();
74   bool AddRemoteCandidates(const std::string& transport_name,
75                            const Candidates& candidates,
76                            std::string* err);
77   bool ReadyForRemoteCandidates(const std::string& transport_name);
78   bool GetStats(const std::string& transport_name, TransportStats* stats);
79 
80   // Creates a channel if it doesn't exist. Otherwise, increments a reference
81   // count and returns an existing channel.
82   virtual TransportChannel* CreateTransportChannel_w(
83       const std::string& transport_name,
84       int component);
85 
86   // Decrements a channel's reference count, and destroys the channel if
87   // nothing is referencing it.
88   virtual void DestroyTransportChannel_w(const std::string& transport_name,
89                                          int component);
90 
91   // All of these signals are fired on the signalling thread.
92 
93   // If any transport failed => failed,
94   // Else if all completed => completed,
95   // Else if all connected => connected,
96   // Else => connecting
97   sigslot::signal1<IceConnectionState> SignalConnectionState;
98 
99   // Receiving if any transport is receiving
100   sigslot::signal1<bool> SignalReceiving;
101 
102   // If all transports done gathering => complete,
103   // Else if any are gathering => gathering,
104   // Else => new
105   sigslot::signal1<IceGatheringState> SignalGatheringState;
106 
107   // (transport_name, candidates)
108   sigslot::signal2<const std::string&, const Candidates&>
109       SignalCandidatesGathered;
110 
111   // for unit test
112   const rtc::scoped_refptr<rtc::RTCCertificate>& certificate_for_testing();
113 
114  protected:
115   // Protected and virtual so we can override it in unit tests.
116   virtual Transport* CreateTransport_w(const std::string& transport_name);
117 
118   // For unit tests
transports()119   const std::map<std::string, Transport*>& transports() { return transports_; }
120   Transport* GetTransport_w(const std::string& transport_name);
121 
122  private:
123   void OnMessage(rtc::Message* pmsg) override;
124 
125   // It's the Transport that's currently responsible for creating/destroying
126   // channels, but the TransportController keeps track of how many external
127   // objects (BaseChannels) reference each channel.
128   struct RefCountedChannel {
RefCountedChannelRefCountedChannel129     RefCountedChannel() : impl_(nullptr), ref_(0) {}
RefCountedChannelRefCountedChannel130     explicit RefCountedChannel(TransportChannelImpl* impl)
131         : impl_(impl), ref_(0) {}
132 
AddRefRefCountedChannel133     void AddRef() { ++ref_; }
DecRefRefCountedChannel134     void DecRef() {
135       ASSERT(ref_ > 0);
136       --ref_;
137     }
refRefCountedChannel138     int ref() const { return ref_; }
139 
getRefCountedChannel140     TransportChannelImpl* get() const { return impl_; }
141     TransportChannelImpl* operator->() const { return impl_; }
142 
143    private:
144     TransportChannelImpl* impl_;
145     int ref_;
146   };
147 
148   std::vector<RefCountedChannel>::iterator FindChannel_w(
149       const std::string& transport_name,
150       int component);
151 
152   Transport* GetOrCreateTransport_w(const std::string& transport_name);
153   void DestroyTransport_w(const std::string& transport_name);
154   void DestroyAllTransports_w();
155 
156   bool SetSslMaxProtocolVersion_w(rtc::SSLProtocolVersion version);
157   void SetIceConfig_w(const IceConfig& config);
158   void SetIceRole_w(IceRole ice_role);
159   bool GetSslRole_w(const std::string& transport_name, rtc::SSLRole* role);
160   bool SetLocalCertificate_w(
161       const rtc::scoped_refptr<rtc::RTCCertificate>& certificate);
162   bool GetLocalCertificate_w(
163       const std::string& transport_name,
164       rtc::scoped_refptr<rtc::RTCCertificate>* certificate);
165   bool GetRemoteSSLCertificate_w(const std::string& transport_name,
166                                  rtc::SSLCertificate** cert);
167   bool SetLocalTransportDescription_w(const std::string& transport_name,
168                                       const TransportDescription& tdesc,
169                                       ContentAction action,
170                                       std::string* err);
171   bool SetRemoteTransportDescription_w(const std::string& transport_name,
172                                        const TransportDescription& tdesc,
173                                        ContentAction action,
174                                        std::string* err);
175   void MaybeStartGathering_w();
176   bool AddRemoteCandidates_w(const std::string& transport_name,
177                              const Candidates& candidates,
178                              std::string* err);
179   bool ReadyForRemoteCandidates_w(const std::string& transport_name);
180   bool GetStats_w(const std::string& transport_name, TransportStats* stats);
181 
182   // Handlers for signals from Transport.
183   void OnChannelWritableState_w(TransportChannel* channel);
184   void OnChannelReceivingState_w(TransportChannel* channel);
185   void OnChannelGatheringState_w(TransportChannelImpl* channel);
186   void OnChannelCandidateGathered_w(TransportChannelImpl* channel,
187                                     const Candidate& candidate);
188   void OnChannelRoleConflict_w(TransportChannelImpl* channel);
189   void OnChannelConnectionRemoved_w(TransportChannelImpl* channel);
190 
191   void UpdateAggregateStates_w();
192 
193   rtc::Thread* const signaling_thread_ = nullptr;
194   rtc::Thread* const worker_thread_ = nullptr;
195   typedef std::map<std::string, Transport*> TransportMap;
196   TransportMap transports_;
197 
198   std::vector<RefCountedChannel> channels_;
199 
200   PortAllocator* const port_allocator_ = nullptr;
201   rtc::SSLProtocolVersion ssl_max_version_ = rtc::SSL_PROTOCOL_DTLS_12;
202 
203   // Aggregate state for TransportChannelImpls.
204   IceConnectionState connection_state_ = kIceConnectionConnecting;
205   bool receiving_ = false;
206   IceGatheringState gathering_state_ = kIceGatheringNew;
207 
208   // TODO(deadbeef): Move the fields below down to the transports themselves
209   IceConfig ice_config_;
210   IceRole ice_role_ = ICEROLE_CONTROLLING;
211   // Flag which will be set to true after the first role switch
212   bool ice_role_switch_ = false;
213   uint64_t ice_tiebreaker_ = rtc::CreateRandomId64();
214   rtc::scoped_refptr<rtc::RTCCertificate> certificate_;
215 };
216 
217 }  // namespace cricket
218 
219 #endif  // WEBRTC_P2P_BASE_TRANSPORTCONTROLLER_H_
220