1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "dns_tls_frontend.h"
18 
19 #include <arpa/inet.h>
20 #include <netdb.h>
21 #include <openssl/err.h>
22 #include <openssl/evp.h>
23 #include <openssl/ssl.h>
24 #include <openssl/x509.h>
25 #include <sys/eventfd.h>
26 #include <sys/poll.h>
27 #include <sys/socket.h>
28 #include <sys/types.h>
29 #include <unistd.h>
30 
31 #define LOG_TAG "DnsTlsFrontend"
32 #include <android-base/logging.h>
33 #include <netdutils/InternetAddresses.h>
34 #include <netdutils/SocketOption.h>
35 #include "dns_responder.h"
36 #include "dns_tls_certificate.h"
37 
38 using android::netdutils::enableSockopt;
39 using android::netdutils::ScopedAddrinfo;
40 
41 namespace {
stringToX509Certs(const char * certs)42 static bssl::UniquePtr<X509> stringToX509Certs(const char* certs) {
43     bssl::UniquePtr<BIO> bio(BIO_new_mem_buf(certs, strlen(certs)));
44     return bssl::UniquePtr<X509>(PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr));
45 }
46 
47 // Convert a string buffer containing an RSA Private Key into an OpenSSL RSA struct.
stringToRSAPrivateKey(const char * key)48 static bssl::UniquePtr<RSA> stringToRSAPrivateKey(const char* key) {
49     bssl::UniquePtr<BIO> bio(BIO_new_mem_buf(key, strlen(key)));
50     return bssl::UniquePtr<RSA>(PEM_read_bio_RSAPrivateKey(bio.get(), nullptr, nullptr, nullptr));
51 }
52 
addr2str(const sockaddr * sa,socklen_t sa_len)53 std::string addr2str(const sockaddr* sa, socklen_t sa_len) {
54     char host_str[NI_MAXHOST] = {0};
55     int rv = getnameinfo(sa, sa_len, host_str, sizeof(host_str), nullptr, 0, NI_NUMERICHOST);
56     if (rv == 0) return std::string(host_str);
57     return std::string();
58 }
59 
60 }  // namespace
61 
62 namespace test {
63 
startServer()64 bool DnsTlsFrontend::startServer() {
65     OpenSSL_add_ssl_algorithms();
66 
67     ctx_.reset(SSL_CTX_new(TLS_server_method()));
68     if (!ctx_) {
69         LOG(ERROR) << "SSL context creation failed";
70         return false;
71     }
72 
73     SSL_CTX_set_ecdh_auto(ctx_.get(), 1);
74 
75     bssl::UniquePtr<X509> ca_certs(stringToX509Certs(kCertificate));
76     if (!ca_certs) {
77         LOG(ERROR) << "StringToX509Certs failed";
78         return false;
79     }
80 
81     if (SSL_CTX_use_certificate(ctx_.get(), ca_certs.get()) <= 0) {
82         LOG(ERROR) << "SSL_CTX_use_certificate failed";
83         return false;
84     }
85 
86     bssl::UniquePtr<RSA> private_key(stringToRSAPrivateKey(kPrivatekey));
87     if (SSL_CTX_use_RSAPrivateKey(ctx_.get(), private_key.get()) <= 0) {
88         LOG(ERROR) << "Error loading client RSA Private Key data.";
89         return false;
90     }
91 
92     // Set up TCP server socket for clients.
93     addrinfo frontend_ai_hints{
94             .ai_flags = AI_PASSIVE,
95             .ai_family = AF_UNSPEC,
96             .ai_socktype = SOCK_STREAM,
97     };
98     addrinfo* frontend_ai_res = nullptr;
99     int rv = getaddrinfo(listen_address_.c_str(), listen_service_.c_str(), &frontend_ai_hints,
100                          &frontend_ai_res);
101     ScopedAddrinfo frontend_ai_res_cleanup(frontend_ai_res);
102     if (rv) {
103         LOG(ERROR) << "frontend getaddrinfo(" << listen_address_.c_str() << ", "
104                    << listen_service_.c_str() << ") failed: " << gai_strerror(rv);
105         return false;
106     }
107 
108     for (const addrinfo* ai = frontend_ai_res; ai; ai = ai->ai_next) {
109         android::base::unique_fd s(socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol));
110         if (s.get() < 0) {
111             PLOG(INFO) << "ignore creating socket failed " << s.get();
112             continue;
113         }
114         enableSockopt(s.get(), SOL_SOCKET, SO_REUSEADDR).ignoreError();
115         std::string host_str = addr2str(ai->ai_addr, ai->ai_addrlen);
116         if (bind(s.get(), ai->ai_addr, ai->ai_addrlen)) {
117             PLOG(INFO) << "failed to bind TCP " << host_str.c_str() << ":"
118                        << listen_service_.c_str();
119             continue;
120         }
121         LOG(INFO) << "bound to TCP " << host_str.c_str() << ":" << listen_service_.c_str();
122         socket_ = std::move(s);
123         break;
124     }
125 
126     if (listen(socket_.get(), 1) < 0) {
127         PLOG(INFO) << "failed to listen socket " << socket_.get();
128         return false;
129     }
130 
131     // Set up UDP client socket to backend.
132     addrinfo backend_ai_hints{.ai_family = AF_UNSPEC, .ai_socktype = SOCK_DGRAM};
133     addrinfo* backend_ai_res = nullptr;
134     rv = getaddrinfo(backend_address_.c_str(), backend_service_.c_str(), &backend_ai_hints,
135                      &backend_ai_res);
136     ScopedAddrinfo backend_ai_res_cleanup(backend_ai_res);
137     if (rv) {
138         LOG(ERROR) << "backend getaddrinfo(" << listen_address_.c_str() << ", "
139                    << listen_service_.c_str() << ") failed: " << gai_strerror(rv);
140         return false;
141     }
142     backend_socket_.reset(socket(backend_ai_res->ai_family, backend_ai_res->ai_socktype,
143                                  backend_ai_res->ai_protocol));
144     if (backend_socket_.get() < 0) {
145         PLOG(INFO) << "backend socket " << backend_socket_.get() << " creation failed";
146         return false;
147     }
148 
149     // connect() always fails in the test DnsTlsSocketTest.SlowDestructor because of
150     // no backend server. Don't check it.
151     static_cast<void>(
152             connect(backend_socket_.get(), backend_ai_res->ai_addr, backend_ai_res->ai_addrlen));
153 
154     // Set up eventfd socket.
155     event_fd_.reset(eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC));
156     if (event_fd_.get() == -1) {
157         PLOG(INFO) << "failed to create eventfd " << event_fd_.get();
158         return false;
159     }
160 
161     {
162         std::lock_guard lock(update_mutex_);
163         handler_thread_ = std::thread(&DnsTlsFrontend::requestHandler, this);
164     }
165     LOG(INFO) << "server started successfully";
166     return true;
167 }
168 
requestHandler()169 void DnsTlsFrontend::requestHandler() {
170     LOG(DEBUG) << "Request handler started";
171     enum { EVENT_FD = 0, LISTEN_FD = 1 };
172     pollfd fds[2] = {{.fd = event_fd_.get(), .events = POLLIN},
173                      {.fd = socket_.get(), .events = POLLIN}};
174     android::base::unique_fd clientFd;
175 
176     while (true) {
177         int poll_code = poll(fds, std::size(fds), -1);
178         if (poll_code <= 0) {
179             PLOG(WARNING) << "Poll failed with error " << poll_code;
180             break;
181         }
182 
183         if (fds[EVENT_FD].revents & (POLLIN | POLLERR)) {
184             handleEventFd();
185             break;
186         }
187         if (fds[LISTEN_FD].revents & (POLLIN | POLLERR)) {
188             sockaddr_storage addr;
189             socklen_t len = sizeof(addr);
190 
191             LOG(DEBUG) << "Trying to accept a client";
192             android::base::unique_fd client(
193                     accept4(socket_.get(), reinterpret_cast<sockaddr*>(&addr), &len, SOCK_CLOEXEC));
194             if (client.get() < 0) {
195                 // Stop
196                 PLOG(INFO) << "failed to accept client socket " << client.get();
197                 break;
198             }
199 
200             accept_connection_count_++;
201             if (hangOnHandshake_) {
202                 LOG(DEBUG) << "TEST ONLY: unresponsive to SSL handshake";
203 
204                 // The previous fd already stored in clientFd will be closed automatically.
205                 clientFd = std::move(client);
206                 continue;
207             }
208 
209             bssl::UniquePtr<SSL> ssl(SSL_new(ctx_.get()));
210             SSL_set_fd(ssl.get(), client.get());
211 
212             LOG(DEBUG) << "Doing SSL handshake";
213             if (SSL_accept(ssl.get()) <= 0) {
214                 LOG(INFO) << "SSL negotiation failure";
215             } else {
216                 LOG(DEBUG) << "SSL handshake complete";
217                 handleRequests(ssl.get(), client.get());
218             }
219 
220             if (passiveClose_) {
221                 LOG(DEBUG) << "hold the current connection until next connection request";
222                 clientFd = std::move(client);
223             }
224         }
225     }
226     LOG(DEBUG) << "Ending loop";
227 }
228 
handleRequests(SSL * ssl,int clientFd)229 void DnsTlsFrontend::handleRequests(SSL* ssl, int clientFd) {
230     int queryCounts = 0;
231     std::vector<uint8_t> reply;
232     bool isDotProbe = false;
233     pollfd fds = {.fd = clientFd, .events = POLLIN};
234 again:
235     do {
236         uint8_t queryHeader[2];
237         if (SSL_read(ssl, &queryHeader, 2) != 2) {
238             LOG(INFO) << "Not enough header bytes";
239             return;
240         }
241         const uint16_t qlen = (queryHeader[0] << 8) | queryHeader[1];
242         uint8_t query[qlen];
243         size_t qbytes = 0;
244         while (qbytes < qlen) {
245             int ret = SSL_read(ssl, query + qbytes, qlen - qbytes);
246             if (ret <= 0) {
247                 LOG(INFO) << "Error while reading query";
248                 return;
249             }
250             qbytes += ret;
251         }
252         int sent = send(backend_socket_.get(), query, qlen, 0);
253         if (sent != qlen) {
254             LOG(INFO) << "Failed to send query";
255             return;
256         }
257 
258         if (!isDotProbe) {
259             DNSHeader dnsHdr;
260             dnsHdr.read((char*)query, (char*)query + qlen);
261             for (const auto& question : dnsHdr.questions) {
262                 if (question.qname.name.find("dnsotls-ds.metric.gstatic.com") !=
263                     std::string::npos) {
264                     isDotProbe = true;
265                     break;
266                 }
267             }
268         }
269 
270         const int max_size = 4096;
271         uint8_t recv_buffer[max_size];
272         int rlen = recv(backend_socket_.get(), recv_buffer, max_size, 0);
273         if (rlen <= 0) {
274             LOG(INFO) << "Failed to receive response";
275             return;
276         }
277         uint8_t responseHeader[2];
278         responseHeader[0] = rlen >> 8;
279         responseHeader[1] = rlen;
280         reply.insert(reply.end(), responseHeader, responseHeader + 2);
281         reply.insert(reply.end(), recv_buffer, recv_buffer + rlen);
282 
283         // Increment queries_ before the answers are sent. This makes sure that a test always
284         // reads the updated value returned from queries() after receiving the DNS answers.
285         ++queries_;
286         ++queryCounts;
287         if (queryCounts >= delayQueries_) {
288             break;
289         }
290     } while (poll(&fds, 1, delayQueriesTimeout_) > 0);
291 
292     if (queryCounts < delayQueries_) {
293         LOG(WARNING) << "Expect " << delayQueries_ << " queries, but actually received "
294                      << queryCounts << " queries";
295     }
296 
297     const int replyLen = reply.size();
298     LOG(DEBUG) << "Sending " << queryCounts << " queries at once, byte = " << replyLen;
299     if (SSL_write(ssl, reply.data(), replyLen) != replyLen) {
300         LOG(WARNING) << "Failed to write response body";
301     }
302 
303     // Poll again because the same DoT probe might be sent again.
304     if (isDotProbe && queryCounts == 1) {
305         const int timeoutMs = 500;
306         int n = poll(&fds, 1, timeoutMs);
307         if (n > 0 && fds.revents & POLLIN) {
308             goto again;
309         }
310         LOG(WARNING) << "Did not receive the second DoT probe within " << timeoutMs << "ms";
311     }
312 
313     LOG(DEBUG) << __func__ << " return: " << queryCounts;
314     return;
315 }
316 
stopServer()317 bool DnsTlsFrontend::stopServer() {
318     std::lock_guard lock(update_mutex_);
319     if (!running()) {
320         LOG(INFO) << "server not running";
321         return false;
322     }
323 
324     LOG(INFO) << "stopping frontend";
325     if (!sendToEventFd()) {
326         return false;
327     }
328     handler_thread_.join();
329     socket_.reset();
330     backend_socket_.reset();
331     event_fd_.reset();
332     ctx_.reset();
333     LOG(INFO) << "frontend stopped successfully";
334     return true;
335 }
336 
337 // TODO: use a condition variable instead of polling
338 // TODO: also clear queries_ to eliminate potential race conditions
waitForQueries(int expected_count) const339 bool DnsTlsFrontend::waitForQueries(int expected_count) const {
340     constexpr int intervalMs = 20;
341     constexpr int timeoutMs = 5000;
342     int limit = timeoutMs / intervalMs;
343     for (int count = 0; count <= limit; ++count) {
344         bool done = queries_ >= expected_count;
345         // Always sleep at least one more interval after we are done, to wait for
346         // any immediate post-query actions that the client may take (such as
347         // marking this server as reachable during validation).
348         usleep(intervalMs * 1000);
349         if (done) {
350             // For ensuring that calls have sufficient headroom for slow machines
351             LOG(DEBUG) << "Query arrived in " << count << "/" << limit << " of allotted time";
352             return true;
353         }
354     }
355     return false;
356 }
357 
sendToEventFd()358 bool DnsTlsFrontend::sendToEventFd() {
359     const uint64_t data = 1;
360     if (const ssize_t rt = write(event_fd_.get(), &data, sizeof(data)); rt != sizeof(data)) {
361         PLOG(INFO) << "failed to write eventfd, rt=" << rt;
362         return false;
363     }
364     return true;
365 }
366 
handleEventFd()367 void DnsTlsFrontend::handleEventFd() {
368     int64_t data;
369     if (const ssize_t rt = read(event_fd_.get(), &data, sizeof(data)); rt != sizeof(data)) {
370         PLOG(INFO) << "ignore reading eventfd failed, rt=" << rt;
371     }
372 }
373 
374 }  // namespace test
375