1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except
6  * in compliance with the License. You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "adb/tls/tls_connection.h"
18 
19 #include <limits.h>
20 
21 #include <algorithm>
22 #include <vector>
23 
24 #include <android-base/logging.h>
25 #include <android-base/strings.h>
26 #include <openssl/err.h>
27 #include <openssl/ssl.h>
28 
29 using android::base::borrowed_fd;
30 
31 namespace adb {
32 namespace tls {
33 
34 namespace {
35 
36 static constexpr char kExportedKeyLabel[] = "adb-label";
37 
38 class TlsConnectionImpl : public TlsConnection {
39   public:
40     explicit TlsConnectionImpl(Role role, std::string_view cert, std::string_view priv_key,
41                                borrowed_fd fd);
42     ~TlsConnectionImpl() override;
43 
44     bool AddTrustedCertificate(std::string_view cert) override;
45     void SetCertVerifyCallback(CertVerifyCb cb) override;
46     void SetCertificateCallback(SetCertCb cb) override;
47     void SetClientCAList(STACK_OF(X509_NAME) * ca_list) override;
48     std::vector<uint8_t> ExportKeyingMaterial(size_t length) override;
49     void EnableClientPostHandshakeCheck(bool enable) override;
50     TlsError DoHandshake() override;
51     std::vector<uint8_t> ReadFully(size_t size) override;
52     bool ReadFully(void* buf, size_t size) override;
53     bool WriteFully(std::string_view data) override;
54 
55     static bssl::UniquePtr<EVP_PKEY> EvpPkeyFromPEM(std::string_view pem);
56     static bssl::UniquePtr<CRYPTO_BUFFER> BufferFromPEM(std::string_view pem);
57 
58   private:
59     static int SSLSetCertVerifyCb(X509_STORE_CTX* ctx, void* opaque);
60     static int SSLSetCertCb(SSL* ssl, void* opaque);
61 
62     static bssl::UniquePtr<X509> X509FromBuffer(bssl::UniquePtr<CRYPTO_BUFFER> buffer);
63     static const char* SSLErrorString();
64     void Invalidate();
65     TlsError GetFailureReason(int err);
RoleToString()66     const char* RoleToString() { return role_ == Role::Server ? kServerRoleStr : kClientRoleStr; }
67 
68     Role role_;
69     bssl::UniquePtr<EVP_PKEY> priv_key_;
70     bssl::UniquePtr<CRYPTO_BUFFER> cert_;
71 
72     bssl::UniquePtr<STACK_OF(X509_NAME)> ca_list_;
73     bssl::UniquePtr<SSL_CTX> ssl_ctx_;
74     bssl::UniquePtr<SSL> ssl_;
75     std::vector<bssl::UniquePtr<X509>> known_certificates_;
76     bool client_verify_post_handshake_ = false;
77 
78     CertVerifyCb cert_verify_cb_;
79     SetCertCb set_cert_cb_;
80     borrowed_fd fd_;
81     static constexpr char kClientRoleStr[] = "[client]: ";
82     static constexpr char kServerRoleStr[] = "[server]: ";
83 };  // TlsConnectionImpl
84 
TlsConnectionImpl(Role role,std::string_view cert,std::string_view priv_key,borrowed_fd fd)85 TlsConnectionImpl::TlsConnectionImpl(Role role, std::string_view cert, std::string_view priv_key,
86                                      borrowed_fd fd)
87     : role_(role), fd_(fd) {
88     CHECK(!cert.empty() && !priv_key.empty());
89     LOG(INFO) << RoleToString() << "Initializing adbwifi TlsConnection";
90     cert_ = BufferFromPEM(cert);
91     CHECK(cert_);
92     priv_key_ = EvpPkeyFromPEM(priv_key);
93     CHECK(priv_key_);
94 }
95 
~TlsConnectionImpl()96 TlsConnectionImpl::~TlsConnectionImpl() {
97     // shutdown the SSL connection
98     if (ssl_ != nullptr) {
99         SSL_shutdown(ssl_.get());
100     }
101 }
102 
103 // static
SSLErrorString()104 const char* TlsConnectionImpl::SSLErrorString() {
105     auto sslerr = ERR_peek_last_error();
106     return ERR_reason_error_string(sslerr);
107 }
108 
109 // static
EvpPkeyFromPEM(std::string_view pem)110 bssl::UniquePtr<EVP_PKEY> TlsConnectionImpl::EvpPkeyFromPEM(std::string_view pem) {
111     bssl::UniquePtr<BIO> bio(BIO_new_mem_buf(pem.data(), pem.size()));
112     return bssl::UniquePtr<EVP_PKEY>(PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr, nullptr));
113 }
114 
115 // static
BufferFromPEM(std::string_view pem)116 bssl::UniquePtr<CRYPTO_BUFFER> TlsConnectionImpl::BufferFromPEM(std::string_view pem) {
117     bssl::UniquePtr<BIO> bio(BIO_new_mem_buf(pem.data(), pem.size()));
118     char* name = nullptr;
119     char* header = nullptr;
120     uint8_t* data = nullptr;
121     long data_len = 0;
122 
123     if (!PEM_read_bio(bio.get(), &name, &header, &data, &data_len)) {
124         LOG(ERROR) << "Failed to read certificate";
125         return nullptr;
126     }
127     OPENSSL_free(name);
128     OPENSSL_free(header);
129 
130     auto ret = bssl::UniquePtr<CRYPTO_BUFFER>(CRYPTO_BUFFER_new(data, data_len, nullptr));
131     OPENSSL_free(data);
132     return ret;
133 }
134 
135 // static
X509FromBuffer(bssl::UniquePtr<CRYPTO_BUFFER> buffer)136 bssl::UniquePtr<X509> TlsConnectionImpl::X509FromBuffer(bssl::UniquePtr<CRYPTO_BUFFER> buffer) {
137     if (!buffer) {
138         return nullptr;
139     }
140     return bssl::UniquePtr<X509>(X509_parse_from_buffer(buffer.get()));
141 }
142 
143 // static
SSLSetCertVerifyCb(X509_STORE_CTX * ctx,void * opaque)144 int TlsConnectionImpl::SSLSetCertVerifyCb(X509_STORE_CTX* ctx, void* opaque) {
145     auto* p = reinterpret_cast<TlsConnectionImpl*>(opaque);
146     return p->cert_verify_cb_(ctx);
147 }
148 
149 // static
SSLSetCertCb(SSL * ssl,void * opaque)150 int TlsConnectionImpl::SSLSetCertCb(SSL* ssl, void* opaque) {
151     auto* p = reinterpret_cast<TlsConnectionImpl*>(opaque);
152     return p->set_cert_cb_(ssl);
153 }
154 
AddTrustedCertificate(std::string_view cert)155 bool TlsConnectionImpl::AddTrustedCertificate(std::string_view cert) {
156     // Create X509 buffer from the certificate string
157     auto buf = X509FromBuffer(BufferFromPEM(cert));
158     if (buf == nullptr) {
159         LOG(ERROR) << RoleToString() << "Failed to create a X509 buffer for the certificate.";
160         return false;
161     }
162     known_certificates_.push_back(std::move(buf));
163     return true;
164 }
165 
SetCertVerifyCallback(CertVerifyCb cb)166 void TlsConnectionImpl::SetCertVerifyCallback(CertVerifyCb cb) {
167     cert_verify_cb_ = cb;
168 }
169 
SetCertificateCallback(SetCertCb cb)170 void TlsConnectionImpl::SetCertificateCallback(SetCertCb cb) {
171     set_cert_cb_ = cb;
172 }
173 
SetClientCAList(STACK_OF (X509_NAME)* ca_list)174 void TlsConnectionImpl::SetClientCAList(STACK_OF(X509_NAME) * ca_list) {
175     CHECK(role_ == Role::Server);
176     ca_list_.reset(ca_list != nullptr ? SSL_dup_CA_list(ca_list) : nullptr);
177 }
178 
ExportKeyingMaterial(size_t length)179 std::vector<uint8_t> TlsConnectionImpl::ExportKeyingMaterial(size_t length) {
180     if (ssl_.get() == nullptr) {
181         return {};
182     }
183 
184     std::vector<uint8_t> out(length);
185     if (SSL_export_keying_material(ssl_.get(), out.data(), out.size(), kExportedKeyLabel,
186                                    sizeof(kExportedKeyLabel), nullptr, 0, false) == 0) {
187         return {};
188     }
189     return out;
190 }
191 
EnableClientPostHandshakeCheck(bool enable)192 void TlsConnectionImpl::EnableClientPostHandshakeCheck(bool enable) {
193     client_verify_post_handshake_ = enable;
194 }
195 
GetFailureReason(int err)196 TlsConnection::TlsError TlsConnectionImpl::GetFailureReason(int err) {
197     switch (ERR_GET_REASON(err)) {
198         case SSL_R_SSLV3_ALERT_BAD_CERTIFICATE:
199         case SSL_R_SSLV3_ALERT_UNSUPPORTED_CERTIFICATE:
200         case SSL_R_SSLV3_ALERT_CERTIFICATE_REVOKED:
201         case SSL_R_SSLV3_ALERT_CERTIFICATE_EXPIRED:
202         case SSL_R_SSLV3_ALERT_CERTIFICATE_UNKNOWN:
203         case SSL_R_TLSV1_ALERT_ACCESS_DENIED:
204         case SSL_R_TLSV1_ALERT_UNKNOWN_CA:
205         case SSL_R_TLSV1_CERTIFICATE_REQUIRED:
206             return TlsError::PeerRejectedCertificate;
207         case SSL_R_CERTIFICATE_VERIFY_FAILED:
208             return TlsError::CertificateRejected;
209         default:
210             return TlsError::UnknownFailure;
211     }
212 }
213 
DoHandshake()214 TlsConnection::TlsError TlsConnectionImpl::DoHandshake() {
215     LOG(INFO) << RoleToString() << "Starting adbwifi tls handshake";
216     ssl_ctx_.reset(SSL_CTX_new(TLS_method()));
217     // TODO: Remove set_max_proto_version() once external/boringssl is updated
218     // past
219     // https://boringssl.googlesource.com/boringssl/+/58d56f4c59969a23e5f52014e2651c76fea2f877
220     if (ssl_ctx_.get() == nullptr ||
221         !SSL_CTX_set_min_proto_version(ssl_ctx_.get(), TLS1_3_VERSION) ||
222         !SSL_CTX_set_max_proto_version(ssl_ctx_.get(), TLS1_3_VERSION)) {
223         LOG(ERROR) << RoleToString() << "Failed to create SSL context";
224         return TlsError::UnknownFailure;
225     }
226 
227     // Register user-supplied known certificates
228     for (auto const& cert : known_certificates_) {
229         if (X509_STORE_add_cert(SSL_CTX_get_cert_store(ssl_ctx_.get()), cert.get()) == 0) {
230             LOG(ERROR) << RoleToString() << "Unable to add certificates into the X509_STORE";
231             return TlsError::UnknownFailure;
232         }
233     }
234 
235     // Custom certificate verification
236     if (cert_verify_cb_) {
237         SSL_CTX_set_cert_verify_callback(ssl_ctx_.get(), SSLSetCertVerifyCb, this);
238     }
239 
240     // set select certificate callback, if any.
241     if (set_cert_cb_) {
242         SSL_CTX_set_cert_cb(ssl_ctx_.get(), SSLSetCertCb, this);
243     }
244 
245     // Server-allowed client CA list
246     if (ca_list_ != nullptr) {
247         bssl::UniquePtr<STACK_OF(X509_NAME)> names(SSL_dup_CA_list(ca_list_.get()));
248         SSL_CTX_set_client_CA_list(ssl_ctx_.get(), names.release());
249     }
250 
251     // Register our certificate and private key.
252     std::vector<CRYPTO_BUFFER*> cert_chain = {
253             cert_.get(),
254     };
255     if (!SSL_CTX_set_chain_and_key(ssl_ctx_.get(), cert_chain.data(), cert_chain.size(),
256                                    priv_key_.get(), nullptr)) {
257         LOG(ERROR) << RoleToString()
258                    << "Unable to register the certificate chain file and private key ["
259                    << SSLErrorString() << "]";
260         Invalidate();
261         return TlsError::UnknownFailure;
262     }
263 
264     SSL_CTX_set_verify(ssl_ctx_.get(), SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr);
265 
266     // Okay! Let's try to do the handshake!
267     ssl_.reset(SSL_new(ssl_ctx_.get()));
268     if (!SSL_set_fd(ssl_.get(), fd_.get())) {
269         LOG(ERROR) << RoleToString() << "SSL_set_fd failed. [" << SSLErrorString() << "]";
270         return TlsError::UnknownFailure;
271     }
272 
273     switch (role_) {
274         case Role::Server:
275             SSL_set_accept_state(ssl_.get());
276             break;
277         case Role::Client:
278             SSL_set_connect_state(ssl_.get());
279             break;
280     }
281     if (SSL_do_handshake(ssl_.get()) != 1) {
282         LOG(ERROR) << RoleToString() << "Handshake failed in SSL_accept/SSL_connect ["
283                    << SSLErrorString() << "]";
284         auto sslerr = ERR_get_error();
285         Invalidate();
286         return GetFailureReason(sslerr);
287     }
288 
289     if (client_verify_post_handshake_ && role_ == Role::Client) {
290         uint8_t check;
291         // Try to peek one byte for any failures. This assumes on success that
292         // the server actually sends something.
293         if (SSL_peek(ssl_.get(), &check, 1) <= 0) {
294             LOG(ERROR) << RoleToString() << "Post-handshake SSL_peek failed [" << SSLErrorString()
295                        << "]";
296             auto sslerr = ERR_get_error();
297             Invalidate();
298             return GetFailureReason(sslerr);
299         }
300     }
301 
302     LOG(INFO) << RoleToString() << "Handshake succeeded.";
303     return TlsError::Success;
304 }
305 
Invalidate()306 void TlsConnectionImpl::Invalidate() {
307     ssl_.reset();
308     ssl_ctx_.reset();
309 }
310 
ReadFully(size_t size)311 std::vector<uint8_t> TlsConnectionImpl::ReadFully(size_t size) {
312     std::vector<uint8_t> buf(size);
313     if (!ReadFully(buf.data(), buf.size())) {
314         return {};
315     }
316 
317     return buf;
318 }
319 
ReadFully(void * buf,size_t size)320 bool TlsConnectionImpl::ReadFully(void* buf, size_t size) {
321     CHECK_GT(size, 0U);
322     if (!ssl_) {
323         LOG(ERROR) << RoleToString() << "Tried to read on a null SSL connection";
324         return false;
325     }
326 
327     size_t offset = 0;
328     uint8_t* p8 = reinterpret_cast<uint8_t*>(buf);
329     while (size > 0) {
330         int bytes_read =
331                 SSL_read(ssl_.get(), p8 + offset, std::min(static_cast<size_t>(INT_MAX), size));
332         if (bytes_read <= 0) {
333             LOG(ERROR) << RoleToString() << "SSL_read failed [" << SSLErrorString() << "]";
334             return false;
335         }
336         size -= bytes_read;
337         offset += bytes_read;
338     }
339     return true;
340 }
341 
WriteFully(std::string_view data)342 bool TlsConnectionImpl::WriteFully(std::string_view data) {
343     CHECK(!data.empty());
344     if (!ssl_) {
345         LOG(ERROR) << RoleToString() << "Tried to read on a null SSL connection";
346         return false;
347     }
348 
349     while (!data.empty()) {
350         int bytes_out = SSL_write(ssl_.get(), data.data(),
351                                   std::min(static_cast<size_t>(INT_MAX), data.size()));
352         if (bytes_out <= 0) {
353             LOG(ERROR) << RoleToString() << "SSL_write failed [" << SSLErrorString() << "]";
354             return false;
355         }
356         data = data.substr(bytes_out);
357     }
358     return true;
359 }
360 }  // namespace
361 
362 // static
Create(TlsConnection::Role role,std::string_view cert,std::string_view priv_key,borrowed_fd fd)363 std::unique_ptr<TlsConnection> TlsConnection::Create(TlsConnection::Role role,
364                                                      std::string_view cert,
365                                                      std::string_view priv_key, borrowed_fd fd) {
366     CHECK(!cert.empty());
367     CHECK(!priv_key.empty());
368 
369     return std::make_unique<TlsConnectionImpl>(role, cert, priv_key, fd);
370 }
371 
372 // static
SetCertAndKey(SSL * ssl,std::string_view cert,std::string_view priv_key)373 bool TlsConnection::SetCertAndKey(SSL* ssl, std::string_view cert, std::string_view priv_key) {
374     CHECK(ssl);
375     // Note: declaring these in local scope is okay because
376     // SSL_set_chain_and_key will increase the refcount (bssl::UpRef).
377     auto x509_cert = TlsConnectionImpl::BufferFromPEM(cert);
378     auto evp_pkey = TlsConnectionImpl::EvpPkeyFromPEM(priv_key);
379     if (x509_cert == nullptr || evp_pkey == nullptr) {
380         return false;
381     }
382 
383     std::vector<CRYPTO_BUFFER*> cert_chain = {
384             x509_cert.get(),
385     };
386     if (!SSL_set_chain_and_key(ssl, cert_chain.data(), cert_chain.size(), evp_pkey.get(),
387                                nullptr)) {
388         LOG(ERROR) << "SSL_set_chain_and_key failed";
389         return false;
390     }
391 
392     return true;
393 }
394 
395 }  // namespace tls
396 }  // namespace adb
397