/* * Copyright (C) 2019 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except * in compliance with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "adb/tls/tls_connection.h" #include #include #include #include #include #include #include using android::base::borrowed_fd; namespace adb { namespace tls { namespace { static constexpr char kExportedKeyLabel[] = "adb-label"; class TlsConnectionImpl : public TlsConnection { public: explicit TlsConnectionImpl(Role role, std::string_view cert, std::string_view priv_key, borrowed_fd fd); ~TlsConnectionImpl() override; bool AddTrustedCertificate(std::string_view cert) override; void SetCertVerifyCallback(CertVerifyCb cb) override; void SetCertificateCallback(SetCertCb cb) override; void SetClientCAList(STACK_OF(X509_NAME) * ca_list) override; std::vector ExportKeyingMaterial(size_t length) override; void EnableClientPostHandshakeCheck(bool enable) override; TlsError DoHandshake() override; std::vector ReadFully(size_t size) override; bool ReadFully(void* buf, size_t size) override; bool WriteFully(std::string_view data) override; static bssl::UniquePtr EvpPkeyFromPEM(std::string_view pem); static bssl::UniquePtr BufferFromPEM(std::string_view pem); private: static int SSLSetCertVerifyCb(X509_STORE_CTX* ctx, void* opaque); static int SSLSetCertCb(SSL* ssl, void* opaque); static bssl::UniquePtr X509FromBuffer(bssl::UniquePtr buffer); static const char* SSLErrorString(); void Invalidate(); TlsError GetFailureReason(int err); const char* RoleToString() { return role_ == Role::Server ? kServerRoleStr : kClientRoleStr; } Role role_; bssl::UniquePtr priv_key_; bssl::UniquePtr cert_; bssl::UniquePtr ca_list_; bssl::UniquePtr ssl_ctx_; bssl::UniquePtr ssl_; std::vector> known_certificates_; bool client_verify_post_handshake_ = false; CertVerifyCb cert_verify_cb_; SetCertCb set_cert_cb_; borrowed_fd fd_; static constexpr char kClientRoleStr[] = "[client]: "; static constexpr char kServerRoleStr[] = "[server]: "; }; // TlsConnectionImpl TlsConnectionImpl::TlsConnectionImpl(Role role, std::string_view cert, std::string_view priv_key, borrowed_fd fd) : role_(role), fd_(fd) { CHECK(!cert.empty() && !priv_key.empty()); LOG(INFO) << RoleToString() << "Initializing adbwifi TlsConnection"; cert_ = BufferFromPEM(cert); CHECK(cert_); priv_key_ = EvpPkeyFromPEM(priv_key); CHECK(priv_key_); } TlsConnectionImpl::~TlsConnectionImpl() { // shutdown the SSL connection if (ssl_ != nullptr) { SSL_shutdown(ssl_.get()); } } // static const char* TlsConnectionImpl::SSLErrorString() { auto sslerr = ERR_peek_last_error(); return ERR_reason_error_string(sslerr); } // static bssl::UniquePtr TlsConnectionImpl::EvpPkeyFromPEM(std::string_view pem) { bssl::UniquePtr bio(BIO_new_mem_buf(pem.data(), pem.size())); return bssl::UniquePtr(PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr, nullptr)); } // static bssl::UniquePtr TlsConnectionImpl::BufferFromPEM(std::string_view pem) { bssl::UniquePtr bio(BIO_new_mem_buf(pem.data(), pem.size())); char* name = nullptr; char* header = nullptr; uint8_t* data = nullptr; long data_len = 0; if (!PEM_read_bio(bio.get(), &name, &header, &data, &data_len)) { LOG(ERROR) << "Failed to read certificate"; return nullptr; } OPENSSL_free(name); OPENSSL_free(header); auto ret = bssl::UniquePtr(CRYPTO_BUFFER_new(data, data_len, nullptr)); OPENSSL_free(data); return ret; } // static bssl::UniquePtr TlsConnectionImpl::X509FromBuffer(bssl::UniquePtr buffer) { if (!buffer) { return nullptr; } return bssl::UniquePtr(X509_parse_from_buffer(buffer.get())); } // static int TlsConnectionImpl::SSLSetCertVerifyCb(X509_STORE_CTX* ctx, void* opaque) { auto* p = reinterpret_cast(opaque); return p->cert_verify_cb_(ctx); } // static int TlsConnectionImpl::SSLSetCertCb(SSL* ssl, void* opaque) { auto* p = reinterpret_cast(opaque); return p->set_cert_cb_(ssl); } bool TlsConnectionImpl::AddTrustedCertificate(std::string_view cert) { // Create X509 buffer from the certificate string auto buf = X509FromBuffer(BufferFromPEM(cert)); if (buf == nullptr) { LOG(ERROR) << RoleToString() << "Failed to create a X509 buffer for the certificate."; return false; } known_certificates_.push_back(std::move(buf)); return true; } void TlsConnectionImpl::SetCertVerifyCallback(CertVerifyCb cb) { cert_verify_cb_ = cb; } void TlsConnectionImpl::SetCertificateCallback(SetCertCb cb) { set_cert_cb_ = cb; } void TlsConnectionImpl::SetClientCAList(STACK_OF(X509_NAME) * ca_list) { CHECK(role_ == Role::Server); ca_list_.reset(ca_list != nullptr ? SSL_dup_CA_list(ca_list) : nullptr); } std::vector TlsConnectionImpl::ExportKeyingMaterial(size_t length) { if (ssl_.get() == nullptr) { return {}; } std::vector out(length); if (SSL_export_keying_material(ssl_.get(), out.data(), out.size(), kExportedKeyLabel, sizeof(kExportedKeyLabel), nullptr, 0, false) == 0) { return {}; } return out; } void TlsConnectionImpl::EnableClientPostHandshakeCheck(bool enable) { client_verify_post_handshake_ = enable; } TlsConnection::TlsError TlsConnectionImpl::GetFailureReason(int err) { switch (ERR_GET_REASON(err)) { case SSL_R_SSLV3_ALERT_BAD_CERTIFICATE: case SSL_R_SSLV3_ALERT_UNSUPPORTED_CERTIFICATE: case SSL_R_SSLV3_ALERT_CERTIFICATE_REVOKED: case SSL_R_SSLV3_ALERT_CERTIFICATE_EXPIRED: case SSL_R_SSLV3_ALERT_CERTIFICATE_UNKNOWN: case SSL_R_TLSV1_ALERT_ACCESS_DENIED: case SSL_R_TLSV1_ALERT_UNKNOWN_CA: case SSL_R_TLSV1_CERTIFICATE_REQUIRED: return TlsError::PeerRejectedCertificate; case SSL_R_CERTIFICATE_VERIFY_FAILED: return TlsError::CertificateRejected; default: return TlsError::UnknownFailure; } } TlsConnection::TlsError TlsConnectionImpl::DoHandshake() { LOG(INFO) << RoleToString() << "Starting adbwifi tls handshake"; ssl_ctx_.reset(SSL_CTX_new(TLS_method())); // TODO: Remove set_max_proto_version() once external/boringssl is updated // past // https://boringssl.googlesource.com/boringssl/+/58d56f4c59969a23e5f52014e2651c76fea2f877 if (ssl_ctx_.get() == nullptr || !SSL_CTX_set_min_proto_version(ssl_ctx_.get(), TLS1_3_VERSION) || !SSL_CTX_set_max_proto_version(ssl_ctx_.get(), TLS1_3_VERSION)) { LOG(ERROR) << RoleToString() << "Failed to create SSL context"; return TlsError::UnknownFailure; } // Register user-supplied known certificates for (auto const& cert : known_certificates_) { if (X509_STORE_add_cert(SSL_CTX_get_cert_store(ssl_ctx_.get()), cert.get()) == 0) { LOG(ERROR) << RoleToString() << "Unable to add certificates into the X509_STORE"; return TlsError::UnknownFailure; } } // Custom certificate verification if (cert_verify_cb_) { SSL_CTX_set_cert_verify_callback(ssl_ctx_.get(), SSLSetCertVerifyCb, this); } // set select certificate callback, if any. if (set_cert_cb_) { SSL_CTX_set_cert_cb(ssl_ctx_.get(), SSLSetCertCb, this); } // Server-allowed client CA list if (ca_list_ != nullptr) { bssl::UniquePtr names(SSL_dup_CA_list(ca_list_.get())); SSL_CTX_set_client_CA_list(ssl_ctx_.get(), names.release()); } // Register our certificate and private key. std::vector cert_chain = { cert_.get(), }; if (!SSL_CTX_set_chain_and_key(ssl_ctx_.get(), cert_chain.data(), cert_chain.size(), priv_key_.get(), nullptr)) { LOG(ERROR) << RoleToString() << "Unable to register the certificate chain file and private key [" << SSLErrorString() << "]"; Invalidate(); return TlsError::UnknownFailure; } SSL_CTX_set_verify(ssl_ctx_.get(), SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr); // Okay! Let's try to do the handshake! ssl_.reset(SSL_new(ssl_ctx_.get())); if (!SSL_set_fd(ssl_.get(), fd_.get())) { LOG(ERROR) << RoleToString() << "SSL_set_fd failed. [" << SSLErrorString() << "]"; return TlsError::UnknownFailure; } switch (role_) { case Role::Server: SSL_set_accept_state(ssl_.get()); break; case Role::Client: SSL_set_connect_state(ssl_.get()); break; } if (SSL_do_handshake(ssl_.get()) != 1) { LOG(ERROR) << RoleToString() << "Handshake failed in SSL_accept/SSL_connect [" << SSLErrorString() << "]"; auto sslerr = ERR_get_error(); Invalidate(); return GetFailureReason(sslerr); } if (client_verify_post_handshake_ && role_ == Role::Client) { uint8_t check; // Try to peek one byte for any failures. This assumes on success that // the server actually sends something. if (SSL_peek(ssl_.get(), &check, 1) <= 0) { LOG(ERROR) << RoleToString() << "Post-handshake SSL_peek failed [" << SSLErrorString() << "]"; auto sslerr = ERR_get_error(); Invalidate(); return GetFailureReason(sslerr); } } LOG(INFO) << RoleToString() << "Handshake succeeded."; return TlsError::Success; } void TlsConnectionImpl::Invalidate() { ssl_.reset(); ssl_ctx_.reset(); } std::vector TlsConnectionImpl::ReadFully(size_t size) { std::vector buf(size); if (!ReadFully(buf.data(), buf.size())) { return {}; } return buf; } bool TlsConnectionImpl::ReadFully(void* buf, size_t size) { CHECK_GT(size, 0U); if (!ssl_) { LOG(ERROR) << RoleToString() << "Tried to read on a null SSL connection"; return false; } size_t offset = 0; uint8_t* p8 = reinterpret_cast(buf); while (size > 0) { int bytes_read = SSL_read(ssl_.get(), p8 + offset, std::min(static_cast(INT_MAX), size)); if (bytes_read <= 0) { LOG(ERROR) << RoleToString() << "SSL_read failed [" << SSLErrorString() << "]"; return false; } size -= bytes_read; offset += bytes_read; } return true; } bool TlsConnectionImpl::WriteFully(std::string_view data) { CHECK(!data.empty()); if (!ssl_) { LOG(ERROR) << RoleToString() << "Tried to read on a null SSL connection"; return false; } while (!data.empty()) { int bytes_out = SSL_write(ssl_.get(), data.data(), std::min(static_cast(INT_MAX), data.size())); if (bytes_out <= 0) { LOG(ERROR) << RoleToString() << "SSL_write failed [" << SSLErrorString() << "]"; return false; } data = data.substr(bytes_out); } return true; } } // namespace // static std::unique_ptr TlsConnection::Create(TlsConnection::Role role, std::string_view cert, std::string_view priv_key, borrowed_fd fd) { CHECK(!cert.empty()); CHECK(!priv_key.empty()); return std::make_unique(role, cert, priv_key, fd); } // static bool TlsConnection::SetCertAndKey(SSL* ssl, std::string_view cert, std::string_view priv_key) { CHECK(ssl); // Note: declaring these in local scope is okay because // SSL_set_chain_and_key will increase the refcount (bssl::UpRef). auto x509_cert = TlsConnectionImpl::BufferFromPEM(cert); auto evp_pkey = TlsConnectionImpl::EvpPkeyFromPEM(priv_key); if (x509_cert == nullptr || evp_pkey == nullptr) { return false; } std::vector cert_chain = { x509_cert.get(), }; if (!SSL_set_chain_and_key(ssl, cert_chain.data(), cert_chain.size(), evp_pkey.get(), nullptr)) { LOG(ERROR) << "SSL_set_chain_and_key failed"; return false; } return true; } } // namespace tls } // namespace adb