1 /*
2  *  Copyright 2014 The WebRTC Project Authors. All rights reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include <string>
12 
13 #include "webrtc/base/gunit.h"
14 #include "webrtc/base/ipaddress.h"
15 #include "webrtc/base/socketstream.h"
16 #include "webrtc/base/ssladapter.h"
17 #include "webrtc/base/sslstreamadapter.h"
18 #include "webrtc/base/sslidentity.h"
19 #include "webrtc/base/stream.h"
20 #include "webrtc/base/virtualsocketserver.h"
21 
22 static const int kTimeout = 5000;
23 
CreateSocket(const rtc::SSLMode & ssl_mode)24 static rtc::AsyncSocket* CreateSocket(const rtc::SSLMode& ssl_mode) {
25   rtc::SocketAddress address(rtc::IPAddress(INADDR_ANY), 0);
26 
27   rtc::AsyncSocket* socket = rtc::Thread::Current()->
28       socketserver()->CreateAsyncSocket(
29       address.family(), (ssl_mode == rtc::SSL_MODE_DTLS) ?
30       SOCK_DGRAM : SOCK_STREAM);
31   socket->Bind(address);
32 
33   return socket;
34 }
35 
GetSSLProtocolName(const rtc::SSLMode & ssl_mode)36 static std::string GetSSLProtocolName(const rtc::SSLMode& ssl_mode) {
37   return (ssl_mode == rtc::SSL_MODE_DTLS) ? "DTLS" : "TLS";
38 }
39 
40 class SSLAdapterTestDummyClient : public sigslot::has_slots<> {
41  public:
SSLAdapterTestDummyClient(const rtc::SSLMode & ssl_mode)42   explicit SSLAdapterTestDummyClient(const rtc::SSLMode& ssl_mode)
43       : ssl_mode_(ssl_mode) {
44     rtc::AsyncSocket* socket = CreateSocket(ssl_mode_);
45 
46     ssl_adapter_.reset(rtc::SSLAdapter::Create(socket));
47 
48     ssl_adapter_->SetMode(ssl_mode_);
49 
50     // Ignore any certificate errors for the purpose of testing.
51     // Note: We do this only because we don't have a real certificate.
52     // NEVER USE THIS IN PRODUCTION CODE!
53     ssl_adapter_->set_ignore_bad_cert(true);
54 
55     ssl_adapter_->SignalReadEvent.connect(this,
56         &SSLAdapterTestDummyClient::OnSSLAdapterReadEvent);
57     ssl_adapter_->SignalCloseEvent.connect(this,
58         &SSLAdapterTestDummyClient::OnSSLAdapterCloseEvent);
59   }
60 
GetAddress() const61   rtc::SocketAddress GetAddress() const {
62     return ssl_adapter_->GetLocalAddress();
63   }
64 
GetState() const65   rtc::AsyncSocket::ConnState GetState() const {
66     return ssl_adapter_->GetState();
67   }
68 
GetReceivedData() const69   const std::string& GetReceivedData() const {
70     return data_;
71   }
72 
Connect(const std::string & hostname,const rtc::SocketAddress & address)73   int Connect(const std::string& hostname, const rtc::SocketAddress& address) {
74     LOG(LS_INFO) << "Initiating connection with " << address;
75 
76     int rv = ssl_adapter_->Connect(address);
77 
78     if (rv == 0) {
79       LOG(LS_INFO) << "Starting " << GetSSLProtocolName(ssl_mode_)
80           << " handshake with " << hostname;
81 
82       if (ssl_adapter_->StartSSL(hostname.c_str(), false) != 0) {
83         return -1;
84       }
85     }
86 
87     return rv;
88   }
89 
Close()90   int Close() {
91     return ssl_adapter_->Close();
92   }
93 
Send(const std::string & message)94   int Send(const std::string& message) {
95     LOG(LS_INFO) << "Client sending '" << message << "'";
96 
97     return ssl_adapter_->Send(message.data(), message.length());
98   }
99 
OnSSLAdapterReadEvent(rtc::AsyncSocket * socket)100   void OnSSLAdapterReadEvent(rtc::AsyncSocket* socket) {
101     char buffer[4096] = "";
102 
103     // Read data received from the server and store it in our internal buffer.
104     int read = socket->Recv(buffer, sizeof(buffer) - 1);
105     if (read != -1) {
106       buffer[read] = '\0';
107 
108       LOG(LS_INFO) << "Client received '" << buffer << "'";
109 
110       data_ += buffer;
111     }
112   }
113 
OnSSLAdapterCloseEvent(rtc::AsyncSocket * socket,int error)114   void OnSSLAdapterCloseEvent(rtc::AsyncSocket* socket, int error) {
115     // OpenSSLAdapter signals handshake failure with a close event, but without
116     // closing the socket! Let's close the socket here. This way GetState() can
117     // return CS_CLOSED after failure.
118     if (socket->GetState() != rtc::AsyncSocket::CS_CLOSED) {
119       socket->Close();
120     }
121   }
122 
123  private:
124   const rtc::SSLMode ssl_mode_;
125 
126   rtc::scoped_ptr<rtc::SSLAdapter> ssl_adapter_;
127 
128   std::string data_;
129 };
130 
131 class SSLAdapterTestDummyServer : public sigslot::has_slots<> {
132  public:
SSLAdapterTestDummyServer(const rtc::SSLMode & ssl_mode,const rtc::KeyParams & key_params)133   explicit SSLAdapterTestDummyServer(const rtc::SSLMode& ssl_mode,
134                                      const rtc::KeyParams& key_params)
135       : ssl_mode_(ssl_mode) {
136     // Generate a key pair and a certificate for this host.
137     ssl_identity_.reset(rtc::SSLIdentity::Generate(GetHostname(), key_params));
138 
139     server_socket_.reset(CreateSocket(ssl_mode_));
140 
141     if (ssl_mode_ == rtc::SSL_MODE_TLS) {
142       server_socket_->SignalReadEvent.connect(this,
143           &SSLAdapterTestDummyServer::OnServerSocketReadEvent);
144 
145       server_socket_->Listen(1);
146     }
147 
148     LOG(LS_INFO) << ((ssl_mode_ == rtc::SSL_MODE_DTLS) ? "UDP" : "TCP")
149         << " server listening on " << server_socket_->GetLocalAddress();
150   }
151 
GetAddress() const152   rtc::SocketAddress GetAddress() const {
153     return server_socket_->GetLocalAddress();
154   }
155 
GetHostname() const156   std::string GetHostname() const {
157     // Since we don't have a real certificate anyway, the value here doesn't
158     // really matter.
159     return "example.com";
160   }
161 
GetReceivedData() const162   const std::string& GetReceivedData() const {
163     return data_;
164   }
165 
Send(const std::string & message)166   int Send(const std::string& message) {
167     if (ssl_stream_adapter_ == NULL
168         || ssl_stream_adapter_->GetState() != rtc::SS_OPEN) {
169       // No connection yet.
170       return -1;
171     }
172 
173     LOG(LS_INFO) << "Server sending '" << message << "'";
174 
175     size_t written;
176     int error;
177 
178     rtc::StreamResult r = ssl_stream_adapter_->Write(message.data(),
179         message.length(), &written, &error);
180     if (r == rtc::SR_SUCCESS) {
181       return written;
182     } else {
183       return -1;
184     }
185   }
186 
AcceptConnection(const rtc::SocketAddress & address)187   void AcceptConnection(const rtc::SocketAddress& address) {
188     // Only a single connection is supported.
189     ASSERT_TRUE(ssl_stream_adapter_ == NULL);
190 
191     // This is only for DTLS.
192     ASSERT_EQ(rtc::SSL_MODE_DTLS, ssl_mode_);
193 
194     // Transfer ownership of the socket to the SSLStreamAdapter object.
195     rtc::AsyncSocket* socket = server_socket_.release();
196 
197     socket->Connect(address);
198 
199     DoHandshake(socket);
200   }
201 
OnServerSocketReadEvent(rtc::AsyncSocket * socket)202   void OnServerSocketReadEvent(rtc::AsyncSocket* socket) {
203     // Only a single connection is supported.
204     ASSERT_TRUE(ssl_stream_adapter_ == NULL);
205 
206     DoHandshake(server_socket_->Accept(NULL));
207   }
208 
OnSSLStreamAdapterEvent(rtc::StreamInterface * stream,int sig,int err)209   void OnSSLStreamAdapterEvent(rtc::StreamInterface* stream, int sig, int err) {
210     if (sig & rtc::SE_READ) {
211       char buffer[4096] = "";
212 
213       size_t read;
214       int error;
215 
216       // Read data received from the client and store it in our internal
217       // buffer.
218       rtc::StreamResult r = stream->Read(buffer,
219           sizeof(buffer) - 1, &read, &error);
220       if (r == rtc::SR_SUCCESS) {
221         buffer[read] = '\0';
222 
223         LOG(LS_INFO) << "Server received '" << buffer << "'";
224 
225         data_ += buffer;
226       }
227     }
228   }
229 
230  private:
DoHandshake(rtc::AsyncSocket * socket)231   void DoHandshake(rtc::AsyncSocket* socket) {
232     rtc::SocketStream* stream = new rtc::SocketStream(socket);
233 
234     ssl_stream_adapter_.reset(rtc::SSLStreamAdapter::Create(stream));
235 
236     ssl_stream_adapter_->SetMode(ssl_mode_);
237     ssl_stream_adapter_->SetServerRole();
238 
239     // SSLStreamAdapter is normally used for peer-to-peer communication, but
240     // here we're testing communication between a client and a server
241     // (e.g. a WebRTC-based application and an RFC 5766 TURN server), where
242     // clients are not required to provide a certificate during handshake.
243     // Accordingly, we must disable client authentication here.
244     ssl_stream_adapter_->set_client_auth_enabled(false);
245 
246     ssl_stream_adapter_->SetIdentity(ssl_identity_->GetReference());
247 
248     // Set a bogus peer certificate digest.
249     unsigned char digest[20];
250     size_t digest_len = sizeof(digest);
251     ssl_stream_adapter_->SetPeerCertificateDigest(rtc::DIGEST_SHA_1, digest,
252         digest_len);
253 
254     ssl_stream_adapter_->StartSSLWithPeer();
255 
256     ssl_stream_adapter_->SignalEvent.connect(this,
257         &SSLAdapterTestDummyServer::OnSSLStreamAdapterEvent);
258   }
259 
260   const rtc::SSLMode ssl_mode_;
261 
262   rtc::scoped_ptr<rtc::AsyncSocket> server_socket_;
263   rtc::scoped_ptr<rtc::SSLStreamAdapter> ssl_stream_adapter_;
264 
265   rtc::scoped_ptr<rtc::SSLIdentity> ssl_identity_;
266 
267   std::string data_;
268 };
269 
270 class SSLAdapterTestBase : public testing::Test,
271                            public sigslot::has_slots<> {
272  public:
SSLAdapterTestBase(const rtc::SSLMode & ssl_mode,const rtc::KeyParams & key_params)273   explicit SSLAdapterTestBase(const rtc::SSLMode& ssl_mode,
274                               const rtc::KeyParams& key_params)
275       : ssl_mode_(ssl_mode),
276         ss_scope_(new rtc::VirtualSocketServer(NULL)),
277         server_(new SSLAdapterTestDummyServer(ssl_mode_, key_params)),
278         client_(new SSLAdapterTestDummyClient(ssl_mode_)),
279         handshake_wait_(kTimeout) {}
280 
SetHandshakeWait(int wait)281   void SetHandshakeWait(int wait) {
282     handshake_wait_ = wait;
283   }
284 
TestHandshake(bool expect_success)285   void TestHandshake(bool expect_success) {
286     int rv;
287 
288     // The initial state is CS_CLOSED
289     ASSERT_EQ(rtc::AsyncSocket::CS_CLOSED, client_->GetState());
290 
291     rv = client_->Connect(server_->GetHostname(), server_->GetAddress());
292     ASSERT_EQ(0, rv);
293 
294     // Now the state should be CS_CONNECTING
295     ASSERT_EQ(rtc::AsyncSocket::CS_CONNECTING, client_->GetState());
296 
297     if (ssl_mode_ == rtc::SSL_MODE_DTLS) {
298       // For DTLS, call AcceptConnection() with the client's address.
299       server_->AcceptConnection(client_->GetAddress());
300     }
301 
302     if (expect_success) {
303       // If expecting success, the client should end up in the CS_CONNECTED
304       // state after handshake.
305       EXPECT_EQ_WAIT(rtc::AsyncSocket::CS_CONNECTED, client_->GetState(),
306           handshake_wait_);
307 
308       LOG(LS_INFO) << GetSSLProtocolName(ssl_mode_) << " handshake complete.";
309 
310     } else {
311       // On handshake failure the client should end up in the CS_CLOSED state.
312       EXPECT_EQ_WAIT(rtc::AsyncSocket::CS_CLOSED, client_->GetState(),
313           handshake_wait_);
314 
315       LOG(LS_INFO) << GetSSLProtocolName(ssl_mode_) << " handshake failed.";
316     }
317   }
318 
TestTransfer(const std::string & message)319   void TestTransfer(const std::string& message) {
320     int rv;
321 
322     rv = client_->Send(message);
323     ASSERT_EQ(static_cast<int>(message.length()), rv);
324 
325     // The server should have received the client's message.
326     EXPECT_EQ_WAIT(message, server_->GetReceivedData(), kTimeout);
327 
328     rv = server_->Send(message);
329     ASSERT_EQ(static_cast<int>(message.length()), rv);
330 
331     // The client should have received the server's message.
332     EXPECT_EQ_WAIT(message, client_->GetReceivedData(), kTimeout);
333 
334     LOG(LS_INFO) << "Transfer complete.";
335   }
336 
337  private:
338   const rtc::SSLMode ssl_mode_;
339 
340   const rtc::SocketServerScope ss_scope_;
341 
342   rtc::scoped_ptr<SSLAdapterTestDummyServer> server_;
343   rtc::scoped_ptr<SSLAdapterTestDummyClient> client_;
344 
345   int handshake_wait_;
346 };
347 
348 class SSLAdapterTestTLS_RSA : public SSLAdapterTestBase {
349  public:
SSLAdapterTestTLS_RSA()350   SSLAdapterTestTLS_RSA()
351       : SSLAdapterTestBase(rtc::SSL_MODE_TLS, rtc::KeyParams::RSA()) {}
352 };
353 
354 class SSLAdapterTestTLS_ECDSA : public SSLAdapterTestBase {
355  public:
SSLAdapterTestTLS_ECDSA()356   SSLAdapterTestTLS_ECDSA()
357       : SSLAdapterTestBase(rtc::SSL_MODE_TLS, rtc::KeyParams::ECDSA()) {}
358 };
359 
360 class SSLAdapterTestDTLS_RSA : public SSLAdapterTestBase {
361  public:
SSLAdapterTestDTLS_RSA()362   SSLAdapterTestDTLS_RSA()
363       : SSLAdapterTestBase(rtc::SSL_MODE_DTLS, rtc::KeyParams::RSA()) {}
364 };
365 
366 class SSLAdapterTestDTLS_ECDSA : public SSLAdapterTestBase {
367  public:
SSLAdapterTestDTLS_ECDSA()368   SSLAdapterTestDTLS_ECDSA()
369       : SSLAdapterTestBase(rtc::SSL_MODE_DTLS, rtc::KeyParams::ECDSA()) {}
370 };
371 
372 #if SSL_USE_OPENSSL
373 
374 // Basic tests: TLS
375 
376 // Test that handshake works, using RSA
TEST_F(SSLAdapterTestTLS_RSA,TestTLSConnect)377 TEST_F(SSLAdapterTestTLS_RSA, TestTLSConnect) {
378   TestHandshake(true);
379 }
380 
381 // Test that handshake works, using ECDSA
TEST_F(SSLAdapterTestTLS_ECDSA,TestTLSConnect)382 TEST_F(SSLAdapterTestTLS_ECDSA, TestTLSConnect) {
383   TestHandshake(true);
384 }
385 
386 // Test transfer between client and server, using RSA
TEST_F(SSLAdapterTestTLS_RSA,TestTLSTransfer)387 TEST_F(SSLAdapterTestTLS_RSA, TestTLSTransfer) {
388   TestHandshake(true);
389   TestTransfer("Hello, world!");
390 }
391 
392 // Test transfer between client and server, using ECDSA
TEST_F(SSLAdapterTestTLS_ECDSA,TestTLSTransfer)393 TEST_F(SSLAdapterTestTLS_ECDSA, TestTLSTransfer) {
394   TestHandshake(true);
395   TestTransfer("Hello, world!");
396 }
397 
398 // Basic tests: DTLS
399 
400 // Test that handshake works, using RSA
TEST_F(SSLAdapterTestDTLS_RSA,TestDTLSConnect)401 TEST_F(SSLAdapterTestDTLS_RSA, TestDTLSConnect) {
402   TestHandshake(true);
403 }
404 
405 // Test that handshake works, using ECDSA
TEST_F(SSLAdapterTestDTLS_ECDSA,TestDTLSConnect)406 TEST_F(SSLAdapterTestDTLS_ECDSA, TestDTLSConnect) {
407   TestHandshake(true);
408 }
409 
410 // Test transfer between client and server, using RSA
TEST_F(SSLAdapterTestDTLS_RSA,TestDTLSTransfer)411 TEST_F(SSLAdapterTestDTLS_RSA, TestDTLSTransfer) {
412   TestHandshake(true);
413   TestTransfer("Hello, world!");
414 }
415 
416 // Test transfer between client and server, using ECDSA
TEST_F(SSLAdapterTestDTLS_ECDSA,TestDTLSTransfer)417 TEST_F(SSLAdapterTestDTLS_ECDSA, TestDTLSTransfer) {
418   TestHandshake(true);
419   TestTransfer("Hello, world!");
420 }
421 
422 #endif  // SSL_USE_OPENSSL
423