1 // Copyright 2015 The Chromium OS Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include <brillo/streams/tls_stream.h>
6
7 #include <algorithm>
8 #include <limits>
9 #include <string>
10 #include <vector>
11
12 #include <openssl/err.h>
13 #include <openssl/ssl.h>
14
15 #include <base/bind.h>
16 #include <base/memory/weak_ptr.h>
17 #include <brillo/message_loops/message_loop.h>
18 #include <brillo/secure_blob.h>
19 #include <brillo/streams/openssl_stream_bio.h>
20 #include <brillo/streams/stream_utils.h>
21 #include <brillo/strings/string_utils.h>
22
23 namespace {
24
25 // SSL info callback which is called by OpenSSL when we enable logging level of
26 // at least 3. This logs the information about the internal TLS handshake.
TlsInfoCallback(const SSL *,int where,int ret)27 void TlsInfoCallback(const SSL* /* ssl */, int where, int ret) {
28 std::string reason;
29 std::vector<std::string> info;
30 if (where & SSL_CB_LOOP)
31 info.push_back("loop");
32 if (where & SSL_CB_EXIT)
33 info.push_back("exit");
34 if (where & SSL_CB_READ)
35 info.push_back("read");
36 if (where & SSL_CB_WRITE)
37 info.push_back("write");
38 if (where & SSL_CB_ALERT) {
39 info.push_back("alert");
40 reason = ", reason: ";
41 reason += SSL_alert_type_string_long(ret);
42 reason += "/";
43 reason += SSL_alert_desc_string_long(ret);
44 }
45 if (where & SSL_CB_HANDSHAKE_START)
46 info.push_back("handshake_start");
47 if (where & SSL_CB_HANDSHAKE_DONE)
48 info.push_back("handshake_done");
49
50 VLOG(3) << "TLS progress info: " << brillo::string_utils::Join(",", info)
51 << ", with status: " << ret << reason;
52 }
53
54 // Static variable to store the index of TlsStream private data in SSL context
55 // used to store custom data for OnCertVerifyResults().
56 int ssl_ctx_private_data_index = -1;
57
58 // Default trusted certificate store location.
59 const char kCACertificatePath[] =
60 #ifdef __ANDROID__
61 "/system/etc/security/cacerts_google";
62 #else
63 "/usr/share/chromeos-ca-certificates";
64 #endif
65
66 } // anonymous namespace
67
68 namespace brillo {
69
70 // Helper implementation of TLS stream used to hide most of OpenSSL inner
71 // workings from the users of brillo::TlsStream.
72 class TlsStream::TlsStreamImpl {
73 public:
74 TlsStreamImpl();
75 ~TlsStreamImpl();
76
77 bool Init(StreamPtr socket,
78 const std::string& host,
79 const base::Closure& success_callback,
80 const Stream::ErrorCallback& error_callback,
81 ErrorPtr* error);
82
83 bool ReadNonBlocking(void* buffer,
84 size_t size_to_read,
85 size_t* size_read,
86 bool* end_of_stream,
87 ErrorPtr* error);
88
89 bool WriteNonBlocking(const void* buffer,
90 size_t size_to_write,
91 size_t* size_written,
92 ErrorPtr* error);
93
94 bool Flush(ErrorPtr* error);
95 bool Close(ErrorPtr* error);
96 bool WaitForData(AccessMode mode,
97 const base::Callback<void(AccessMode)>& callback,
98 ErrorPtr* error);
99 bool WaitForDataBlocking(AccessMode in_mode,
100 base::TimeDelta timeout,
101 AccessMode* out_mode,
102 ErrorPtr* error);
103 void CancelPendingAsyncOperations();
104
105 private:
106 bool ReportError(ErrorPtr* error,
107 const tracked_objects::Location& location,
108 const std::string& message);
109 void DoHandshake(const base::Closure& success_callback,
110 const Stream::ErrorCallback& error_callback);
111 void RetryHandshake(const base::Closure& success_callback,
112 const Stream::ErrorCallback& error_callback,
113 Stream::AccessMode mode);
114
115 int OnCertVerifyResults(int ok, X509_STORE_CTX* ctx);
116 static int OnCertVerifyResultsStatic(int ok, X509_STORE_CTX* ctx);
117
118 StreamPtr socket_;
119 std::unique_ptr<SSL_CTX, decltype(&SSL_CTX_free)> ctx_{nullptr, SSL_CTX_free};
120 std::unique_ptr<SSL, decltype(&SSL_free)> ssl_{nullptr, SSL_free};
121 BIO* stream_bio_{nullptr};
122 bool need_more_read_{false};
123 bool need_more_write_{false};
124
125 base::WeakPtrFactory<TlsStreamImpl> weak_ptr_factory_{this};
126 DISALLOW_COPY_AND_ASSIGN(TlsStreamImpl);
127 };
128
TlsStreamImpl()129 TlsStream::TlsStreamImpl::TlsStreamImpl() {
130 SSL_load_error_strings();
131 SSL_library_init();
132 if (ssl_ctx_private_data_index < 0) {
133 ssl_ctx_private_data_index =
134 SSL_CTX_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
135 }
136 }
137
~TlsStreamImpl()138 TlsStream::TlsStreamImpl::~TlsStreamImpl() {
139 ssl_.reset();
140 ctx_.reset();
141 }
142
ReadNonBlocking(void * buffer,size_t size_to_read,size_t * size_read,bool * end_of_stream,ErrorPtr * error)143 bool TlsStream::TlsStreamImpl::ReadNonBlocking(void* buffer,
144 size_t size_to_read,
145 size_t* size_read,
146 bool* end_of_stream,
147 ErrorPtr* error) {
148 const size_t max_int = std::numeric_limits<int>::max();
149 int size_int = static_cast<int>(std::min(size_to_read, max_int));
150 int ret = SSL_read(ssl_.get(), buffer, size_int);
151 if (ret > 0) {
152 *size_read = static_cast<size_t>(ret);
153 if (end_of_stream)
154 *end_of_stream = false;
155 return true;
156 }
157
158 int err = SSL_get_error(ssl_.get(), ret);
159 if (err == SSL_ERROR_ZERO_RETURN) {
160 *size_read = 0;
161 if (end_of_stream)
162 *end_of_stream = true;
163 return true;
164 }
165
166 if (err == SSL_ERROR_WANT_READ) {
167 need_more_read_ = true;
168 } else if (err == SSL_ERROR_WANT_WRITE) {
169 // Writes might be required for SSL_read() because of possible TLS
170 // re-negotiations which can happen at any time.
171 need_more_write_ = true;
172 } else {
173 return ReportError(error, FROM_HERE, "Error reading from TLS socket");
174 }
175 *size_read = 0;
176 if (end_of_stream)
177 *end_of_stream = false;
178 return true;
179 }
180
WriteNonBlocking(const void * buffer,size_t size_to_write,size_t * size_written,ErrorPtr * error)181 bool TlsStream::TlsStreamImpl::WriteNonBlocking(const void* buffer,
182 size_t size_to_write,
183 size_t* size_written,
184 ErrorPtr* error) {
185 const size_t max_int = std::numeric_limits<int>::max();
186 int size_int = static_cast<int>(std::min(size_to_write, max_int));
187 int ret = SSL_write(ssl_.get(), buffer, size_int);
188 if (ret > 0) {
189 *size_written = static_cast<size_t>(ret);
190 return true;
191 }
192
193 int err = SSL_get_error(ssl_.get(), ret);
194 if (err == SSL_ERROR_WANT_READ) {
195 // Reads might be required for SSL_write() because of possible TLS
196 // re-negotiations which can happen at any time.
197 need_more_read_ = true;
198 } else if (err == SSL_ERROR_WANT_WRITE) {
199 need_more_write_ = true;
200 } else {
201 return ReportError(error, FROM_HERE, "Error writing to TLS socket");
202 }
203 *size_written = 0;
204 return true;
205 }
206
Flush(ErrorPtr * error)207 bool TlsStream::TlsStreamImpl::Flush(ErrorPtr* error) {
208 return socket_->FlushBlocking(error);
209 }
210
Close(ErrorPtr * error)211 bool TlsStream::TlsStreamImpl::Close(ErrorPtr* error) {
212 // 2 seconds should be plenty here.
213 const base::TimeDelta kTimeout = base::TimeDelta::FromSeconds(2);
214 // The retry count of 4 below is just arbitrary, to ensure we don't get stuck
215 // here forever. We should rarely need to repeat SSL_shutdown anyway.
216 for (int retry_count = 0; retry_count < 4; retry_count++) {
217 int ret = SSL_shutdown(ssl_.get());
218 // We really don't care for bi-directional shutdown here.
219 // Just make sure we only send the "close notify" alert to the remote peer.
220 if (ret >= 0)
221 break;
222
223 int err = SSL_get_error(ssl_.get(), ret);
224 if (err == SSL_ERROR_WANT_READ) {
225 if (!socket_->WaitForDataBlocking(AccessMode::READ, kTimeout, nullptr,
226 error)) {
227 break;
228 }
229 } else if (err == SSL_ERROR_WANT_WRITE) {
230 if (!socket_->WaitForDataBlocking(AccessMode::WRITE, kTimeout, nullptr,
231 error)) {
232 break;
233 }
234 } else {
235 LOG(ERROR) << "SSL_shutdown returned error #" << err;
236 ReportError(error, FROM_HERE, "Failed to shut down TLS socket");
237 break;
238 }
239 }
240 return socket_->CloseBlocking(error);
241 }
242
WaitForData(AccessMode mode,const base::Callback<void (AccessMode)> & callback,ErrorPtr * error)243 bool TlsStream::TlsStreamImpl::WaitForData(
244 AccessMode mode,
245 const base::Callback<void(AccessMode)>& callback,
246 ErrorPtr* error) {
247 bool is_read = stream_utils::IsReadAccessMode(mode);
248 bool is_write = stream_utils::IsWriteAccessMode(mode);
249 is_read |= need_more_read_;
250 is_write |= need_more_write_;
251 need_more_read_ = false;
252 need_more_write_ = false;
253 if (is_read && SSL_pending(ssl_.get()) > 0) {
254 callback.Run(AccessMode::READ);
255 return true;
256 }
257 mode = stream_utils::MakeAccessMode(is_read, is_write);
258 return socket_->WaitForData(mode, callback, error);
259 }
260
WaitForDataBlocking(AccessMode in_mode,base::TimeDelta timeout,AccessMode * out_mode,ErrorPtr * error)261 bool TlsStream::TlsStreamImpl::WaitForDataBlocking(AccessMode in_mode,
262 base::TimeDelta timeout,
263 AccessMode* out_mode,
264 ErrorPtr* error) {
265 bool is_read = stream_utils::IsReadAccessMode(in_mode);
266 bool is_write = stream_utils::IsWriteAccessMode(in_mode);
267 is_read |= need_more_read_;
268 is_write |= need_more_write_;
269 need_more_read_ = need_more_write_ = false;
270 if (is_read && SSL_pending(ssl_.get()) > 0) {
271 if (out_mode)
272 *out_mode = AccessMode::READ;
273 return true;
274 }
275 in_mode = stream_utils::MakeAccessMode(is_read, is_write);
276 return socket_->WaitForDataBlocking(in_mode, timeout, out_mode, error);
277 }
278
CancelPendingAsyncOperations()279 void TlsStream::TlsStreamImpl::CancelPendingAsyncOperations() {
280 socket_->CancelPendingAsyncOperations();
281 weak_ptr_factory_.InvalidateWeakPtrs();
282 }
283
ReportError(ErrorPtr * error,const tracked_objects::Location & location,const std::string & message)284 bool TlsStream::TlsStreamImpl::ReportError(
285 ErrorPtr* error,
286 const tracked_objects::Location& location,
287 const std::string& message) {
288 const char* file = nullptr;
289 int line = 0;
290 const char* data = 0;
291 int flags = 0;
292 while (auto errnum = ERR_get_error_line_data(&file, &line, &data, &flags)) {
293 char buf[256];
294 ERR_error_string_n(errnum, buf, sizeof(buf));
295 tracked_objects::Location ssl_location{"Unknown", file, line, nullptr};
296 std::string ssl_message = buf;
297 if (flags & ERR_TXT_STRING) {
298 ssl_message += ": ";
299 ssl_message += data;
300 }
301 Error::AddTo(error, ssl_location, "openssl", std::to_string(errnum),
302 ssl_message);
303 }
304 Error::AddTo(error, location, "tls_stream", "failed", message);
305 return false;
306 }
307
OnCertVerifyResults(int ok,X509_STORE_CTX * ctx)308 int TlsStream::TlsStreamImpl::OnCertVerifyResults(int ok, X509_STORE_CTX* ctx) {
309 // OpenSSL already performs a comprehensive check of the certificate chain
310 // (using X509_verify_cert() function) and calls back with the result of its
311 // verification.
312 // |ok| is set to 1 if the verification passed and 0 if an error was detected.
313 // Here we can perform some additional checks if we need to, or simply log
314 // the issues found.
315
316 // For now, just log an error if it occurred.
317 if (!ok) {
318 LOG(ERROR) << "Server certificate validation failed: "
319 << X509_verify_cert_error_string(X509_STORE_CTX_get_error(ctx));
320 }
321 return ok;
322 }
323
OnCertVerifyResultsStatic(int ok,X509_STORE_CTX * ctx)324 int TlsStream::TlsStreamImpl::OnCertVerifyResultsStatic(int ok,
325 X509_STORE_CTX* ctx) {
326 // Obtain the pointer to the instance of TlsStream::TlsStreamImpl from the
327 // SSL CTX object referenced by |ctx|.
328 SSL* ssl = static_cast<SSL*>(X509_STORE_CTX_get_ex_data(
329 ctx, SSL_get_ex_data_X509_STORE_CTX_idx()));
330 SSL_CTX* ssl_ctx = ssl ? SSL_get_SSL_CTX(ssl) : nullptr;
331 TlsStream::TlsStreamImpl* self = nullptr;
332 if (ssl_ctx) {
333 self = static_cast<TlsStream::TlsStreamImpl*>(SSL_CTX_get_ex_data(
334 ssl_ctx, ssl_ctx_private_data_index));
335 }
336 return self ? self->OnCertVerifyResults(ok, ctx) : ok;
337 }
338
Init(StreamPtr socket,const std::string & host,const base::Closure & success_callback,const Stream::ErrorCallback & error_callback,ErrorPtr * error)339 bool TlsStream::TlsStreamImpl::Init(StreamPtr socket,
340 const std::string& host,
341 const base::Closure& success_callback,
342 const Stream::ErrorCallback& error_callback,
343 ErrorPtr* error) {
344 ctx_.reset(SSL_CTX_new(TLSv1_2_client_method()));
345 if (!ctx_)
346 return ReportError(error, FROM_HERE, "Cannot create SSL_CTX");
347
348 // Top cipher suites supported by both Google GFEs and OpenSSL (in server
349 // preferred order).
350 int res = SSL_CTX_set_cipher_list(ctx_.get(),
351 "ECDHE-ECDSA-AES128-GCM-SHA256:"
352 "ECDHE-ECDSA-AES256-GCM-SHA384:"
353 "ECDHE-RSA-AES128-GCM-SHA256:"
354 "ECDHE-RSA-AES256-GCM-SHA384");
355 if (res != 1)
356 return ReportError(error, FROM_HERE, "Cannot set the cipher list");
357
358 res = SSL_CTX_load_verify_locations(ctx_.get(), nullptr, kCACertificatePath);
359 if (res != 1) {
360 return ReportError(error, FROM_HERE,
361 "Failed to specify trusted certificate location");
362 }
363
364 // Store a pointer to "this" into SSL_CTX instance.
365 SSL_CTX_set_ex_data(ctx_.get(), ssl_ctx_private_data_index, this);
366
367 // Ask OpenSSL to validate the server host from the certificate to match
368 // the expected host name we are given:
369 X509_VERIFY_PARAM* param = SSL_CTX_get0_param(ctx_.get());
370 X509_VERIFY_PARAM_set1_host(param, host.c_str(), host.size());
371
372 SSL_CTX_set_verify(ctx_.get(), SSL_VERIFY_PEER,
373 &TlsStreamImpl::OnCertVerifyResultsStatic);
374
375 socket_ = std::move(socket);
376 ssl_.reset(SSL_new(ctx_.get()));
377
378 // Enable TLS progress callback if VLOG level is >=3.
379 if (VLOG_IS_ON(3))
380 SSL_set_info_callback(ssl_.get(), TlsInfoCallback);
381
382 stream_bio_ = BIO_new_stream(socket_.get());
383 SSL_set_bio(ssl_.get(), stream_bio_, stream_bio_);
384 SSL_set_connect_state(ssl_.get());
385
386 // We might have no message loop (e.g. we are in unit tests).
387 if (MessageLoop::ThreadHasCurrent()) {
388 MessageLoop::current()->PostTask(
389 FROM_HERE,
390 base::Bind(&TlsStreamImpl::DoHandshake,
391 weak_ptr_factory_.GetWeakPtr(),
392 success_callback,
393 error_callback));
394 } else {
395 DoHandshake(success_callback, error_callback);
396 }
397 return true;
398 }
399
RetryHandshake(const base::Closure & success_callback,const Stream::ErrorCallback & error_callback,Stream::AccessMode)400 void TlsStream::TlsStreamImpl::RetryHandshake(
401 const base::Closure& success_callback,
402 const Stream::ErrorCallback& error_callback,
403 Stream::AccessMode /* mode */) {
404 VLOG(1) << "Retrying TLS handshake";
405 DoHandshake(success_callback, error_callback);
406 }
407
DoHandshake(const base::Closure & success_callback,const Stream::ErrorCallback & error_callback)408 void TlsStream::TlsStreamImpl::DoHandshake(
409 const base::Closure& success_callback,
410 const Stream::ErrorCallback& error_callback) {
411 VLOG(1) << "Begin TLS handshake";
412 int res = SSL_do_handshake(ssl_.get());
413 if (res == 1) {
414 VLOG(1) << "Handshake successful";
415 success_callback.Run();
416 return;
417 }
418 ErrorPtr error;
419 int err = SSL_get_error(ssl_.get(), res);
420 if (err == SSL_ERROR_WANT_READ) {
421 VLOG(1) << "Waiting for read data...";
422 bool ok = socket_->WaitForData(
423 Stream::AccessMode::READ,
424 base::Bind(&TlsStreamImpl::RetryHandshake,
425 weak_ptr_factory_.GetWeakPtr(),
426 success_callback, error_callback),
427 &error);
428 if (ok)
429 return;
430 } else if (err == SSL_ERROR_WANT_WRITE) {
431 VLOG(1) << "Waiting for write data...";
432 bool ok = socket_->WaitForData(
433 Stream::AccessMode::WRITE,
434 base::Bind(&TlsStreamImpl::RetryHandshake,
435 weak_ptr_factory_.GetWeakPtr(),
436 success_callback, error_callback),
437 &error);
438 if (ok)
439 return;
440 } else {
441 ReportError(&error, FROM_HERE, "TLS handshake failed.");
442 }
443 error_callback.Run(error.get());
444 }
445
446 /////////////////////////////////////////////////////////////////////////////
TlsStream(std::unique_ptr<TlsStreamImpl> impl)447 TlsStream::TlsStream(std::unique_ptr<TlsStreamImpl> impl)
448 : impl_{std::move(impl)} {}
449
~TlsStream()450 TlsStream::~TlsStream() {
451 if (impl_) {
452 impl_->Close(nullptr);
453 }
454 }
455
Connect(StreamPtr socket,const std::string & host,const base::Callback<void (StreamPtr)> & success_callback,const Stream::ErrorCallback & error_callback)456 void TlsStream::Connect(StreamPtr socket,
457 const std::string& host,
458 const base::Callback<void(StreamPtr)>& success_callback,
459 const Stream::ErrorCallback& error_callback) {
460 std::unique_ptr<TlsStreamImpl> impl{new TlsStreamImpl};
461 std::unique_ptr<TlsStream> stream{new TlsStream{std::move(impl)}};
462
463 TlsStreamImpl* pimpl = stream->impl_.get();
464 ErrorPtr error;
465 bool success = pimpl->Init(std::move(socket), host,
466 base::Bind(success_callback,
467 base::Passed(std::move(stream))),
468 error_callback, &error);
469
470 if (!success)
471 error_callback.Run(error.get());
472 }
473
IsOpen() const474 bool TlsStream::IsOpen() const {
475 return impl_ ? true : false;
476 }
477
SetSizeBlocking(uint64_t,ErrorPtr * error)478 bool TlsStream::SetSizeBlocking(uint64_t /* size */, ErrorPtr* error) {
479 return stream_utils::ErrorOperationNotSupported(FROM_HERE, error);
480 }
481
Seek(int64_t,Whence,uint64_t *,ErrorPtr * error)482 bool TlsStream::Seek(int64_t /* offset */,
483 Whence /* whence */,
484 uint64_t* /* new_position*/,
485 ErrorPtr* error) {
486 return stream_utils::ErrorOperationNotSupported(FROM_HERE, error);
487 }
488
ReadNonBlocking(void * buffer,size_t size_to_read,size_t * size_read,bool * end_of_stream,ErrorPtr * error)489 bool TlsStream::ReadNonBlocking(void* buffer,
490 size_t size_to_read,
491 size_t* size_read,
492 bool* end_of_stream,
493 ErrorPtr* error) {
494 if (!impl_)
495 return stream_utils::ErrorStreamClosed(FROM_HERE, error);
496 return impl_->ReadNonBlocking(buffer, size_to_read, size_read, end_of_stream,
497 error);
498 }
499
WriteNonBlocking(const void * buffer,size_t size_to_write,size_t * size_written,ErrorPtr * error)500 bool TlsStream::WriteNonBlocking(const void* buffer,
501 size_t size_to_write,
502 size_t* size_written,
503 ErrorPtr* error) {
504 if (!impl_)
505 return stream_utils::ErrorStreamClosed(FROM_HERE, error);
506 return impl_->WriteNonBlocking(buffer, size_to_write, size_written, error);
507 }
508
FlushBlocking(ErrorPtr * error)509 bool TlsStream::FlushBlocking(ErrorPtr* error) {
510 if (!impl_)
511 return stream_utils::ErrorStreamClosed(FROM_HERE, error);
512 return impl_->Flush(error);
513 }
514
CloseBlocking(ErrorPtr * error)515 bool TlsStream::CloseBlocking(ErrorPtr* error) {
516 if (impl_ && !impl_->Close(error))
517 return false;
518 impl_.reset();
519 return true;
520 }
521
WaitForData(AccessMode mode,const base::Callback<void (AccessMode)> & callback,ErrorPtr * error)522 bool TlsStream::WaitForData(AccessMode mode,
523 const base::Callback<void(AccessMode)>& callback,
524 ErrorPtr* error) {
525 if (!impl_)
526 return stream_utils::ErrorStreamClosed(FROM_HERE, error);
527 return impl_->WaitForData(mode, callback, error);
528 }
529
WaitForDataBlocking(AccessMode in_mode,base::TimeDelta timeout,AccessMode * out_mode,ErrorPtr * error)530 bool TlsStream::WaitForDataBlocking(AccessMode in_mode,
531 base::TimeDelta timeout,
532 AccessMode* out_mode,
533 ErrorPtr* error) {
534 if (!impl_)
535 return stream_utils::ErrorStreamClosed(FROM_HERE, error);
536 return impl_->WaitForDataBlocking(in_mode, timeout, out_mode, error);
537 }
538
CancelPendingAsyncOperations()539 void TlsStream::CancelPendingAsyncOperations() {
540 if (impl_)
541 impl_->CancelPendingAsyncOperations();
542 Stream::CancelPendingAsyncOperations();
543 }
544
545 } // namespace brillo
546