1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "adb/pairing/pairing_connection.h"
18 
19 #include <stddef.h>
20 #include <stdint.h>
21 
22 #include <functional>
23 #include <memory>
24 #include <string_view>
25 #include <thread>
26 #include <vector>
27 
28 #include <adb/pairing/pairing_auth.h>
29 #include <adb/tls/tls_connection.h>
30 #include <android-base/endian.h>
31 #include <android-base/logging.h>
32 #include <android-base/macros.h>
33 #include <android-base/unique_fd.h>
34 
35 #include "pairing.pb.h"
36 
37 using namespace adb;
38 using android::base::unique_fd;
39 using TlsError = tls::TlsConnection::TlsError;
40 
41 const uint8_t kCurrentKeyHeaderVersion = 1;
42 const uint8_t kMinSupportedKeyHeaderVersion = 1;
43 const uint8_t kMaxSupportedKeyHeaderVersion = 1;
44 const uint32_t kMaxPayloadSize = kMaxPeerInfoSize * 2;
45 
46 struct PairingPacketHeader {
47     uint8_t version;   // PairingPacket version
48     uint8_t type;      // the type of packet (PairingPacket.Type)
49     uint32_t payload;  // Size of the payload in bytes
50 } __attribute__((packed));
51 
52 struct PairingAuthDeleter {
53     void operator()(PairingAuthCtx* p) { pairing_auth_destroy(p); }
54 };  // PairingAuthDeleter
55 using PairingAuthPtr = std::unique_ptr<PairingAuthCtx, PairingAuthDeleter>;
56 
57 // PairingConnectionCtx encapsulates the protocol to authenticate two peers with
58 // each other. This class will open the tcp sockets and handle the pairing
59 // process. On completion, both sides will have each other's public key
60 // (certificate) if successful, otherwise, the pairing failed. The tcp port
61 // number is hardcoded (see pairing_connection.cpp).
62 //
63 // Each PairingConnectionCtx instance represents a different device trying to
64 // pair. So for the device, we can have multiple PairingConnectionCtxs while the
65 // host may have only one (unless host has a PairingServer).
66 //
67 // See pairing_connection_test.cpp for example usage.
68 //
69 struct PairingConnectionCtx {
70   public:
71     using Data = std::vector<uint8_t>;
72     using ResultCallback = pairing_result_cb;
73     enum class Role {
74         Client,
75         Server,
76     };
77 
78     explicit PairingConnectionCtx(Role role, const Data& pswd, const PeerInfo& peer_info,
79                                   const Data& certificate, const Data& priv_key);
80     virtual ~PairingConnectionCtx();
81 
82     // Starts the pairing connection on a separate thread.
83     // Upon completion, if the pairing was successful,
84     // |cb| will be called with the peer information and certificate.
85     // Otherwise, |cb| will be called with empty data. |fd| should already
86     // be opened. PairingConnectionCtx will take ownership of the |fd|.
87     //
88     // Pairing is successful if both server/client uses the same non-empty
89     // |pswd|, and they are able to exchange the information. |pswd| and
90     // |certificate| must be non-empty. Start() can only be called once in the
91     // lifetime of this object.
92     //
93     // Returns true if the thread was successfully started, false otherwise.
94     bool Start(int fd, ResultCallback cb, void* opaque);
95 
96   private:
97     // Setup the tls connection.
98     bool SetupTlsConnection();
99 
100     /************ PairingPacketHeader methods ****************/
101     // Tries to write out the header and payload.
102     bool WriteHeader(const PairingPacketHeader* header, std::string_view payload);
103     // Tries to parse incoming data into the |header|. Returns true if header
104     // is valid and header version is supported. |header| is filled on success.
105     // |header| may contain garbage if unsuccessful.
106     bool ReadHeader(PairingPacketHeader* header);
107     // Creates a PairingPacketHeader.
108     void CreateHeader(PairingPacketHeader* header, adb::proto::PairingPacket::Type type,
109                       uint32_t payload_size);
110     // Checks if actual matches expected.
111     bool CheckHeaderType(adb::proto::PairingPacket::Type expected, uint8_t actual);
112 
113     /*********** State related methods **************/
114     // Handles the State::ExchangingMsgs state.
115     bool DoExchangeMsgs();
116     // Handles the State::ExchangingPeerInfo state.
117     bool DoExchangePeerInfo();
118 
119     // The background task to do the pairing.
120     void StartWorker();
121 
122     // Calls |cb_| and sets the state to Stopped.
123     void NotifyResult(const PeerInfo* p);
124 
125     static PairingAuthPtr CreatePairingAuthPtr(Role role, const Data& pswd);
126 
127     enum class State {
128         Ready,
129         ExchangingMsgs,
130         ExchangingPeerInfo,
131         Stopped,
132     };
133 
134     std::atomic<State> state_{State::Ready};
135     Role role_;
136     Data pswd_;
137     PeerInfo peer_info_;
138     Data cert_;
139     Data priv_key_;
140 
141     // Peer's info
142     PeerInfo their_info_;
143 
144     ResultCallback cb_;
145     void* opaque_ = nullptr;
146     std::unique_ptr<tls::TlsConnection> tls_;
147     PairingAuthPtr auth_;
148     unique_fd fd_;
149     std::thread thread_;
150     static constexpr size_t kExportedKeySize = 64;
151 };  // PairingConnectionCtx
152 
153 PairingConnectionCtx::PairingConnectionCtx(Role role, const Data& pswd, const PeerInfo& peer_info,
154                                            const Data& cert, const Data& priv_key)
155     : role_(role), pswd_(pswd), peer_info_(peer_info), cert_(cert), priv_key_(priv_key) {
156     CHECK(!pswd_.empty() && !cert_.empty() && !priv_key_.empty());
157 }
158 
159 PairingConnectionCtx::~PairingConnectionCtx() {
160     // Force close the fd and wait for the worker thread to finish.
161     fd_.reset();
162     if (thread_.joinable()) {
163         thread_.join();
164     }
165 }
166 
167 bool PairingConnectionCtx::SetupTlsConnection() {
168     tls_ = tls::TlsConnection::Create(
169             role_ == Role::Server ? tls::TlsConnection::Role::Server
170                                   : tls::TlsConnection::Role::Client,
171             std::string_view(reinterpret_cast<const char*>(cert_.data()), cert_.size()),
172             std::string_view(reinterpret_cast<const char*>(priv_key_.data()), priv_key_.size()),
173             fd_);
174 
175     if (tls_ == nullptr) {
176         LOG(ERROR) << "Unable to start TlsConnection. Unable to pair fd=" << fd_.get();
177         return false;
178     }
179 
180     // Allow any peer certificate
181     tls_->SetCertVerifyCallback([](X509_STORE_CTX*) { return 1; });
182 
183     // SSL doesn't seem to behave correctly with fdevents so just do a blocking
184     // read for the pairing data.
185     if (tls_->DoHandshake() != TlsError::Success) {
186         LOG(ERROR) << "Failed to handshake with the peer fd=" << fd_.get();
187         return false;
188     }
189 
190     // To ensure the connection is not stolen while we do the PAKE, append the
191     // exported key material from the tls connection to the password.
192     std::vector<uint8_t> exportedKeyMaterial = tls_->ExportKeyingMaterial(kExportedKeySize);
193     if (exportedKeyMaterial.empty()) {
194         LOG(ERROR) << "Failed to export key material";
195         return false;
196     }
197     pswd_.insert(pswd_.end(), std::make_move_iterator(exportedKeyMaterial.begin()),
198                  std::make_move_iterator(exportedKeyMaterial.end()));
199     auth_ = CreatePairingAuthPtr(role_, pswd_);
200 
201     return true;
202 }
203 
204 bool PairingConnectionCtx::WriteHeader(const PairingPacketHeader* header,
205                                        std::string_view payload) {
206     PairingPacketHeader network_header = *header;
207     network_header.payload = htonl(network_header.payload);
208     if (!tls_->WriteFully(std::string_view(reinterpret_cast<const char*>(&network_header),
209                                            sizeof(PairingPacketHeader))) ||
210         !tls_->WriteFully(payload)) {
211         LOG(ERROR) << "Failed to write out PairingPacketHeader";
212         state_ = State::Stopped;
213         return false;
214     }
215     return true;
216 }
217 
218 bool PairingConnectionCtx::ReadHeader(PairingPacketHeader* header) {
219     auto data = tls_->ReadFully(sizeof(PairingPacketHeader));
220     if (data.empty()) {
221         return false;
222     }
223 
224     uint8_t* p = data.data();
225     // First byte is always PairingPacketHeader version
226     header->version = *p;
227     ++p;
228     if (header->version < kMinSupportedKeyHeaderVersion ||
229         header->version > kMaxSupportedKeyHeaderVersion) {
230         LOG(ERROR) << "PairingPacketHeader version mismatch (us=" << kCurrentKeyHeaderVersion
231                    << " them=" << header->version << ")";
232         return false;
233     }
234     // Next byte is the PairingPacket::Type
235     if (!adb::proto::PairingPacket::Type_IsValid(*p)) {
236         LOG(ERROR) << "Unknown PairingPacket type=" << static_cast<uint32_t>(*p);
237         return false;
238     }
239     header->type = *p;
240     ++p;
241     // Last, the payload size
242     header->payload = ntohl(*(reinterpret_cast<uint32_t*>(p)));
243     if (header->payload == 0 || header->payload > kMaxPayloadSize) {
244         LOG(ERROR) << "header payload not within a safe payload size (size=" << header->payload
245                    << ")";
246         return false;
247     }
248 
249     return true;
250 }
251 
252 void PairingConnectionCtx::CreateHeader(PairingPacketHeader* header,
253                                         adb::proto::PairingPacket::Type type,
254                                         uint32_t payload_size) {
255     header->version = kCurrentKeyHeaderVersion;
256     uint8_t type8 = static_cast<uint8_t>(static_cast<int>(type));
257     header->type = type8;
258     header->payload = payload_size;
259 }
260 
261 bool PairingConnectionCtx::CheckHeaderType(adb::proto::PairingPacket::Type expected_type,
262                                            uint8_t actual) {
263     uint8_t expected = *reinterpret_cast<uint8_t*>(&expected_type);
264     if (actual != expected) {
265         LOG(ERROR) << "Unexpected header type (expected=" << static_cast<uint32_t>(expected)
266                    << " actual=" << static_cast<uint32_t>(actual) << ")";
267         return false;
268     }
269     return true;
270 }
271 
272 void PairingConnectionCtx::NotifyResult(const PeerInfo* p) {
273     cb_(p, fd_.get(), opaque_);
274     state_ = State::Stopped;
275 }
276 
277 bool PairingConnectionCtx::Start(int fd, ResultCallback cb, void* opaque) {
278     if (fd < 0) {
279         return false;
280     }
281     fd_.reset(fd);
282 
283     State expected = State::Ready;
284     if (!state_.compare_exchange_strong(expected, State::ExchangingMsgs)) {
285         return false;
286     }
287 
288     cb_ = cb;
289     opaque_ = opaque;
290 
291     thread_ = std::thread([this] { StartWorker(); });
292     return true;
293 }
294 
295 bool PairingConnectionCtx::DoExchangeMsgs() {
296     uint32_t payload = pairing_auth_msg_size(auth_.get());
297     std::vector<uint8_t> msg(payload);
298     pairing_auth_get_spake2_msg(auth_.get(), msg.data());
299 
300     PairingPacketHeader header;
301     CreateHeader(&header, adb::proto::PairingPacket::SPAKE2_MSG, payload);
302 
303     // Write our SPAKE2 msg
304     if (!WriteHeader(&header,
305                      std::string_view(reinterpret_cast<const char*>(msg.data()), msg.size()))) {
306         LOG(ERROR) << "Failed to write SPAKE2 msg.";
307         return false;
308     }
309 
310     // Read the peer's SPAKE2 msg header
311     if (!ReadHeader(&header)) {
312         LOG(ERROR) << "Invalid PairingPacketHeader.";
313         return false;
314     }
315     if (!CheckHeaderType(adb::proto::PairingPacket::SPAKE2_MSG, header.type)) {
316         return false;
317     }
318 
319     // Read the SPAKE2 msg payload and initialize the cipher for
320     // encrypting the PeerInfo and certificate.
321     auto their_msg = tls_->ReadFully(header.payload);
322     if (their_msg.empty() ||
323         !pairing_auth_init_cipher(auth_.get(), their_msg.data(), their_msg.size())) {
324         LOG(ERROR) << "Unable to initialize pairing cipher [their_msg.size=" << their_msg.size()
325                    << "]";
326         return false;
327     }
328 
329     return true;
330 }
331 
332 bool PairingConnectionCtx::DoExchangePeerInfo() {
333     // Encrypt PeerInfo
334     std::vector<uint8_t> buf;
335     uint8_t* p = reinterpret_cast<uint8_t*>(&peer_info_);
336     buf.assign(p, p + sizeof(peer_info_));
337     std::vector<uint8_t> outbuf(pairing_auth_safe_encrypted_size(auth_.get(), buf.size()));
338     CHECK(!outbuf.empty());
339     size_t outsize;
340     if (!pairing_auth_encrypt(auth_.get(), buf.data(), buf.size(), outbuf.data(), &outsize)) {
341         LOG(ERROR) << "Failed to encrypt peer info";
342         return false;
343     }
344     outbuf.resize(outsize);
345 
346     // Write out the packet header
347     PairingPacketHeader out_header;
348     out_header.version = kCurrentKeyHeaderVersion;
349     out_header.type = static_cast<uint8_t>(static_cast<int>(adb::proto::PairingPacket::PEER_INFO));
350     out_header.payload = htonl(outbuf.size());
351     if (!tls_->WriteFully(
352                 std::string_view(reinterpret_cast<const char*>(&out_header), sizeof(out_header)))) {
353         LOG(ERROR) << "Unable to write PairingPacketHeader";
354         return false;
355     }
356 
357     // Write out the encrypted payload
358     if (!tls_->WriteFully(
359                 std::string_view(reinterpret_cast<const char*>(outbuf.data()), outbuf.size()))) {
360         LOG(ERROR) << "Unable to write encrypted peer info";
361         return false;
362     }
363 
364     // Read in the peer's packet header
365     PairingPacketHeader header;
366     if (!ReadHeader(&header)) {
367         LOG(ERROR) << "Invalid PairingPacketHeader.";
368         return false;
369     }
370 
371     if (!CheckHeaderType(adb::proto::PairingPacket::PEER_INFO, header.type)) {
372         return false;
373     }
374 
375     // Read in the encrypted peer certificate
376     buf = tls_->ReadFully(header.payload);
377     if (buf.empty()) {
378         return false;
379     }
380 
381     // Try to decrypt the certificate
382     outbuf.resize(pairing_auth_safe_decrypted_size(auth_.get(), buf.data(), buf.size()));
383     if (outbuf.empty()) {
384         LOG(ERROR) << "Unsupported payload while decrypting peer info.";
385         return false;
386     }
387 
388     if (!pairing_auth_decrypt(auth_.get(), buf.data(), buf.size(), outbuf.data(), &outsize)) {
389         LOG(ERROR) << "Failed to decrypt";
390         return false;
391     }
392     outbuf.resize(outsize);
393 
394     // The decrypted message should contain the PeerInfo.
395     if (outbuf.size() != sizeof(PeerInfo)) {
396         LOG(ERROR) << "Got size=" << outbuf.size() << "PeerInfo.size=" << sizeof(PeerInfo);
397         return false;
398     }
399 
400     p = outbuf.data();
401     ::memcpy(&their_info_, p, sizeof(PeerInfo));
402     p += sizeof(PeerInfo);
403 
404     return true;
405 }
406 
407 void PairingConnectionCtx::StartWorker() {
408     // Setup the secure transport
409     if (!SetupTlsConnection()) {
410         NotifyResult(nullptr);
411         return;
412     }
413 
414     for (;;) {
415         switch (state_) {
416             case State::ExchangingMsgs:
417                 if (!DoExchangeMsgs()) {
418                     NotifyResult(nullptr);
419                     return;
420                 }
421                 state_ = State::ExchangingPeerInfo;
422                 break;
423             case State::ExchangingPeerInfo:
424                 if (!DoExchangePeerInfo()) {
425                     NotifyResult(nullptr);
426                     return;
427                 }
428                 NotifyResult(&their_info_);
429                 return;
430             case State::Ready:
431             case State::Stopped:
432                 LOG(FATAL) << __func__ << ": Got invalid state";
433                 return;
434         }
435     }
436 }
437 
438 // static
439 PairingAuthPtr PairingConnectionCtx::CreatePairingAuthPtr(Role role, const Data& pswd) {
440     switch (role) {
441         case Role::Client:
442             return PairingAuthPtr(pairing_auth_client_new(pswd.data(), pswd.size()));
443             break;
444         case Role::Server:
445             return PairingAuthPtr(pairing_auth_server_new(pswd.data(), pswd.size()));
446             break;
447     }
448 }
449 
450 static PairingConnectionCtx* CreateConnection(PairingConnectionCtx::Role role, const uint8_t* pswd,
451                                               size_t pswd_len, const PeerInfo* peer_info,
452                                               const uint8_t* x509_cert_pem, size_t x509_size,
453                                               const uint8_t* priv_key_pem, size_t priv_size) {
454     CHECK(pswd);
455     CHECK_GT(pswd_len, 0U);
456     CHECK(x509_cert_pem);
457     CHECK_GT(x509_size, 0U);
458     CHECK(priv_key_pem);
459     CHECK_GT(priv_size, 0U);
460     CHECK(peer_info);
461     std::vector<uint8_t> vec_pswd(pswd, pswd + pswd_len);
462     std::vector<uint8_t> vec_x509_cert(x509_cert_pem, x509_cert_pem + x509_size);
463     std::vector<uint8_t> vec_priv_key(priv_key_pem, priv_key_pem + priv_size);
464     return new PairingConnectionCtx(role, vec_pswd, *peer_info, vec_x509_cert, vec_priv_key);
465 }
466 
467 PairingConnectionCtx* pairing_connection_client_new(const uint8_t* pswd, size_t pswd_len,
468                                                     const PeerInfo* peer_info,
469                                                     const uint8_t* x509_cert_pem, size_t x509_size,
470                                                     const uint8_t* priv_key_pem, size_t priv_size) {
471     return CreateConnection(PairingConnectionCtx::Role::Client, pswd, pswd_len, peer_info,
472                             x509_cert_pem, x509_size, priv_key_pem, priv_size);
473 }
474 
475 PairingConnectionCtx* pairing_connection_server_new(const uint8_t* pswd, size_t pswd_len,
476                                                     const PeerInfo* peer_info,
477                                                     const uint8_t* x509_cert_pem, size_t x509_size,
478                                                     const uint8_t* priv_key_pem, size_t priv_size) {
479     return CreateConnection(PairingConnectionCtx::Role::Server, pswd, pswd_len, peer_info,
480                             x509_cert_pem, x509_size, priv_key_pem, priv_size);
481 }
482 
483 void pairing_connection_destroy(PairingConnectionCtx* ctx) {
484     CHECK(ctx);
485     delete ctx;
486 }
487 
488 bool pairing_connection_start(PairingConnectionCtx* ctx, int fd, pairing_result_cb cb,
489                               void* opaque) {
490     return ctx->Start(fd, cb, opaque);
491 }
492