1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "adbwifi/pairing/pairing_server.h"
18 
19 #include <sys/epoll.h>
20 #include <sys/eventfd.h>
21 
22 #include <atomic>
23 #include <deque>
24 #include <iomanip>
25 #include <mutex>
26 #include <sstream>
27 #include <thread>
28 #include <tuple>
29 #include <unordered_map>
30 #include <variant>
31 #include <vector>
32 
33 #include <adbwifi/pairing/pairing_connection.h>
34 #include <android-base/logging.h>
35 #include <android-base/parsenetaddress.h>
36 #include <android-base/thread_annotations.h>
37 #include <android-base/unique_fd.h>
38 #include <cutils/sockets.h>
39 
40 namespace adbwifi {
41 namespace pairing {
42 
43 using android::base::ScopedLockAssertion;
44 using android::base::unique_fd;
45 
46 namespace {
47 
48 // The implimentation has two background threads running: one to handle and
49 // accept any new pairing connection requests (socket accept), and the other to
50 // handle connection events (connection started, connection finished).
51 class PairingServerImpl : public PairingServer {
52   public:
53     virtual ~PairingServerImpl();
54 
55     // All parameters must be non-empty.
56     explicit PairingServerImpl(const Data& pswd, const PeerInfo& peer_info, const Data& cert,
57                                const Data& priv_key, int port);
58 
59     // Starts the pairing server. This call is non-blocking. Upon completion,
60     // if the pairing was successful, then |cb| will be called with the PublicKeyHeader
61     // containing the info of the trusted peer. Otherwise, |cb| will be
62     // called with an empty value. Start can only be called once in the lifetime
63     // of this object.
64     //
65     // Returns true if PairingServer was successfully started. Otherwise,
66     // returns false.
67     virtual bool start(PairingConnection::ResultCallback cb, void* opaque) override;
68 
69   private:
70     // Setup the server socket to accept incoming connections
71     bool setupServer();
72     // Force stop the server thread.
73     void stopServer();
74 
75     // handles a new pairing client connection
76     bool handleNewClientConnection(int fd) EXCLUDES(conn_mutex_);
77 
78     // ======== connection events thread =============
79     std::mutex conn_mutex_;
80     std::condition_variable conn_cv_;
81 
82     using FdVal = int;
83     using ConnectionPtr = std::unique_ptr<PairingConnection>;
84     using NewConnectionEvent = std::tuple<unique_fd, ConnectionPtr>;
85     // <fd, PeerInfo.name, PeerInfo.guid, certificate>
86     using ConnectionFinishedEvent = std::tuple<FdVal, std::optional<std::string>,
87                                                std::optional<std::string>, std::optional<Data>>;
88     using ConnectionEvent = std::variant<NewConnectionEvent, ConnectionFinishedEvent>;
89     // Queue for connections to write into. We have a separate queue to read
90     // from, in order to minimize the time the server thread is blocked.
91     std::deque<ConnectionEvent> conn_write_queue_ GUARDED_BY(conn_mutex_);
92     std::deque<ConnectionEvent> conn_read_queue_;
93     // Map of fds to their PairingConnections currently running.
94     std::unordered_map<FdVal, ConnectionPtr> connections_;
95 
96     // Two threads launched when starting the pairing server:
97     // 1) A server thread that waits for incoming client connections, and
98     // 2) A connection events thread that synchonizes events from all of the
99     //    clients, since each PairingConnection is running in it's own thread.
100     void startConnectionEventsThread();
101     void startServerThread();
102 
103     std::thread conn_events_thread_;
104     void connectionEventsWorker();
105     std::thread server_thread_;
106     void serverWorker();
107     bool is_terminate_ GUARDED_BY(conn_mutex_) = false;
108 
109     enum class State {
110         Ready,
111         Running,
112         Stopped,
113     };
114     State state_ = State::Ready;
115     Data pswd_;
116     PeerInfo peer_info_;
117     Data cert_;
118     Data priv_key_;
119     int port_ = -1;
120 
121     PairingConnection::ResultCallback cb_;
122     void* opaque_ = nullptr;
123     bool got_valid_pairing_ = false;
124 
125     static const int kEpollConstSocket = 0;
126     // Used to break the server thread from epoll_wait
127     static const int kEpollConstEventFd = 1;
128     unique_fd epoll_fd_;
129     unique_fd server_fd_;
130     unique_fd event_fd_;
131 };  // PairingServerImpl
132 
PairingServerImpl(const Data & pswd,const PeerInfo & peer_info,const Data & cert,const Data & priv_key,int port)133 PairingServerImpl::PairingServerImpl(const Data& pswd, const PeerInfo& peer_info, const Data& cert,
134                                      const Data& priv_key, int port)
135     : pswd_(pswd), peer_info_(peer_info), cert_(cert), priv_key_(priv_key), port_(port) {
136     CHECK(!pswd_.empty() && !cert_.empty() && !priv_key_.empty() && port_ > 0);
137     CHECK('\0' == peer_info.name[kPeerNameLength - 1] &&
138           '\0' == peer_info.guid[kPeerGuidLength - 1] && strlen(peer_info.name) > 0 &&
139           strlen(peer_info.guid) > 0);
140 }
141 
~PairingServerImpl()142 PairingServerImpl::~PairingServerImpl() {
143     // Since these connections have references to us, let's make sure they
144     // destruct before us.
145     if (server_thread_.joinable()) {
146         stopServer();
147         server_thread_.join();
148     }
149 
150     {
151         std::lock_guard<std::mutex> lock(conn_mutex_);
152         is_terminate_ = true;
153     }
154     conn_cv_.notify_one();
155     if (conn_events_thread_.joinable()) {
156         conn_events_thread_.join();
157     }
158 
159     // Notify the cb_ if it hasn't already.
160     if (!got_valid_pairing_ && cb_ != nullptr) {
161         cb_(nullptr, nullptr, opaque_);
162     }
163 }
164 
start(PairingConnection::ResultCallback cb,void * opaque)165 bool PairingServerImpl::start(PairingConnection::ResultCallback cb, void* opaque) {
166     cb_ = cb;
167     opaque_ = opaque;
168 
169     if (state_ != State::Ready) {
170         LOG(ERROR) << "PairingServer already running or stopped";
171         return false;
172     }
173 
174     if (!setupServer()) {
175         LOG(ERROR) << "Unable to start PairingServer";
176         state_ = State::Stopped;
177         return false;
178     }
179 
180     state_ = State::Running;
181     return true;
182 }
183 
stopServer()184 void PairingServerImpl::stopServer() {
185     if (event_fd_.get() == -1) {
186         return;
187     }
188     uint64_t value = 1;
189     ssize_t rc = write(event_fd_.get(), &value, sizeof(value));
190     if (rc == -1) {
191         // This can happen if the server didn't start.
192         PLOG(ERROR) << "write to eventfd failed";
193     } else if (rc != sizeof(value)) {
194         LOG(FATAL) << "write to event returned short (" << rc << ")";
195     }
196 }
197 
setupServer()198 bool PairingServerImpl::setupServer() {
199     epoll_fd_.reset(epoll_create1(EPOLL_CLOEXEC));
200     if (epoll_fd_ == -1) {
201         PLOG(ERROR) << "failed to create epoll fd";
202         return false;
203     }
204 
205     event_fd_.reset(eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK));
206     if (event_fd_ == -1) {
207         PLOG(ERROR) << "failed to create eventfd";
208         return false;
209     }
210 
211     server_fd_.reset(socket_inaddr_any_server(port_, SOCK_STREAM));
212     if (server_fd_.get() == -1) {
213         PLOG(ERROR) << "Failed to start pairing connection server";
214         return false;
215     }
216 
217     startConnectionEventsThread();
218     startServerThread();
219     return true;
220 }
221 
startServerThread()222 void PairingServerImpl::startServerThread() {
223     server_thread_ = std::thread([this]() { serverWorker(); });
224 }
225 
startConnectionEventsThread()226 void PairingServerImpl::startConnectionEventsThread() {
227     conn_events_thread_ = std::thread([this]() { connectionEventsWorker(); });
228 }
229 
serverWorker()230 void PairingServerImpl::serverWorker() {
231     {
232         struct epoll_event event;
233         event.events = EPOLLIN;
234         event.data.u64 = kEpollConstSocket;
235         CHECK_EQ(0, epoll_ctl(epoll_fd_.get(), EPOLL_CTL_ADD, server_fd_.get(), &event));
236     }
237 
238     {
239         struct epoll_event event;
240         event.events = EPOLLIN;
241         event.data.u64 = kEpollConstEventFd;
242         CHECK_EQ(0, epoll_ctl(epoll_fd_.get(), EPOLL_CTL_ADD, event_fd_.get(), &event));
243     }
244 
245     while (true) {
246         struct epoll_event events[2];
247         int rc = TEMP_FAILURE_RETRY(epoll_wait(epoll_fd_.get(), events, 2, -1));
248         if (rc == -1) {
249             PLOG(ERROR) << "epoll_wait failed";
250             return;
251         } else if (rc == 0) {
252             LOG(ERROR) << "epoll_wait returned 0";
253             return;
254         }
255 
256         for (int i = 0; i < rc; ++i) {
257             struct epoll_event& event = events[i];
258             switch (event.data.u64) {
259                 case kEpollConstSocket:
260                     handleNewClientConnection(server_fd_.get());
261                     break;
262                 case kEpollConstEventFd:
263                     uint64_t dummy;
264                     int rc = TEMP_FAILURE_RETRY(read(event_fd_.get(), &dummy, sizeof(dummy)));
265                     if (rc != sizeof(dummy)) {
266                         PLOG(FATAL) << "failed to read from eventfd (rc=" << rc << ")";
267                     }
268                     return;
269             }
270         }
271     }
272 }
273 
connectionEventsWorker()274 void PairingServerImpl::connectionEventsWorker() {
275     for (;;) {
276         // Transfer the write queue to the read queue.
277         {
278             std::unique_lock<std::mutex> lock(conn_mutex_);
279             ScopedLockAssertion assume_locked(conn_mutex_);
280 
281             if (is_terminate_) {
282                 // We check |is_terminate_| twice because condition_variable's
283                 // notify() only wakes up a thread if it is in the wait state
284                 // prior to notify(). Furthermore, we aren't holding the mutex
285                 // when processing the events in |conn_read_queue_|.
286                 return;
287             }
288             if (conn_write_queue_.empty()) {
289                 // We need to wait for new events, or the termination signal.
290                 conn_cv_.wait(lock, [this]() REQUIRES(conn_mutex_) {
291                     return (is_terminate_ || !conn_write_queue_.empty());
292                 });
293             }
294             if (is_terminate_) {
295                 // We're done.
296                 return;
297             }
298             // Move all events into the read queue.
299             conn_read_queue_ = std::move(conn_write_queue_);
300             conn_write_queue_.clear();
301         }
302 
303         // Process all events in the read queue.
304         while (conn_read_queue_.size() > 0) {
305             auto& event = conn_read_queue_.front();
306             if (auto* p = std::get_if<NewConnectionEvent>(&event)) {
307                 // Ignore if we are already at the max number of connections
308                 if (connections_.size() >= internal::kMaxConnections) {
309                     conn_read_queue_.pop_front();
310                     continue;
311                 }
312                 auto [ufd, connection] = std::move(*p);
313                 int fd = ufd.release();
314                 bool started = connection->start(
315                         fd,
316                         [fd](const PeerInfo* peer_info, const Data* cert, void* opaque) {
317                             auto* p = reinterpret_cast<PairingServerImpl*>(opaque);
318 
319                             ConnectionFinishedEvent event;
320                             if (peer_info != nullptr && cert != nullptr) {
321                                 event = std::make_tuple(fd, std::string(peer_info->name),
322                                                         std::string(peer_info->guid), Data(*cert));
323                             } else {
324                                 event = std::make_tuple(fd, std::nullopt, std::nullopt,
325                                                         std::nullopt);
326                             }
327                             {
328                                 std::lock_guard<std::mutex> lock(p->conn_mutex_);
329                                 p->conn_write_queue_.push_back(std::move(event));
330                             }
331                             p->conn_cv_.notify_one();
332                         },
333                         this);
334                 if (!started) {
335                     LOG(ERROR) << "PairingServer unable to start a PairingConnection fd=" << fd;
336                     ufd.reset(fd);
337                 } else {
338                     connections_[fd] = std::move(connection);
339                 }
340             } else if (auto* p = std::get_if<ConnectionFinishedEvent>(&event)) {
341                 auto [fd, name, guid, cert] = std::move(*p);
342                 if (name.has_value() && guid.has_value() && cert.has_value() && !name->empty() &&
343                     !guid->empty() && !cert->empty()) {
344                     // Valid pairing. Let's shutdown the server and close any
345                     // pairing connections in progress.
346                     stopServer();
347                     connections_.clear();
348 
349                     CHECK_LE(name->size(), kPeerNameLength);
350                     CHECK_LE(guid->size(), kPeerGuidLength);
351                     PeerInfo info = {};
352                     strncpy(info.name, name->data(), name->size());
353                     strncpy(info.guid, guid->data(), guid->size());
354 
355                     cb_(&info, &*cert, opaque_);
356 
357                     got_valid_pairing_ = true;
358                     return;
359                 }
360                 // Invalid pairing. Close the invalid connection.
361                 if (connections_.find(fd) != connections_.end()) {
362                     connections_.erase(fd);
363                 }
364             }
365             conn_read_queue_.pop_front();
366         }
367     }
368 }
369 
handleNewClientConnection(int fd)370 bool PairingServerImpl::handleNewClientConnection(int fd) {
371     unique_fd ufd(TEMP_FAILURE_RETRY(accept4(fd, nullptr, nullptr, SOCK_CLOEXEC)));
372     if (ufd == -1) {
373         PLOG(WARNING) << "adb_socket_accept failed fd=" << fd;
374         return false;
375     }
376     auto connection = PairingConnection::create(PairingConnection::Role::Server, pswd_, peer_info_,
377                                                 cert_, priv_key_);
378     if (connection == nullptr) {
379         LOG(ERROR) << "PairingServer unable to create a PairingConnection fd=" << fd;
380         return false;
381     }
382     // send the new connection to the connection thread for further processing
383     NewConnectionEvent event = std::make_tuple(std::move(ufd), std::move(connection));
384     {
385         std::lock_guard<std::mutex> lock(conn_mutex_);
386         conn_write_queue_.push_back(std::move(event));
387     }
388     conn_cv_.notify_one();
389 
390     return true;
391 }
392 
393 }  // namespace
394 
395 // static
create(const Data & pswd,const PeerInfo & peer_info,const Data & cert,const Data & priv_key,int port)396 std::unique_ptr<PairingServer> PairingServer::create(const Data& pswd, const PeerInfo& peer_info,
397                                                      const Data& cert, const Data& priv_key,
398                                                      int port) {
399     if (pswd.empty() || cert.empty() || priv_key.empty() || port <= 0) {
400         return nullptr;
401     }
402     // Make sure peer_info has a non-empty, null-terminated string for guid and
403     // name.
404     if ('\0' != peer_info.name[kPeerNameLength - 1] ||
405         '\0' != peer_info.guid[kPeerGuidLength - 1] || strlen(peer_info.name) == 0 ||
406         strlen(peer_info.guid) == 0) {
407         LOG(ERROR) << "The GUID/short name fields are empty or not null-terminated";
408         return nullptr;
409     }
410 
411     if (port != kDefaultPairingPort) {
412         LOG(WARNING) << "Starting server with non-default pairing port=" << port;
413     }
414 
415     return std::unique_ptr<PairingServer>(
416             new PairingServerImpl(pswd, peer_info, cert, priv_key, port));
417 }
418 
419 }  // namespace pairing
420 }  // namespace adbwifi
421