1 // Copyright 2015 The Weave Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "examples/provider/ssl_stream.h"
6
7 #include <openssl/err.h>
8
9 #include <base/bind.h>
10 #include <base/bind_helpers.h>
11 #include <weave/provider/task_runner.h>
12
13 namespace weave {
14 namespace examples {
15
16 namespace {
17
AddSslError(ErrorPtr * error,const tracked_objects::Location & location,const std::string & error_code,unsigned long ssl_error_code)18 void AddSslError(ErrorPtr* error,
19 const tracked_objects::Location& location,
20 const std::string& error_code,
21 unsigned long ssl_error_code) {
22 ERR_load_BIO_strings();
23 SSL_load_error_strings();
24 Error::AddToPrintf(error, location, error_code, "%s: %s",
25 ERR_lib_error_string(ssl_error_code),
26 ERR_reason_error_string(ssl_error_code));
27 }
28
RetryAsyncTask(provider::TaskRunner * task_runner,const tracked_objects::Location & location,const base::Closure & task)29 void RetryAsyncTask(provider::TaskRunner* task_runner,
30 const tracked_objects::Location& location,
31 const base::Closure& task) {
32 task_runner->PostDelayedTask(FROM_HERE, task,
33 base::TimeDelta::FromMilliseconds(100));
34 }
35
36 } // namespace
37
operator ()(BIO * bio) const38 void SSLStream::SslDeleter::operator()(BIO* bio) const {
39 BIO_free(bio);
40 }
41
operator ()(SSL * ssl) const42 void SSLStream::SslDeleter::operator()(SSL* ssl) const {
43 SSL_free(ssl);
44 }
45
operator ()(SSL_CTX * ctx) const46 void SSLStream::SslDeleter::operator()(SSL_CTX* ctx) const {
47 SSL_CTX_free(ctx);
48 }
49
SSLStream(provider::TaskRunner * task_runner,std::unique_ptr<BIO,SslDeleter> stream_bio)50 SSLStream::SSLStream(provider::TaskRunner* task_runner,
51 std::unique_ptr<BIO, SslDeleter> stream_bio)
52 : task_runner_{task_runner} {
53 ctx_.reset(SSL_CTX_new(TLSv1_2_client_method()));
54 CHECK(ctx_);
55 ssl_.reset(SSL_new(ctx_.get()));
56
57 SSL_set_bio(ssl_.get(), stream_bio.get(), stream_bio.get());
58 stream_bio.release(); // Owned by ssl now.
59 SSL_set_connect_state(ssl_.get());
60 }
61
~SSLStream()62 SSLStream::~SSLStream() {
63 CancelPendingOperations();
64 }
65
RunTask(const base::Closure & task)66 void SSLStream::RunTask(const base::Closure& task) {
67 task.Run();
68 }
69
Read(void * buffer,size_t size_to_read,const ReadCallback & callback)70 void SSLStream::Read(void* buffer,
71 size_t size_to_read,
72 const ReadCallback& callback) {
73 int res = SSL_read(ssl_.get(), buffer, size_to_read);
74 if (res > 0) {
75 task_runner_->PostDelayedTask(
76 FROM_HERE,
77 base::Bind(&SSLStream::RunTask, weak_ptr_factory_.GetWeakPtr(),
78 base::Bind(callback, res, nullptr)),
79 {});
80 return;
81 }
82
83 int err = SSL_get_error(ssl_.get(), res);
84
85 if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) {
86 return RetryAsyncTask(
87 task_runner_, FROM_HERE,
88 base::Bind(&SSLStream::Read, weak_ptr_factory_.GetWeakPtr(), buffer,
89 size_to_read, callback));
90 }
91
92 ErrorPtr weave_error;
93 AddSslError(&weave_error, FROM_HERE, "read_failed", err);
94 return task_runner_->PostDelayedTask(
95 FROM_HERE,
96 base::Bind(&SSLStream::RunTask, weak_ptr_factory_.GetWeakPtr(),
97 base::Bind(callback, 0, base::Passed(&weave_error))),
98 {});
99 }
100
Write(const void * buffer,size_t size_to_write,const WriteCallback & callback)101 void SSLStream::Write(const void* buffer,
102 size_t size_to_write,
103 const WriteCallback& callback) {
104 int res = SSL_write(ssl_.get(), buffer, size_to_write);
105 if (res > 0) {
106 buffer = static_cast<const char*>(buffer) + res;
107 size_to_write -= res;
108 if (size_to_write == 0) {
109 return task_runner_->PostDelayedTask(
110 FROM_HERE,
111 base::Bind(&SSLStream::RunTask, weak_ptr_factory_.GetWeakPtr(),
112 base::Bind(callback, nullptr)),
113 {});
114 }
115
116 return RetryAsyncTask(
117 task_runner_, FROM_HERE,
118 base::Bind(&SSLStream::Write, weak_ptr_factory_.GetWeakPtr(), buffer,
119 size_to_write, callback));
120 }
121
122 int err = SSL_get_error(ssl_.get(), res);
123
124 if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) {
125 return RetryAsyncTask(
126 task_runner_, FROM_HERE,
127 base::Bind(&SSLStream::Write, weak_ptr_factory_.GetWeakPtr(), buffer,
128 size_to_write, callback));
129 }
130
131 ErrorPtr weave_error;
132 AddSslError(&weave_error, FROM_HERE, "write_failed", err);
133 task_runner_->PostDelayedTask(
134 FROM_HERE, base::Bind(&SSLStream::RunTask, weak_ptr_factory_.GetWeakPtr(),
135 base::Bind(callback, base::Passed(&weave_error))),
136 {});
137 }
138
CancelPendingOperations()139 void SSLStream::CancelPendingOperations() {
140 weak_ptr_factory_.InvalidateWeakPtrs();
141 }
142
Connect(provider::TaskRunner * task_runner,const std::string & host,uint16_t port,const provider::Network::OpenSslSocketCallback & callback)143 void SSLStream::Connect(
144 provider::TaskRunner* task_runner,
145 const std::string& host,
146 uint16_t port,
147 const provider::Network::OpenSslSocketCallback& callback) {
148 SSL_library_init();
149
150 char end_point[255];
151 snprintf(end_point, sizeof(end_point), "%s:%u", host.c_str(), port);
152
153 std::unique_ptr<BIO, SslDeleter> stream_bio(BIO_new_connect(end_point));
154 CHECK(stream_bio);
155 BIO_set_nbio(stream_bio.get(), 1);
156
157 std::unique_ptr<SSLStream> stream{
158 new SSLStream{task_runner, std::move(stream_bio)}};
159 ConnectBio(std::move(stream), callback);
160 }
161
ConnectBio(std::unique_ptr<SSLStream> stream,const provider::Network::OpenSslSocketCallback & callback)162 void SSLStream::ConnectBio(
163 std::unique_ptr<SSLStream> stream,
164 const provider::Network::OpenSslSocketCallback& callback) {
165 BIO* bio = SSL_get_rbio(stream->ssl_.get());
166 if (BIO_do_connect(bio) == 1)
167 return DoHandshake(std::move(stream), callback);
168
169 auto task_runner = stream->task_runner_;
170 if (BIO_should_retry(bio)) {
171 return RetryAsyncTask(
172 task_runner, FROM_HERE,
173 base::Bind(&SSLStream::ConnectBio, base::Passed(&stream), callback));
174 }
175
176 ErrorPtr error;
177 AddSslError(&error, FROM_HERE, "connect_failed", ERR_get_error());
178 task_runner->PostDelayedTask(
179 FROM_HERE, base::Bind(callback, nullptr, base::Passed(&error)), {});
180 }
181
DoHandshake(std::unique_ptr<SSLStream> stream,const provider::Network::OpenSslSocketCallback & callback)182 void SSLStream::DoHandshake(
183 std::unique_ptr<SSLStream> stream,
184 const provider::Network::OpenSslSocketCallback& callback) {
185 int res = SSL_do_handshake(stream->ssl_.get());
186 auto task_runner = stream->task_runner_;
187 if (res == 1) {
188 return task_runner->PostDelayedTask(
189 FROM_HERE, base::Bind(callback, base::Passed(&stream), nullptr), {});
190 }
191
192 res = SSL_get_error(stream->ssl_.get(), res);
193
194 if (res == SSL_ERROR_WANT_READ || res == SSL_ERROR_WANT_WRITE) {
195 return RetryAsyncTask(
196 task_runner, FROM_HERE,
197 base::Bind(&SSLStream::DoHandshake, base::Passed(&stream), callback));
198 }
199
200 ErrorPtr error;
201 AddSslError(&error, FROM_HERE, "handshake_failed", res);
202 task_runner->PostDelayedTask(
203 FROM_HERE, base::Bind(callback, nullptr, base::Passed(&error)), {});
204 }
205
206 } // namespace examples
207 } // namespace weave
208