1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "dns_tls_frontend.h" 18 19 #include <arpa/inet.h> 20 #include <netdb.h> 21 #include <openssl/err.h> 22 #include <openssl/evp.h> 23 #include <openssl/ssl.h> 24 #include <openssl/x509.h> 25 #include <sys/eventfd.h> 26 #include <sys/poll.h> 27 #include <sys/socket.h> 28 #include <sys/types.h> 29 #include <unistd.h> 30 31 #define LOG_TAG "DnsTlsFrontend" 32 #include <android-base/logging.h> 33 #include <netdutils/InternetAddresses.h> 34 #include <netdutils/SocketOption.h> 35 #include "dns_responder.h" 36 #include "dns_tls_certificate.h" 37 38 using android::netdutils::enableSockopt; 39 using android::netdutils::ScopedAddrinfo; 40 41 namespace { 42 static bssl::UniquePtr<X509> stringToX509Certs(const char* certs) { 43 bssl::UniquePtr<BIO> bio(BIO_new_mem_buf(certs, strlen(certs))); 44 return bssl::UniquePtr<X509>(PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr)); 45 } 46 47 // Convert a string buffer containing an RSA Private Key into an OpenSSL RSA struct. 48 static bssl::UniquePtr<RSA> stringToRSAPrivateKey(const char* key) { 49 bssl::UniquePtr<BIO> bio(BIO_new_mem_buf(key, strlen(key))); 50 return bssl::UniquePtr<RSA>(PEM_read_bio_RSAPrivateKey(bio.get(), nullptr, nullptr, nullptr)); 51 } 52 53 std::string addr2str(const sockaddr* sa, socklen_t sa_len) { 54 char host_str[NI_MAXHOST] = {0}; 55 int rv = getnameinfo(sa, sa_len, host_str, sizeof(host_str), nullptr, 0, NI_NUMERICHOST); 56 if (rv == 0) return std::string(host_str); 57 return std::string(); 58 } 59 60 } // namespace 61 62 namespace test { 63 64 bool DnsTlsFrontend::startServer() { 65 OpenSSL_add_ssl_algorithms(); 66 67 // reset queries_ to 0 every time startServer called 68 // which would help us easy to check queries_ via calling waitForQueries 69 queries_ = 0; 70 71 ctx_.reset(SSL_CTX_new(TLS_server_method())); 72 if (!ctx_) { 73 LOG(ERROR) << "SSL context creation failed"; 74 return false; 75 } 76 77 SSL_CTX_set_ecdh_auto(ctx_.get(), 1); 78 79 bssl::UniquePtr<X509> ca_certs(stringToX509Certs(kCertificate)); 80 if (!ca_certs) { 81 LOG(ERROR) << "StringToX509Certs failed"; 82 return false; 83 } 84 85 if (SSL_CTX_use_certificate(ctx_.get(), ca_certs.get()) <= 0) { 86 LOG(ERROR) << "SSL_CTX_use_certificate failed"; 87 return false; 88 } 89 90 bssl::UniquePtr<RSA> private_key(stringToRSAPrivateKey(kPrivatekey)); 91 if (SSL_CTX_use_RSAPrivateKey(ctx_.get(), private_key.get()) <= 0) { 92 LOG(ERROR) << "Error loading client RSA Private Key data."; 93 return false; 94 } 95 96 // Set up TCP server socket for clients. 97 addrinfo frontend_ai_hints{ 98 .ai_flags = AI_PASSIVE, 99 .ai_family = AF_UNSPEC, 100 .ai_socktype = SOCK_STREAM, 101 }; 102 addrinfo* frontend_ai_res = nullptr; 103 int rv = getaddrinfo(listen_address_.c_str(), listen_service_.c_str(), &frontend_ai_hints, 104 &frontend_ai_res); 105 ScopedAddrinfo frontend_ai_res_cleanup(frontend_ai_res); 106 if (rv) { 107 LOG(ERROR) << "frontend getaddrinfo(" << listen_address_.c_str() << ", " 108 << listen_service_.c_str() << ") failed: " << gai_strerror(rv); 109 return false; 110 } 111 112 for (const addrinfo* ai = frontend_ai_res; ai; ai = ai->ai_next) { 113 android::base::unique_fd s(socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol)); 114 if (s.get() < 0) { 115 PLOG(INFO) << "ignore creating socket failed " << s.get(); 116 continue; 117 } 118 enableSockopt(s.get(), SOL_SOCKET, SO_REUSEADDR).ignoreError(); 119 std::string host_str = addr2str(ai->ai_addr, ai->ai_addrlen); 120 if (bind(s.get(), ai->ai_addr, ai->ai_addrlen)) { 121 PLOG(INFO) << "failed to bind TCP " << host_str.c_str() << ":" 122 << listen_service_.c_str(); 123 continue; 124 } 125 LOG(INFO) << "bound to TCP " << host_str.c_str() << ":" << listen_service_.c_str(); 126 socket_ = std::move(s); 127 break; 128 } 129 130 if (listen(socket_.get(), 1) < 0) { 131 PLOG(INFO) << "failed to listen socket " << socket_.get(); 132 return false; 133 } 134 135 // Set up UDP client socket to backend. 136 addrinfo backend_ai_hints{.ai_family = AF_UNSPEC, .ai_socktype = SOCK_DGRAM}; 137 addrinfo* backend_ai_res = nullptr; 138 rv = getaddrinfo(backend_address_.c_str(), backend_service_.c_str(), &backend_ai_hints, 139 &backend_ai_res); 140 ScopedAddrinfo backend_ai_res_cleanup(backend_ai_res); 141 if (rv) { 142 LOG(ERROR) << "backend getaddrinfo(" << listen_address_.c_str() << ", " 143 << listen_service_.c_str() << ") failed: " << gai_strerror(rv); 144 return false; 145 } 146 backend_socket_.reset(socket(backend_ai_res->ai_family, backend_ai_res->ai_socktype, 147 backend_ai_res->ai_protocol)); 148 if (backend_socket_.get() < 0) { 149 PLOG(INFO) << "backend socket " << backend_socket_.get() << " creation failed"; 150 return false; 151 } 152 153 // connect() always fails in the test DnsTlsSocketTest.SlowDestructor because of 154 // no backend server. Don't check it. 155 static_cast<void>( 156 connect(backend_socket_.get(), backend_ai_res->ai_addr, backend_ai_res->ai_addrlen)); 157 158 // Set up eventfd socket. 159 event_fd_.reset(eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC)); 160 if (event_fd_.get() == -1) { 161 PLOG(INFO) << "failed to create eventfd " << event_fd_.get(); 162 return false; 163 } 164 165 { 166 std::lock_guard lock(update_mutex_); 167 handler_thread_ = std::thread(&DnsTlsFrontend::requestHandler, this); 168 } 169 LOG(INFO) << "server started successfully"; 170 return true; 171 } 172 173 void DnsTlsFrontend::requestHandler() { 174 LOG(DEBUG) << "Request handler started"; 175 enum { EVENT_FD = 0, LISTEN_FD = 1 }; 176 pollfd fds[2] = {{.fd = event_fd_.get(), .events = POLLIN}, 177 {.fd = socket_.get(), .events = POLLIN}}; 178 android::base::unique_fd clientFd; 179 180 while (true) { 181 int poll_code = poll(fds, std::size(fds), -1); 182 if (poll_code <= 0) { 183 PLOG(WARNING) << "Poll failed with error " << poll_code; 184 break; 185 } 186 187 if (fds[EVENT_FD].revents & (POLLIN | POLLERR)) { 188 handleEventFd(); 189 break; 190 } 191 if (fds[LISTEN_FD].revents & (POLLIN | POLLERR)) { 192 sockaddr_storage addr; 193 socklen_t len = sizeof(addr); 194 195 LOG(DEBUG) << "Trying to accept a client"; 196 android::base::unique_fd client( 197 accept4(socket_.get(), reinterpret_cast<sockaddr*>(&addr), &len, SOCK_CLOEXEC)); 198 if (client.get() < 0) { 199 // Stop 200 PLOG(INFO) << "failed to accept client socket " << client.get(); 201 break; 202 } 203 204 accept_connection_count_++; 205 if (hangOnHandshake_) { 206 LOG(DEBUG) << "TEST ONLY: unresponsive to SSL handshake"; 207 208 // The previous fd already stored in clientFd will be closed automatically. 209 clientFd = std::move(client); 210 continue; 211 } 212 213 bssl::UniquePtr<SSL> ssl(SSL_new(ctx_.get())); 214 SSL_set_fd(ssl.get(), client.get()); 215 216 LOG(DEBUG) << "Doing SSL handshake"; 217 if (SSL_accept(ssl.get()) <= 0) { 218 LOG(INFO) << "SSL negotiation failure"; 219 } else { 220 LOG(DEBUG) << "SSL handshake complete"; 221 // Increment queries_ as late as possible, because it represents 222 // a query that is fully processed, and the response returned to the 223 // client, including cleanup actions. 224 queries_ += handleRequests(ssl.get(), client.get()); 225 } 226 227 if (passiveClose_) { 228 LOG(DEBUG) << "hold the current connection until next connection request"; 229 clientFd = std::move(client); 230 } 231 } 232 } 233 LOG(DEBUG) << "Ending loop"; 234 } 235 236 int DnsTlsFrontend::handleRequests(SSL* ssl, int clientFd) { 237 int queryCounts = 0; 238 std::vector<uint8_t> reply; 239 bool isDotProbe = false; 240 pollfd fds = {.fd = clientFd, .events = POLLIN}; 241 again: 242 do { 243 uint8_t queryHeader[2]; 244 if (SSL_read(ssl, &queryHeader, 2) != 2) { 245 LOG(INFO) << "Not enough header bytes"; 246 return queryCounts; 247 } 248 const uint16_t qlen = (queryHeader[0] << 8) | queryHeader[1]; 249 uint8_t query[qlen]; 250 size_t qbytes = 0; 251 while (qbytes < qlen) { 252 int ret = SSL_read(ssl, query + qbytes, qlen - qbytes); 253 if (ret <= 0) { 254 LOG(INFO) << "Error while reading query"; 255 return queryCounts; 256 } 257 qbytes += ret; 258 } 259 int sent = send(backend_socket_.get(), query, qlen, 0); 260 if (sent != qlen) { 261 LOG(INFO) << "Failed to send query"; 262 return queryCounts; 263 } 264 265 if (!isDotProbe) { 266 DNSHeader dnsHdr; 267 dnsHdr.read((char*)query, (char*)query + qlen); 268 for (const auto& question : dnsHdr.questions) { 269 if (question.qname.name.find("dnsotls-ds.metric.gstatic.com") != 270 std::string::npos) { 271 isDotProbe = true; 272 break; 273 } 274 } 275 } 276 277 const int max_size = 4096; 278 uint8_t recv_buffer[max_size]; 279 int rlen = recv(backend_socket_.get(), recv_buffer, max_size, 0); 280 if (rlen <= 0) { 281 LOG(INFO) << "Failed to receive response"; 282 return queryCounts; 283 } 284 uint8_t responseHeader[2]; 285 responseHeader[0] = rlen >> 8; 286 responseHeader[1] = rlen; 287 reply.insert(reply.end(), responseHeader, responseHeader + 2); 288 reply.insert(reply.end(), recv_buffer, recv_buffer + rlen); 289 290 ++queryCounts; 291 if (queryCounts >= delayQueries_) { 292 break; 293 } 294 } while (poll(&fds, 1, delayQueriesTimeout_) > 0); 295 296 if (queryCounts < delayQueries_) { 297 LOG(WARNING) << "Expect " << delayQueries_ << " queries, but actually received " 298 << queryCounts << " queries"; 299 } 300 301 const int replyLen = reply.size(); 302 LOG(DEBUG) << "Sending " << queryCounts << "queries at once, byte = " << replyLen; 303 if (SSL_write(ssl, reply.data(), replyLen) != replyLen) { 304 LOG(WARNING) << "Failed to write response body"; 305 } 306 307 // Poll again because the same DoT probe might be sent again. 308 if (isDotProbe && queryCounts == 1) { 309 int n = poll(&fds, 1, 50); 310 if (n > 0 && fds.revents & POLLIN) { 311 goto again; 312 } 313 } 314 315 LOG(DEBUG) << __func__ << " return: " << queryCounts; 316 return queryCounts; 317 } 318 319 bool DnsTlsFrontend::stopServer() { 320 std::lock_guard lock(update_mutex_); 321 if (!running()) { 322 LOG(INFO) << "server not running"; 323 return false; 324 } 325 326 LOG(INFO) << "stopping frontend"; 327 if (!sendToEventFd()) { 328 return false; 329 } 330 handler_thread_.join(); 331 socket_.reset(); 332 backend_socket_.reset(); 333 event_fd_.reset(); 334 ctx_.reset(); 335 LOG(INFO) << "frontend stopped successfully"; 336 return true; 337 } 338 339 // TODO: use a condition variable instead of polling 340 // TODO: also clear queries_ to eliminate potential race conditions 341 bool DnsTlsFrontend::waitForQueries(int expected_count) const { 342 constexpr int intervalMs = 20; 343 constexpr int timeoutMs = 5000; 344 int limit = timeoutMs / intervalMs; 345 for (int count = 0; count <= limit; ++count) { 346 bool done = queries_ >= expected_count; 347 // Always sleep at least one more interval after we are done, to wait for 348 // any immediate post-query actions that the client may take (such as 349 // marking this server as reachable during validation). 350 usleep(intervalMs * 1000); 351 if (done) { 352 // For ensuring that calls have sufficient headroom for slow machines 353 LOG(DEBUG) << "Query arrived in " << count << "/" << limit << " of allotted time"; 354 return true; 355 } 356 } 357 return false; 358 } 359 360 bool DnsTlsFrontend::sendToEventFd() { 361 const uint64_t data = 1; 362 if (const ssize_t rt = write(event_fd_.get(), &data, sizeof(data)); rt != sizeof(data)) { 363 PLOG(INFO) << "failed to write eventfd, rt=" << rt; 364 return false; 365 } 366 return true; 367 } 368 369 void DnsTlsFrontend::handleEventFd() { 370 int64_t data; 371 if (const ssize_t rt = read(event_fd_.get(), &data, sizeof(data)); rt != sizeof(data)) { 372 PLOG(INFO) << "ignore reading eventfd failed, rt=" << rt; 373 } 374 } 375 376 } // namespace test 377