1 // Copyright 2015 The Weave Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "examples/provider/ssl_stream.h"
6 
7 #include <openssl/err.h>
8 
9 #include <base/bind.h>
10 #include <base/bind_helpers.h>
11 #include <weave/provider/task_runner.h>
12 
13 namespace weave {
14 namespace examples {
15 
16 namespace {
17 
AddSslError(ErrorPtr * error,const tracked_objects::Location & location,const std::string & error_code,unsigned long ssl_error_code)18 void AddSslError(ErrorPtr* error,
19                  const tracked_objects::Location& location,
20                  const std::string& error_code,
21                  unsigned long ssl_error_code) {
22   ERR_load_BIO_strings();
23   SSL_load_error_strings();
24   Error::AddToPrintf(error, location, error_code, "%s: %s",
25                      ERR_lib_error_string(ssl_error_code),
26                      ERR_reason_error_string(ssl_error_code));
27 }
28 
RetryAsyncTask(provider::TaskRunner * task_runner,const tracked_objects::Location & location,const base::Closure & task)29 void RetryAsyncTask(provider::TaskRunner* task_runner,
30                     const tracked_objects::Location& location,
31                     const base::Closure& task) {
32   task_runner->PostDelayedTask(FROM_HERE, task,
33                                base::TimeDelta::FromMilliseconds(100));
34 }
35 
36 }  // namespace
37 
operator ()(BIO * bio) const38 void SSLStream::SslDeleter::operator()(BIO* bio) const {
39   BIO_free(bio);
40 }
41 
operator ()(SSL * ssl) const42 void SSLStream::SslDeleter::operator()(SSL* ssl) const {
43   SSL_free(ssl);
44 }
45 
operator ()(SSL_CTX * ctx) const46 void SSLStream::SslDeleter::operator()(SSL_CTX* ctx) const {
47   SSL_CTX_free(ctx);
48 }
49 
SSLStream(provider::TaskRunner * task_runner,std::unique_ptr<BIO,SslDeleter> stream_bio)50 SSLStream::SSLStream(provider::TaskRunner* task_runner,
51                      std::unique_ptr<BIO, SslDeleter> stream_bio)
52     : task_runner_{task_runner} {
53   ctx_.reset(SSL_CTX_new(TLSv1_2_client_method()));
54   CHECK(ctx_);
55   ssl_.reset(SSL_new(ctx_.get()));
56 
57   SSL_set_bio(ssl_.get(), stream_bio.get(), stream_bio.get());
58   stream_bio.release();  // Owned by ssl now.
59   SSL_set_connect_state(ssl_.get());
60 }
61 
~SSLStream()62 SSLStream::~SSLStream() {
63   CancelPendingOperations();
64 }
65 
RunTask(const base::Closure & task)66 void SSLStream::RunTask(const base::Closure& task) {
67   task.Run();
68 }
69 
Read(void * buffer,size_t size_to_read,const ReadCallback & callback)70 void SSLStream::Read(void* buffer,
71                      size_t size_to_read,
72                      const ReadCallback& callback) {
73   int res = SSL_read(ssl_.get(), buffer, size_to_read);
74   if (res > 0) {
75     task_runner_->PostDelayedTask(
76         FROM_HERE,
77         base::Bind(&SSLStream::RunTask, weak_ptr_factory_.GetWeakPtr(),
78                    base::Bind(callback, res, nullptr)),
79         {});
80     return;
81   }
82 
83   int err = SSL_get_error(ssl_.get(), res);
84 
85   if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) {
86     return RetryAsyncTask(
87         task_runner_, FROM_HERE,
88         base::Bind(&SSLStream::Read, weak_ptr_factory_.GetWeakPtr(), buffer,
89                    size_to_read, callback));
90   }
91 
92   ErrorPtr weave_error;
93   AddSslError(&weave_error, FROM_HERE, "read_failed", err);
94   return task_runner_->PostDelayedTask(
95       FROM_HERE,
96       base::Bind(&SSLStream::RunTask, weak_ptr_factory_.GetWeakPtr(),
97                  base::Bind(callback, 0, base::Passed(&weave_error))),
98       {});
99 }
100 
Write(const void * buffer,size_t size_to_write,const WriteCallback & callback)101 void SSLStream::Write(const void* buffer,
102                       size_t size_to_write,
103                       const WriteCallback& callback) {
104   int res = SSL_write(ssl_.get(), buffer, size_to_write);
105   if (res > 0) {
106     buffer = static_cast<const char*>(buffer) + res;
107     size_to_write -= res;
108     if (size_to_write == 0) {
109       return task_runner_->PostDelayedTask(
110           FROM_HERE,
111           base::Bind(&SSLStream::RunTask, weak_ptr_factory_.GetWeakPtr(),
112                      base::Bind(callback, nullptr)),
113           {});
114     }
115 
116     return RetryAsyncTask(
117         task_runner_, FROM_HERE,
118         base::Bind(&SSLStream::Write, weak_ptr_factory_.GetWeakPtr(), buffer,
119                    size_to_write, callback));
120   }
121 
122   int err = SSL_get_error(ssl_.get(), res);
123 
124   if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) {
125     return RetryAsyncTask(
126         task_runner_, FROM_HERE,
127         base::Bind(&SSLStream::Write, weak_ptr_factory_.GetWeakPtr(), buffer,
128                    size_to_write, callback));
129   }
130 
131   ErrorPtr weave_error;
132   AddSslError(&weave_error, FROM_HERE, "write_failed", err);
133   task_runner_->PostDelayedTask(
134       FROM_HERE, base::Bind(&SSLStream::RunTask, weak_ptr_factory_.GetWeakPtr(),
135                             base::Bind(callback, base::Passed(&weave_error))),
136       {});
137 }
138 
CancelPendingOperations()139 void SSLStream::CancelPendingOperations() {
140   weak_ptr_factory_.InvalidateWeakPtrs();
141 }
142 
Connect(provider::TaskRunner * task_runner,const std::string & host,uint16_t port,const provider::Network::OpenSslSocketCallback & callback)143 void SSLStream::Connect(
144     provider::TaskRunner* task_runner,
145     const std::string& host,
146     uint16_t port,
147     const provider::Network::OpenSslSocketCallback& callback) {
148   SSL_library_init();
149 
150   char end_point[255];
151   snprintf(end_point, sizeof(end_point), "%s:%u", host.c_str(), port);
152 
153   std::unique_ptr<BIO, SslDeleter> stream_bio(BIO_new_connect(end_point));
154   CHECK(stream_bio);
155   BIO_set_nbio(stream_bio.get(), 1);
156 
157   std::unique_ptr<SSLStream> stream{
158       new SSLStream{task_runner, std::move(stream_bio)}};
159   ConnectBio(std::move(stream), callback);
160 }
161 
ConnectBio(std::unique_ptr<SSLStream> stream,const provider::Network::OpenSslSocketCallback & callback)162 void SSLStream::ConnectBio(
163     std::unique_ptr<SSLStream> stream,
164     const provider::Network::OpenSslSocketCallback& callback) {
165   BIO* bio = SSL_get_rbio(stream->ssl_.get());
166   if (BIO_do_connect(bio) == 1)
167     return DoHandshake(std::move(stream), callback);
168 
169   auto task_runner = stream->task_runner_;
170   if (BIO_should_retry(bio)) {
171     return RetryAsyncTask(
172         task_runner, FROM_HERE,
173         base::Bind(&SSLStream::ConnectBio, base::Passed(&stream), callback));
174   }
175 
176   ErrorPtr error;
177   AddSslError(&error, FROM_HERE, "connect_failed", ERR_get_error());
178   task_runner->PostDelayedTask(
179       FROM_HERE, base::Bind(callback, nullptr, base::Passed(&error)), {});
180 }
181 
DoHandshake(std::unique_ptr<SSLStream> stream,const provider::Network::OpenSslSocketCallback & callback)182 void SSLStream::DoHandshake(
183     std::unique_ptr<SSLStream> stream,
184     const provider::Network::OpenSslSocketCallback& callback) {
185   int res = SSL_do_handshake(stream->ssl_.get());
186   auto task_runner = stream->task_runner_;
187   if (res == 1) {
188     return task_runner->PostDelayedTask(
189         FROM_HERE, base::Bind(callback, base::Passed(&stream), nullptr), {});
190   }
191 
192   res = SSL_get_error(stream->ssl_.get(), res);
193 
194   if (res == SSL_ERROR_WANT_READ || res == SSL_ERROR_WANT_WRITE) {
195     return RetryAsyncTask(
196         task_runner, FROM_HERE,
197         base::Bind(&SSLStream::DoHandshake, base::Passed(&stream), callback));
198   }
199 
200   ErrorPtr error;
201   AddSslError(&error, FROM_HERE, "handshake_failed", res);
202   task_runner->PostDelayedTask(
203       FROM_HERE, base::Bind(callback, nullptr, base::Passed(&error)), {});
204 }
205 
206 }  // namespace examples
207 }  // namespace weave
208