1 /* 2 * Copyright (C) 2020 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "adb/pairing/pairing_connection.h" 18 19 #include <stddef.h> 20 #include <stdint.h> 21 22 #include <functional> 23 #include <memory> 24 #include <string_view> 25 #include <thread> 26 #include <vector> 27 28 #include <adb/pairing/pairing_auth.h> 29 #include <adb/tls/tls_connection.h> 30 #include <android-base/endian.h> 31 #include <android-base/logging.h> 32 #include <android-base/macros.h> 33 #include <android-base/unique_fd.h> 34 35 #include "pairing.pb.h" 36 37 using namespace adb; 38 using android::base::unique_fd; 39 using TlsError = tls::TlsConnection::TlsError; 40 41 const uint8_t kCurrentKeyHeaderVersion = 1; 42 const uint8_t kMinSupportedKeyHeaderVersion = 1; 43 const uint8_t kMaxSupportedKeyHeaderVersion = 1; 44 const uint32_t kMaxPayloadSize = kMaxPeerInfoSize * 2; 45 46 struct PairingPacketHeader { 47 uint8_t version; // PairingPacket version 48 uint8_t type; // the type of packet (PairingPacket.Type) 49 uint32_t payload; // Size of the payload in bytes 50 } __attribute__((packed)); 51 52 struct PairingAuthDeleter { 53 void operator()(PairingAuthCtx* p) { pairing_auth_destroy(p); } 54 }; // PairingAuthDeleter 55 using PairingAuthPtr = std::unique_ptr<PairingAuthCtx, PairingAuthDeleter>; 56 57 // PairingConnectionCtx encapsulates the protocol to authenticate two peers with 58 // each other. This class will open the tcp sockets and handle the pairing 59 // process. On completion, both sides will have each other's public key 60 // (certificate) if successful, otherwise, the pairing failed. The tcp port 61 // number is hardcoded (see pairing_connection.cpp). 62 // 63 // Each PairingConnectionCtx instance represents a different device trying to 64 // pair. So for the device, we can have multiple PairingConnectionCtxs while the 65 // host may have only one (unless host has a PairingServer). 66 // 67 // See pairing_connection_test.cpp for example usage. 68 // 69 struct PairingConnectionCtx { 70 public: 71 using Data = std::vector<uint8_t>; 72 using ResultCallback = pairing_result_cb; 73 enum class Role { 74 Client, 75 Server, 76 }; 77 78 explicit PairingConnectionCtx(Role role, const Data& pswd, const PeerInfo& peer_info, 79 const Data& certificate, const Data& priv_key); 80 virtual ~PairingConnectionCtx(); 81 82 // Starts the pairing connection on a separate thread. 83 // Upon completion, if the pairing was successful, 84 // |cb| will be called with the peer information and certificate. 85 // Otherwise, |cb| will be called with empty data. |fd| should already 86 // be opened. PairingConnectionCtx will take ownership of the |fd|. 87 // 88 // Pairing is successful if both server/client uses the same non-empty 89 // |pswd|, and they are able to exchange the information. |pswd| and 90 // |certificate| must be non-empty. Start() can only be called once in the 91 // lifetime of this object. 92 // 93 // Returns true if the thread was successfully started, false otherwise. 94 bool Start(int fd, ResultCallback cb, void* opaque); 95 96 private: 97 // Setup the tls connection. 98 bool SetupTlsConnection(); 99 100 /************ PairingPacketHeader methods ****************/ 101 // Tries to write out the header and payload. 102 bool WriteHeader(const PairingPacketHeader* header, std::string_view payload); 103 // Tries to parse incoming data into the |header|. Returns true if header 104 // is valid and header version is supported. |header| is filled on success. 105 // |header| may contain garbage if unsuccessful. 106 bool ReadHeader(PairingPacketHeader* header); 107 // Creates a PairingPacketHeader. 108 void CreateHeader(PairingPacketHeader* header, adb::proto::PairingPacket::Type type, 109 uint32_t payload_size); 110 // Checks if actual matches expected. 111 bool CheckHeaderType(adb::proto::PairingPacket::Type expected, uint8_t actual); 112 113 /*********** State related methods **************/ 114 // Handles the State::ExchangingMsgs state. 115 bool DoExchangeMsgs(); 116 // Handles the State::ExchangingPeerInfo state. 117 bool DoExchangePeerInfo(); 118 119 // The background task to do the pairing. 120 void StartWorker(); 121 122 // Calls |cb_| and sets the state to Stopped. 123 void NotifyResult(const PeerInfo* p); 124 125 static PairingAuthPtr CreatePairingAuthPtr(Role role, const Data& pswd); 126 127 enum class State { 128 Ready, 129 ExchangingMsgs, 130 ExchangingPeerInfo, 131 Stopped, 132 }; 133 134 std::atomic<State> state_{State::Ready}; 135 Role role_; 136 Data pswd_; 137 PeerInfo peer_info_; 138 Data cert_; 139 Data priv_key_; 140 141 // Peer's info 142 PeerInfo their_info_; 143 144 ResultCallback cb_; 145 void* opaque_ = nullptr; 146 std::unique_ptr<tls::TlsConnection> tls_; 147 PairingAuthPtr auth_; 148 unique_fd fd_; 149 std::thread thread_; 150 static constexpr size_t kExportedKeySize = 64; 151 }; // PairingConnectionCtx 152 153 PairingConnectionCtx::PairingConnectionCtx(Role role, const Data& pswd, const PeerInfo& peer_info, 154 const Data& cert, const Data& priv_key) 155 : role_(role), pswd_(pswd), peer_info_(peer_info), cert_(cert), priv_key_(priv_key) { 156 CHECK(!pswd_.empty() && !cert_.empty() && !priv_key_.empty()); 157 } 158 159 PairingConnectionCtx::~PairingConnectionCtx() { 160 // Force close the fd and wait for the worker thread to finish. 161 fd_.reset(); 162 if (thread_.joinable()) { 163 thread_.join(); 164 } 165 } 166 167 bool PairingConnectionCtx::SetupTlsConnection() { 168 tls_ = tls::TlsConnection::Create( 169 role_ == Role::Server ? tls::TlsConnection::Role::Server 170 : tls::TlsConnection::Role::Client, 171 std::string_view(reinterpret_cast<const char*>(cert_.data()), cert_.size()), 172 std::string_view(reinterpret_cast<const char*>(priv_key_.data()), priv_key_.size()), 173 fd_); 174 175 if (tls_ == nullptr) { 176 LOG(ERROR) << "Unable to start TlsConnection. Unable to pair fd=" << fd_.get(); 177 return false; 178 } 179 180 // Allow any peer certificate 181 tls_->SetCertVerifyCallback([](X509_STORE_CTX*) { return 1; }); 182 183 // SSL doesn't seem to behave correctly with fdevents so just do a blocking 184 // read for the pairing data. 185 if (tls_->DoHandshake() != TlsError::Success) { 186 LOG(ERROR) << "Failed to handshake with the peer fd=" << fd_.get(); 187 return false; 188 } 189 190 // To ensure the connection is not stolen while we do the PAKE, append the 191 // exported key material from the tls connection to the password. 192 std::vector<uint8_t> exportedKeyMaterial = tls_->ExportKeyingMaterial(kExportedKeySize); 193 if (exportedKeyMaterial.empty()) { 194 LOG(ERROR) << "Failed to export key material"; 195 return false; 196 } 197 pswd_.insert(pswd_.end(), std::make_move_iterator(exportedKeyMaterial.begin()), 198 std::make_move_iterator(exportedKeyMaterial.end())); 199 auth_ = CreatePairingAuthPtr(role_, pswd_); 200 201 return true; 202 } 203 204 bool PairingConnectionCtx::WriteHeader(const PairingPacketHeader* header, 205 std::string_view payload) { 206 PairingPacketHeader network_header = *header; 207 network_header.payload = htonl(network_header.payload); 208 if (!tls_->WriteFully(std::string_view(reinterpret_cast<const char*>(&network_header), 209 sizeof(PairingPacketHeader))) || 210 !tls_->WriteFully(payload)) { 211 LOG(ERROR) << "Failed to write out PairingPacketHeader"; 212 state_ = State::Stopped; 213 return false; 214 } 215 return true; 216 } 217 218 bool PairingConnectionCtx::ReadHeader(PairingPacketHeader* header) { 219 auto data = tls_->ReadFully(sizeof(PairingPacketHeader)); 220 if (data.empty()) { 221 return false; 222 } 223 224 uint8_t* p = data.data(); 225 // First byte is always PairingPacketHeader version 226 header->version = *p; 227 ++p; 228 if (header->version < kMinSupportedKeyHeaderVersion || 229 header->version > kMaxSupportedKeyHeaderVersion) { 230 LOG(ERROR) << "PairingPacketHeader version mismatch (us=" << kCurrentKeyHeaderVersion 231 << " them=" << header->version << ")"; 232 return false; 233 } 234 // Next byte is the PairingPacket::Type 235 if (!adb::proto::PairingPacket::Type_IsValid(*p)) { 236 LOG(ERROR) << "Unknown PairingPacket type=" << static_cast<uint32_t>(*p); 237 return false; 238 } 239 header->type = *p; 240 ++p; 241 // Last, the payload size 242 header->payload = ntohl(*(reinterpret_cast<uint32_t*>(p))); 243 if (header->payload == 0 || header->payload > kMaxPayloadSize) { 244 LOG(ERROR) << "header payload not within a safe payload size (size=" << header->payload 245 << ")"; 246 return false; 247 } 248 249 return true; 250 } 251 252 void PairingConnectionCtx::CreateHeader(PairingPacketHeader* header, 253 adb::proto::PairingPacket::Type type, 254 uint32_t payload_size) { 255 header->version = kCurrentKeyHeaderVersion; 256 uint8_t type8 = static_cast<uint8_t>(static_cast<int>(type)); 257 header->type = type8; 258 header->payload = payload_size; 259 } 260 261 bool PairingConnectionCtx::CheckHeaderType(adb::proto::PairingPacket::Type expected_type, 262 uint8_t actual) { 263 uint8_t expected = *reinterpret_cast<uint8_t*>(&expected_type); 264 if (actual != expected) { 265 LOG(ERROR) << "Unexpected header type (expected=" << static_cast<uint32_t>(expected) 266 << " actual=" << static_cast<uint32_t>(actual) << ")"; 267 return false; 268 } 269 return true; 270 } 271 272 void PairingConnectionCtx::NotifyResult(const PeerInfo* p) { 273 cb_(p, fd_.get(), opaque_); 274 state_ = State::Stopped; 275 } 276 277 bool PairingConnectionCtx::Start(int fd, ResultCallback cb, void* opaque) { 278 if (fd < 0) { 279 return false; 280 } 281 fd_.reset(fd); 282 283 State expected = State::Ready; 284 if (!state_.compare_exchange_strong(expected, State::ExchangingMsgs)) { 285 return false; 286 } 287 288 cb_ = cb; 289 opaque_ = opaque; 290 291 thread_ = std::thread([this] { StartWorker(); }); 292 return true; 293 } 294 295 bool PairingConnectionCtx::DoExchangeMsgs() { 296 uint32_t payload = pairing_auth_msg_size(auth_.get()); 297 std::vector<uint8_t> msg(payload); 298 pairing_auth_get_spake2_msg(auth_.get(), msg.data()); 299 300 PairingPacketHeader header; 301 CreateHeader(&header, adb::proto::PairingPacket::SPAKE2_MSG, payload); 302 303 // Write our SPAKE2 msg 304 if (!WriteHeader(&header, 305 std::string_view(reinterpret_cast<const char*>(msg.data()), msg.size()))) { 306 LOG(ERROR) << "Failed to write SPAKE2 msg."; 307 return false; 308 } 309 310 // Read the peer's SPAKE2 msg header 311 if (!ReadHeader(&header)) { 312 LOG(ERROR) << "Invalid PairingPacketHeader."; 313 return false; 314 } 315 if (!CheckHeaderType(adb::proto::PairingPacket::SPAKE2_MSG, header.type)) { 316 return false; 317 } 318 319 // Read the SPAKE2 msg payload and initialize the cipher for 320 // encrypting the PeerInfo and certificate. 321 auto their_msg = tls_->ReadFully(header.payload); 322 if (their_msg.empty() || 323 !pairing_auth_init_cipher(auth_.get(), their_msg.data(), their_msg.size())) { 324 LOG(ERROR) << "Unable to initialize pairing cipher [their_msg.size=" << their_msg.size() 325 << "]"; 326 return false; 327 } 328 329 return true; 330 } 331 332 bool PairingConnectionCtx::DoExchangePeerInfo() { 333 // Encrypt PeerInfo 334 std::vector<uint8_t> buf; 335 uint8_t* p = reinterpret_cast<uint8_t*>(&peer_info_); 336 buf.assign(p, p + sizeof(peer_info_)); 337 std::vector<uint8_t> outbuf(pairing_auth_safe_encrypted_size(auth_.get(), buf.size())); 338 CHECK(!outbuf.empty()); 339 size_t outsize; 340 if (!pairing_auth_encrypt(auth_.get(), buf.data(), buf.size(), outbuf.data(), &outsize)) { 341 LOG(ERROR) << "Failed to encrypt peer info"; 342 return false; 343 } 344 outbuf.resize(outsize); 345 346 // Write out the packet header 347 PairingPacketHeader out_header; 348 out_header.version = kCurrentKeyHeaderVersion; 349 out_header.type = static_cast<uint8_t>(static_cast<int>(adb::proto::PairingPacket::PEER_INFO)); 350 out_header.payload = htonl(outbuf.size()); 351 if (!tls_->WriteFully( 352 std::string_view(reinterpret_cast<const char*>(&out_header), sizeof(out_header)))) { 353 LOG(ERROR) << "Unable to write PairingPacketHeader"; 354 return false; 355 } 356 357 // Write out the encrypted payload 358 if (!tls_->WriteFully( 359 std::string_view(reinterpret_cast<const char*>(outbuf.data()), outbuf.size()))) { 360 LOG(ERROR) << "Unable to write encrypted peer info"; 361 return false; 362 } 363 364 // Read in the peer's packet header 365 PairingPacketHeader header; 366 if (!ReadHeader(&header)) { 367 LOG(ERROR) << "Invalid PairingPacketHeader."; 368 return false; 369 } 370 371 if (!CheckHeaderType(adb::proto::PairingPacket::PEER_INFO, header.type)) { 372 return false; 373 } 374 375 // Read in the encrypted peer certificate 376 buf = tls_->ReadFully(header.payload); 377 if (buf.empty()) { 378 return false; 379 } 380 381 // Try to decrypt the certificate 382 outbuf.resize(pairing_auth_safe_decrypted_size(auth_.get(), buf.data(), buf.size())); 383 if (outbuf.empty()) { 384 LOG(ERROR) << "Unsupported payload while decrypting peer info."; 385 return false; 386 } 387 388 if (!pairing_auth_decrypt(auth_.get(), buf.data(), buf.size(), outbuf.data(), &outsize)) { 389 LOG(ERROR) << "Failed to decrypt"; 390 return false; 391 } 392 outbuf.resize(outsize); 393 394 // The decrypted message should contain the PeerInfo. 395 if (outbuf.size() != sizeof(PeerInfo)) { 396 LOG(ERROR) << "Got size=" << outbuf.size() << "PeerInfo.size=" << sizeof(PeerInfo); 397 return false; 398 } 399 400 p = outbuf.data(); 401 ::memcpy(&their_info_, p, sizeof(PeerInfo)); 402 p += sizeof(PeerInfo); 403 404 return true; 405 } 406 407 void PairingConnectionCtx::StartWorker() { 408 // Setup the secure transport 409 if (!SetupTlsConnection()) { 410 NotifyResult(nullptr); 411 return; 412 } 413 414 for (;;) { 415 switch (state_) { 416 case State::ExchangingMsgs: 417 if (!DoExchangeMsgs()) { 418 NotifyResult(nullptr); 419 return; 420 } 421 state_ = State::ExchangingPeerInfo; 422 break; 423 case State::ExchangingPeerInfo: 424 if (!DoExchangePeerInfo()) { 425 NotifyResult(nullptr); 426 return; 427 } 428 NotifyResult(&their_info_); 429 return; 430 case State::Ready: 431 case State::Stopped: 432 LOG(FATAL) << __func__ << ": Got invalid state"; 433 return; 434 } 435 } 436 } 437 438 // static 439 PairingAuthPtr PairingConnectionCtx::CreatePairingAuthPtr(Role role, const Data& pswd) { 440 switch (role) { 441 case Role::Client: 442 return PairingAuthPtr(pairing_auth_client_new(pswd.data(), pswd.size())); 443 break; 444 case Role::Server: 445 return PairingAuthPtr(pairing_auth_server_new(pswd.data(), pswd.size())); 446 break; 447 } 448 } 449 450 static PairingConnectionCtx* CreateConnection(PairingConnectionCtx::Role role, const uint8_t* pswd, 451 size_t pswd_len, const PeerInfo* peer_info, 452 const uint8_t* x509_cert_pem, size_t x509_size, 453 const uint8_t* priv_key_pem, size_t priv_size) { 454 CHECK(pswd); 455 CHECK_GT(pswd_len, 0U); 456 CHECK(x509_cert_pem); 457 CHECK_GT(x509_size, 0U); 458 CHECK(priv_key_pem); 459 CHECK_GT(priv_size, 0U); 460 CHECK(peer_info); 461 std::vector<uint8_t> vec_pswd(pswd, pswd + pswd_len); 462 std::vector<uint8_t> vec_x509_cert(x509_cert_pem, x509_cert_pem + x509_size); 463 std::vector<uint8_t> vec_priv_key(priv_key_pem, priv_key_pem + priv_size); 464 return new PairingConnectionCtx(role, vec_pswd, *peer_info, vec_x509_cert, vec_priv_key); 465 } 466 467 PairingConnectionCtx* pairing_connection_client_new(const uint8_t* pswd, size_t pswd_len, 468 const PeerInfo* peer_info, 469 const uint8_t* x509_cert_pem, size_t x509_size, 470 const uint8_t* priv_key_pem, size_t priv_size) { 471 return CreateConnection(PairingConnectionCtx::Role::Client, pswd, pswd_len, peer_info, 472 x509_cert_pem, x509_size, priv_key_pem, priv_size); 473 } 474 475 PairingConnectionCtx* pairing_connection_server_new(const uint8_t* pswd, size_t pswd_len, 476 const PeerInfo* peer_info, 477 const uint8_t* x509_cert_pem, size_t x509_size, 478 const uint8_t* priv_key_pem, size_t priv_size) { 479 return CreateConnection(PairingConnectionCtx::Role::Server, pswd, pswd_len, peer_info, 480 x509_cert_pem, x509_size, priv_key_pem, priv_size); 481 } 482 483 void pairing_connection_destroy(PairingConnectionCtx* ctx) { 484 CHECK(ctx); 485 delete ctx; 486 } 487 488 bool pairing_connection_start(PairingConnectionCtx* ctx, int fd, pairing_result_cb cb, 489 void* opaque) { 490 return ctx->Start(fd, cb, opaque); 491 } 492