1 // Copyright 2015 The Chromium OS Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include <brillo/streams/tls_stream.h>
6 
7 #include <algorithm>
8 #include <limits>
9 #include <string>
10 #include <vector>
11 
12 #include <openssl/err.h>
13 #include <openssl/ssl.h>
14 
15 #include <base/bind.h>
16 #include <base/memory/weak_ptr.h>
17 #include <brillo/message_loops/message_loop.h>
18 #include <brillo/secure_blob.h>
19 #include <brillo/streams/openssl_stream_bio.h>
20 #include <brillo/streams/stream_utils.h>
21 #include <brillo/strings/string_utils.h>
22 
23 namespace {
24 
25 // SSL info callback which is called by OpenSSL when we enable logging level of
26 // at least 3. This logs the information about the internal TLS handshake.
TlsInfoCallback(const SSL *,int where,int ret)27 void TlsInfoCallback(const SSL* /* ssl */, int where, int ret) {
28   std::string reason;
29   std::vector<std::string> info;
30   if (where & SSL_CB_LOOP)
31     info.push_back("loop");
32   if (where & SSL_CB_EXIT)
33     info.push_back("exit");
34   if (where & SSL_CB_READ)
35     info.push_back("read");
36   if (where & SSL_CB_WRITE)
37     info.push_back("write");
38   if (where & SSL_CB_ALERT) {
39     info.push_back("alert");
40     reason = ", reason: ";
41     reason += SSL_alert_type_string_long(ret);
42     reason += "/";
43     reason += SSL_alert_desc_string_long(ret);
44   }
45   if (where & SSL_CB_HANDSHAKE_START)
46     info.push_back("handshake_start");
47   if (where & SSL_CB_HANDSHAKE_DONE)
48     info.push_back("handshake_done");
49 
50   VLOG(3) << "TLS progress info: " << brillo::string_utils::Join(",", info)
51           << ", with status: " << ret << reason;
52 }
53 
54 // Static variable to store the index of TlsStream private data in SSL context
55 // used to store custom data for OnCertVerifyResults().
56 int ssl_ctx_private_data_index = -1;
57 
58 // Default trusted certificate store location.
59 const char kCACertificatePath[] =
60 #ifdef __ANDROID__
61     "/system/etc/security/cacerts_google";
62 #else
63     "/usr/share/chromeos-ca-certificates";
64 #endif
65 
66 }  // anonymous namespace
67 
68 namespace brillo {
69 
70 // Helper implementation of TLS stream used to hide most of OpenSSL inner
71 // workings from the users of brillo::TlsStream.
72 class TlsStream::TlsStreamImpl {
73  public:
74   TlsStreamImpl();
75   ~TlsStreamImpl();
76 
77   bool Init(StreamPtr socket,
78             const std::string& host,
79             const base::Closure& success_callback,
80             const Stream::ErrorCallback& error_callback,
81             ErrorPtr* error);
82 
83   bool ReadNonBlocking(void* buffer,
84                        size_t size_to_read,
85                        size_t* size_read,
86                        bool* end_of_stream,
87                        ErrorPtr* error);
88 
89   bool WriteNonBlocking(const void* buffer,
90                         size_t size_to_write,
91                         size_t* size_written,
92                         ErrorPtr* error);
93 
94   bool Flush(ErrorPtr* error);
95   bool Close(ErrorPtr* error);
96   bool WaitForData(AccessMode mode,
97                    const base::Callback<void(AccessMode)>& callback,
98                    ErrorPtr* error);
99   bool WaitForDataBlocking(AccessMode in_mode,
100                            base::TimeDelta timeout,
101                            AccessMode* out_mode,
102                            ErrorPtr* error);
103   void CancelPendingAsyncOperations();
104 
105  private:
106   bool ReportError(ErrorPtr* error,
107                    const tracked_objects::Location& location,
108                    const std::string& message);
109   void DoHandshake(const base::Closure& success_callback,
110                    const Stream::ErrorCallback& error_callback);
111   void RetryHandshake(const base::Closure& success_callback,
112                       const Stream::ErrorCallback& error_callback,
113                       Stream::AccessMode mode);
114 
115   int OnCertVerifyResults(int ok, X509_STORE_CTX* ctx);
116   static int OnCertVerifyResultsStatic(int ok, X509_STORE_CTX* ctx);
117 
118   StreamPtr socket_;
119   std::unique_ptr<SSL_CTX, decltype(&SSL_CTX_free)> ctx_{nullptr, SSL_CTX_free};
120   std::unique_ptr<SSL, decltype(&SSL_free)> ssl_{nullptr, SSL_free};
121   BIO* stream_bio_{nullptr};
122   bool need_more_read_{false};
123   bool need_more_write_{false};
124 
125   base::WeakPtrFactory<TlsStreamImpl> weak_ptr_factory_{this};
126   DISALLOW_COPY_AND_ASSIGN(TlsStreamImpl);
127 };
128 
TlsStreamImpl()129 TlsStream::TlsStreamImpl::TlsStreamImpl() {
130   SSL_load_error_strings();
131   SSL_library_init();
132   if (ssl_ctx_private_data_index < 0) {
133     ssl_ctx_private_data_index =
134         SSL_CTX_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
135   }
136 }
137 
~TlsStreamImpl()138 TlsStream::TlsStreamImpl::~TlsStreamImpl() {
139   ssl_.reset();
140   ctx_.reset();
141 }
142 
ReadNonBlocking(void * buffer,size_t size_to_read,size_t * size_read,bool * end_of_stream,ErrorPtr * error)143 bool TlsStream::TlsStreamImpl::ReadNonBlocking(void* buffer,
144                                                size_t size_to_read,
145                                                size_t* size_read,
146                                                bool* end_of_stream,
147                                                ErrorPtr* error) {
148   const size_t max_int = std::numeric_limits<int>::max();
149   int size_int = static_cast<int>(std::min(size_to_read, max_int));
150   int ret = SSL_read(ssl_.get(), buffer, size_int);
151   if (ret > 0) {
152     *size_read = static_cast<size_t>(ret);
153     if (end_of_stream)
154       *end_of_stream = false;
155     return true;
156   }
157 
158   int err = SSL_get_error(ssl_.get(), ret);
159   if (err == SSL_ERROR_ZERO_RETURN) {
160     *size_read = 0;
161     if (end_of_stream)
162       *end_of_stream = true;
163     return true;
164   }
165 
166   if (err == SSL_ERROR_WANT_READ) {
167     need_more_read_ = true;
168   } else if (err == SSL_ERROR_WANT_WRITE) {
169     // Writes might be required for SSL_read() because of possible TLS
170     // re-negotiations which can happen at any time.
171     need_more_write_ = true;
172   } else {
173     return ReportError(error, FROM_HERE, "Error reading from TLS socket");
174   }
175   *size_read = 0;
176   if (end_of_stream)
177     *end_of_stream = false;
178   return true;
179 }
180 
WriteNonBlocking(const void * buffer,size_t size_to_write,size_t * size_written,ErrorPtr * error)181 bool TlsStream::TlsStreamImpl::WriteNonBlocking(const void* buffer,
182                                                 size_t size_to_write,
183                                                 size_t* size_written,
184                                                 ErrorPtr* error) {
185   const size_t max_int = std::numeric_limits<int>::max();
186   int size_int = static_cast<int>(std::min(size_to_write, max_int));
187   int ret = SSL_write(ssl_.get(), buffer, size_int);
188   if (ret > 0) {
189     *size_written = static_cast<size_t>(ret);
190     return true;
191   }
192 
193   int err = SSL_get_error(ssl_.get(), ret);
194   if (err == SSL_ERROR_WANT_READ) {
195     // Reads might be required for SSL_write() because of possible TLS
196     // re-negotiations which can happen at any time.
197     need_more_read_ = true;
198   } else if (err == SSL_ERROR_WANT_WRITE) {
199     need_more_write_ = true;
200   } else {
201     return ReportError(error, FROM_HERE, "Error writing to TLS socket");
202   }
203   *size_written = 0;
204   return true;
205 }
206 
Flush(ErrorPtr * error)207 bool TlsStream::TlsStreamImpl::Flush(ErrorPtr* error) {
208   return socket_->FlushBlocking(error);
209 }
210 
Close(ErrorPtr * error)211 bool TlsStream::TlsStreamImpl::Close(ErrorPtr* error) {
212   // 2 seconds should be plenty here.
213   const base::TimeDelta kTimeout = base::TimeDelta::FromSeconds(2);
214   // The retry count of 4 below is just arbitrary, to ensure we don't get stuck
215   // here forever. We should rarely need to repeat SSL_shutdown anyway.
216   for (int retry_count = 0; retry_count < 4; retry_count++) {
217     int ret = SSL_shutdown(ssl_.get());
218     // We really don't care for bi-directional shutdown here.
219     // Just make sure we only send the "close notify" alert to the remote peer.
220     if (ret >= 0)
221       break;
222 
223     int err = SSL_get_error(ssl_.get(), ret);
224     if (err == SSL_ERROR_WANT_READ) {
225       if (!socket_->WaitForDataBlocking(AccessMode::READ, kTimeout, nullptr,
226                                         error)) {
227         break;
228       }
229     } else if (err == SSL_ERROR_WANT_WRITE) {
230       if (!socket_->WaitForDataBlocking(AccessMode::WRITE, kTimeout, nullptr,
231                                         error)) {
232         break;
233       }
234     } else {
235       LOG(ERROR) << "SSL_shutdown returned error #" << err;
236       ReportError(error, FROM_HERE, "Failed to shut down TLS socket");
237       break;
238     }
239   }
240   return socket_->CloseBlocking(error);
241 }
242 
WaitForData(AccessMode mode,const base::Callback<void (AccessMode)> & callback,ErrorPtr * error)243 bool TlsStream::TlsStreamImpl::WaitForData(
244     AccessMode mode,
245     const base::Callback<void(AccessMode)>& callback,
246     ErrorPtr* error) {
247   bool is_read = stream_utils::IsReadAccessMode(mode);
248   bool is_write = stream_utils::IsWriteAccessMode(mode);
249   is_read |= need_more_read_;
250   is_write |= need_more_write_;
251   need_more_read_ = false;
252   need_more_write_ = false;
253   if (is_read && SSL_pending(ssl_.get()) > 0) {
254     callback.Run(AccessMode::READ);
255     return true;
256   }
257   mode = stream_utils::MakeAccessMode(is_read, is_write);
258   return socket_->WaitForData(mode, callback, error);
259 }
260 
WaitForDataBlocking(AccessMode in_mode,base::TimeDelta timeout,AccessMode * out_mode,ErrorPtr * error)261 bool TlsStream::TlsStreamImpl::WaitForDataBlocking(AccessMode in_mode,
262                                                    base::TimeDelta timeout,
263                                                    AccessMode* out_mode,
264                                                    ErrorPtr* error) {
265   bool is_read = stream_utils::IsReadAccessMode(in_mode);
266   bool is_write = stream_utils::IsWriteAccessMode(in_mode);
267   is_read |= need_more_read_;
268   is_write |= need_more_write_;
269   need_more_read_ = need_more_write_ = false;
270   if (is_read && SSL_pending(ssl_.get()) > 0) {
271     if (out_mode)
272       *out_mode = AccessMode::READ;
273     return true;
274   }
275   in_mode = stream_utils::MakeAccessMode(is_read, is_write);
276   return socket_->WaitForDataBlocking(in_mode, timeout, out_mode, error);
277 }
278 
CancelPendingAsyncOperations()279 void TlsStream::TlsStreamImpl::CancelPendingAsyncOperations() {
280   socket_->CancelPendingAsyncOperations();
281   weak_ptr_factory_.InvalidateWeakPtrs();
282 }
283 
ReportError(ErrorPtr * error,const tracked_objects::Location & location,const std::string & message)284 bool TlsStream::TlsStreamImpl::ReportError(
285     ErrorPtr* error,
286     const tracked_objects::Location& location,
287     const std::string& message) {
288   const char* file = nullptr;
289   int line = 0;
290   const char* data = 0;
291   int flags = 0;
292   while (auto errnum = ERR_get_error_line_data(&file, &line, &data, &flags)) {
293     char buf[256];
294     ERR_error_string_n(errnum, buf, sizeof(buf));
295     tracked_objects::Location ssl_location{"Unknown", file, line, nullptr};
296     std::string ssl_message = buf;
297     if (flags & ERR_TXT_STRING) {
298       ssl_message += ": ";
299       ssl_message += data;
300     }
301     Error::AddTo(error, ssl_location, "openssl", std::to_string(errnum),
302                  ssl_message);
303   }
304   Error::AddTo(error, location, "tls_stream", "failed", message);
305   return false;
306 }
307 
OnCertVerifyResults(int ok,X509_STORE_CTX * ctx)308 int TlsStream::TlsStreamImpl::OnCertVerifyResults(int ok, X509_STORE_CTX* ctx) {
309   // OpenSSL already performs a comprehensive check of the certificate chain
310   // (using X509_verify_cert() function) and calls back with the result of its
311   // verification.
312   // |ok| is set to 1 if the verification passed and 0 if an error was detected.
313   // Here we can perform some additional checks if we need to, or simply log
314   // the issues found.
315 
316   // For now, just log an error if it occurred.
317   if (!ok) {
318     LOG(ERROR) << "Server certificate validation failed: "
319                << X509_verify_cert_error_string(X509_STORE_CTX_get_error(ctx));
320   }
321   return ok;
322 }
323 
OnCertVerifyResultsStatic(int ok,X509_STORE_CTX * ctx)324 int TlsStream::TlsStreamImpl::OnCertVerifyResultsStatic(int ok,
325                                                         X509_STORE_CTX* ctx) {
326   // Obtain the pointer to the instance of TlsStream::TlsStreamImpl from the
327   // SSL CTX object referenced by |ctx|.
328   SSL* ssl = static_cast<SSL*>(X509_STORE_CTX_get_ex_data(
329       ctx, SSL_get_ex_data_X509_STORE_CTX_idx()));
330   SSL_CTX* ssl_ctx = ssl ? SSL_get_SSL_CTX(ssl) : nullptr;
331   TlsStream::TlsStreamImpl* self = nullptr;
332   if (ssl_ctx) {
333     self = static_cast<TlsStream::TlsStreamImpl*>(SSL_CTX_get_ex_data(
334         ssl_ctx, ssl_ctx_private_data_index));
335   }
336   return self ? self->OnCertVerifyResults(ok, ctx) : ok;
337 }
338 
Init(StreamPtr socket,const std::string & host,const base::Closure & success_callback,const Stream::ErrorCallback & error_callback,ErrorPtr * error)339 bool TlsStream::TlsStreamImpl::Init(StreamPtr socket,
340                                     const std::string& host,
341                                     const base::Closure& success_callback,
342                                     const Stream::ErrorCallback& error_callback,
343                                     ErrorPtr* error) {
344   ctx_.reset(SSL_CTX_new(TLSv1_2_client_method()));
345   if (!ctx_)
346     return ReportError(error, FROM_HERE, "Cannot create SSL_CTX");
347 
348   // Top cipher suites supported by both Google GFEs and OpenSSL (in server
349   // preferred order).
350   int res = SSL_CTX_set_cipher_list(ctx_.get(),
351                                     "ECDHE-ECDSA-AES128-GCM-SHA256:"
352                                     "ECDHE-ECDSA-AES256-GCM-SHA384:"
353                                     "ECDHE-RSA-AES128-GCM-SHA256:"
354                                     "ECDHE-RSA-AES256-GCM-SHA384");
355   if (res != 1)
356     return ReportError(error, FROM_HERE, "Cannot set the cipher list");
357 
358   res = SSL_CTX_load_verify_locations(ctx_.get(), nullptr, kCACertificatePath);
359   if (res != 1) {
360     return ReportError(error, FROM_HERE,
361                        "Failed to specify trusted certificate location");
362   }
363 
364   // Store a pointer to "this" into SSL_CTX instance.
365   SSL_CTX_set_ex_data(ctx_.get(), ssl_ctx_private_data_index, this);
366 
367   // Ask OpenSSL to validate the server host from the certificate to match
368   // the expected host name we are given:
369   X509_VERIFY_PARAM* param = SSL_CTX_get0_param(ctx_.get());
370   X509_VERIFY_PARAM_set1_host(param, host.c_str(), host.size());
371 
372   SSL_CTX_set_verify(ctx_.get(), SSL_VERIFY_PEER,
373                      &TlsStreamImpl::OnCertVerifyResultsStatic);
374 
375   socket_ = std::move(socket);
376   ssl_.reset(SSL_new(ctx_.get()));
377 
378   // Enable TLS progress callback if VLOG level is >=3.
379   if (VLOG_IS_ON(3))
380     SSL_set_info_callback(ssl_.get(), TlsInfoCallback);
381 
382   stream_bio_ = BIO_new_stream(socket_.get());
383   SSL_set_bio(ssl_.get(), stream_bio_, stream_bio_);
384   SSL_set_connect_state(ssl_.get());
385 
386   // We might have no message loop (e.g. we are in unit tests).
387   if (MessageLoop::ThreadHasCurrent()) {
388     MessageLoop::current()->PostTask(
389         FROM_HERE,
390         base::Bind(&TlsStreamImpl::DoHandshake,
391                    weak_ptr_factory_.GetWeakPtr(),
392                    success_callback,
393                    error_callback));
394   } else {
395     DoHandshake(success_callback, error_callback);
396   }
397   return true;
398 }
399 
RetryHandshake(const base::Closure & success_callback,const Stream::ErrorCallback & error_callback,Stream::AccessMode)400 void TlsStream::TlsStreamImpl::RetryHandshake(
401     const base::Closure& success_callback,
402     const Stream::ErrorCallback& error_callback,
403     Stream::AccessMode /* mode */) {
404   VLOG(1) << "Retrying TLS handshake";
405   DoHandshake(success_callback, error_callback);
406 }
407 
DoHandshake(const base::Closure & success_callback,const Stream::ErrorCallback & error_callback)408 void TlsStream::TlsStreamImpl::DoHandshake(
409     const base::Closure& success_callback,
410     const Stream::ErrorCallback& error_callback) {
411   VLOG(1) << "Begin TLS handshake";
412   int res = SSL_do_handshake(ssl_.get());
413   if (res == 1) {
414     VLOG(1) << "Handshake successful";
415     success_callback.Run();
416     return;
417   }
418   ErrorPtr error;
419   int err = SSL_get_error(ssl_.get(), res);
420   if (err == SSL_ERROR_WANT_READ) {
421     VLOG(1) << "Waiting for read data...";
422     bool ok = socket_->WaitForData(
423         Stream::AccessMode::READ,
424         base::Bind(&TlsStreamImpl::RetryHandshake,
425                    weak_ptr_factory_.GetWeakPtr(),
426                    success_callback, error_callback),
427         &error);
428     if (ok)
429       return;
430   } else if (err == SSL_ERROR_WANT_WRITE) {
431     VLOG(1) << "Waiting for write data...";
432     bool ok = socket_->WaitForData(
433         Stream::AccessMode::WRITE,
434         base::Bind(&TlsStreamImpl::RetryHandshake,
435                    weak_ptr_factory_.GetWeakPtr(),
436                    success_callback, error_callback),
437         &error);
438     if (ok)
439       return;
440   } else {
441     ReportError(&error, FROM_HERE, "TLS handshake failed.");
442   }
443   error_callback.Run(error.get());
444 }
445 
446 /////////////////////////////////////////////////////////////////////////////
TlsStream(std::unique_ptr<TlsStreamImpl> impl)447 TlsStream::TlsStream(std::unique_ptr<TlsStreamImpl> impl)
448     : impl_{std::move(impl)} {}
449 
~TlsStream()450 TlsStream::~TlsStream() {
451   if (impl_) {
452     impl_->Close(nullptr);
453   }
454 }
455 
Connect(StreamPtr socket,const std::string & host,const base::Callback<void (StreamPtr)> & success_callback,const Stream::ErrorCallback & error_callback)456 void TlsStream::Connect(StreamPtr socket,
457                         const std::string& host,
458                         const base::Callback<void(StreamPtr)>& success_callback,
459                         const Stream::ErrorCallback& error_callback) {
460   std::unique_ptr<TlsStreamImpl> impl{new TlsStreamImpl};
461   std::unique_ptr<TlsStream> stream{new TlsStream{std::move(impl)}};
462 
463   TlsStreamImpl* pimpl = stream->impl_.get();
464   ErrorPtr error;
465   bool success = pimpl->Init(std::move(socket), host,
466                              base::Bind(success_callback,
467                                         base::Passed(std::move(stream))),
468                              error_callback, &error);
469 
470   if (!success)
471     error_callback.Run(error.get());
472 }
473 
IsOpen() const474 bool TlsStream::IsOpen() const {
475   return impl_ ? true : false;
476 }
477 
SetSizeBlocking(uint64_t,ErrorPtr * error)478 bool TlsStream::SetSizeBlocking(uint64_t /* size */, ErrorPtr* error) {
479   return stream_utils::ErrorOperationNotSupported(FROM_HERE, error);
480 }
481 
Seek(int64_t,Whence,uint64_t *,ErrorPtr * error)482 bool TlsStream::Seek(int64_t /* offset */,
483                      Whence /* whence */,
484                      uint64_t* /* new_position*/,
485                      ErrorPtr* error) {
486   return stream_utils::ErrorOperationNotSupported(FROM_HERE, error);
487 }
488 
ReadNonBlocking(void * buffer,size_t size_to_read,size_t * size_read,bool * end_of_stream,ErrorPtr * error)489 bool TlsStream::ReadNonBlocking(void* buffer,
490                                 size_t size_to_read,
491                                 size_t* size_read,
492                                 bool* end_of_stream,
493                                 ErrorPtr* error) {
494   if (!impl_)
495     return stream_utils::ErrorStreamClosed(FROM_HERE, error);
496   return impl_->ReadNonBlocking(buffer, size_to_read, size_read, end_of_stream,
497                                 error);
498 }
499 
WriteNonBlocking(const void * buffer,size_t size_to_write,size_t * size_written,ErrorPtr * error)500 bool TlsStream::WriteNonBlocking(const void* buffer,
501                                  size_t size_to_write,
502                                  size_t* size_written,
503                                  ErrorPtr* error) {
504   if (!impl_)
505     return stream_utils::ErrorStreamClosed(FROM_HERE, error);
506   return impl_->WriteNonBlocking(buffer, size_to_write, size_written, error);
507 }
508 
FlushBlocking(ErrorPtr * error)509 bool TlsStream::FlushBlocking(ErrorPtr* error) {
510   if (!impl_)
511     return stream_utils::ErrorStreamClosed(FROM_HERE, error);
512   return impl_->Flush(error);
513 }
514 
CloseBlocking(ErrorPtr * error)515 bool TlsStream::CloseBlocking(ErrorPtr* error) {
516   if (impl_ && !impl_->Close(error))
517     return false;
518   impl_.reset();
519   return true;
520 }
521 
WaitForData(AccessMode mode,const base::Callback<void (AccessMode)> & callback,ErrorPtr * error)522 bool TlsStream::WaitForData(AccessMode mode,
523                             const base::Callback<void(AccessMode)>& callback,
524                             ErrorPtr* error) {
525   if (!impl_)
526     return stream_utils::ErrorStreamClosed(FROM_HERE, error);
527   return impl_->WaitForData(mode, callback, error);
528 }
529 
WaitForDataBlocking(AccessMode in_mode,base::TimeDelta timeout,AccessMode * out_mode,ErrorPtr * error)530 bool TlsStream::WaitForDataBlocking(AccessMode in_mode,
531                                     base::TimeDelta timeout,
532                                     AccessMode* out_mode,
533                                     ErrorPtr* error) {
534   if (!impl_)
535     return stream_utils::ErrorStreamClosed(FROM_HERE, error);
536   return impl_->WaitForDataBlocking(in_mode, timeout, out_mode, error);
537 }
538 
CancelPendingAsyncOperations()539 void TlsStream::CancelPendingAsyncOperations() {
540   if (impl_)
541     impl_->CancelPendingAsyncOperations();
542   Stream::CancelPendingAsyncOperations();
543 }
544 
545 }  // namespace brillo
546