1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #define LOG_TAG "resolv" 18 19 #include "DnsTlsTransport.h" 20 21 #include <span> 22 23 #include <android-base/format.h> 24 #include <android-base/logging.h> 25 #include <android-base/result.h> 26 #include <android-base/stringprintf.h> 27 #include <arpa/inet.h> 28 #include <arpa/nameser.h> 29 #include <netdutils/Stopwatch.h> 30 #include <netdutils/ThreadUtil.h> 31 #include <private/android_filesystem_config.h> // AID_DNS 32 #include <sys/poll.h> 33 34 #include "DnsTlsSocketFactory.h" 35 #include "Experiments.h" 36 #include "IDnsTlsSocketFactory.h" 37 #include "resolv_private.h" 38 #include "util.h" 39 40 using android::base::StringPrintf; 41 using android::netdutils::setThreadName; 42 43 namespace android { 44 namespace net { 45 46 namespace { 47 48 // Make a DNS query for the hostname "<random>-dnsotls-ds.metric.gstatic.com". 49 std::vector<uint8_t> makeDnsQuery() { 50 static const char kDnsSafeChars[] = 51 "abcdefhijklmnopqrstuvwxyz" 52 "ABCDEFHIJKLMNOPQRSTUVWXYZ" 53 "0123456789"; 54 const auto c = [](uint8_t rnd) -> uint8_t { 55 return kDnsSafeChars[(rnd % std::size(kDnsSafeChars))]; 56 }; 57 uint8_t rnd[8]; 58 arc4random_buf(rnd, std::size(rnd)); 59 60 return std::vector<uint8_t>{ 61 rnd[6], rnd[7], // [0-1] query ID 62 1, 0, // [2-3] flags; query[2] = 1 for recursion desired (RD). 63 0, 1, // [4-5] QDCOUNT (number of queries) 64 0, 0, // [6-7] ANCOUNT (number of answers) 65 0, 0, // [8-9] NSCOUNT (number of name server records) 66 0, 0, // [10-11] ARCOUNT (number of additional records) 67 17, c(rnd[0]), c(rnd[1]), c(rnd[2]), c(rnd[3]), c(rnd[4]), c(rnd[5]), '-', 'd', 'n', 68 's', 'o', 't', 'l', 's', '-', 'd', 's', 6, 'm', 69 'e', 't', 'r', 'i', 'c', 7, 'g', 's', 't', 'a', 70 't', 'i', 'c', 3, 'c', 'o', 'm', 71 0, // null terminator of FQDN (root TLD) 72 0, ns_t_aaaa, // QTYPE 73 0, ns_c_in // QCLASS 74 }; 75 } 76 77 base::Result<void> checkDnsResponse(const std::span<const uint8_t> answer) { 78 if (answer.size() < NS_HFIXEDSZ) { 79 return Errorf("short response: {}", answer.size()); 80 } 81 82 const int qdcount = (answer[4] << 8) | answer[5]; 83 if (qdcount != 1) { 84 return Errorf("reply query count != 1: {}", qdcount); 85 } 86 87 const int ancount = (answer[6] << 8) | answer[7]; 88 LOG(DEBUG) << "answer count: " << ancount; 89 90 // TODO: Further validate the response contents (check for valid AAAA record, ...). 91 // Note that currently, integration tests rely on this function accepting a 92 // response with zero records. 93 94 return {}; 95 } 96 97 // Sends |query| to the given server, and returns the DNS response. 98 base::Result<void> sendUdpQuery(netdutils::IPAddress ip, uint32_t mark, 99 std::span<const uint8_t> query) { 100 const sockaddr_storage ss = netdutils::IPSockAddr(ip, 53); 101 const sockaddr* nsap = reinterpret_cast<const sockaddr*>(&ss); 102 const int nsaplen = sockaddrSize(nsap); 103 const int sockType = SOCK_DGRAM | SOCK_NONBLOCK | SOCK_CLOEXEC; 104 android::base::unique_fd fd{socket(nsap->sa_family, sockType, 0)}; 105 if (fd < 0) { 106 return ErrnoErrorf("socket failed"); 107 } 108 109 resolv_tag_socket(fd.get(), AID_DNS, NET_CONTEXT_INVALID_PID); 110 if (setsockopt(fd.get(), SOL_SOCKET, SO_MARK, &mark, sizeof(mark)) < 0) { 111 return ErrnoErrorf("setsockopt failed"); 112 } 113 114 if (connect(fd.get(), nsap, (socklen_t)nsaplen) < 0) { 115 return ErrnoErrorf("connect failed"); 116 } 117 118 if (send(fd, query.data(), query.size(), 0) != query.size()) { 119 return ErrnoErrorf("send failed"); 120 } 121 122 const int timeoutMs = 3000; 123 while (true) { 124 pollfd fds = {.fd = fd, .events = POLLIN}; 125 126 const int n = TEMP_FAILURE_RETRY(poll(&fds, 1, timeoutMs)); 127 if (n == 0) { 128 return Errorf("poll timed out"); 129 } 130 if (n < 0) { 131 return ErrnoErrorf("poll failed"); 132 } 133 if (fds.revents & (POLLIN | POLLERR)) { 134 std::vector<uint8_t> buf(MAXPACKET); 135 const int resplen = recv(fd, buf.data(), buf.size(), 0); 136 137 if (resplen < 0) { 138 return ErrnoErrorf("recvfrom failed"); 139 } 140 141 buf.resize(resplen); 142 if (auto result = checkDnsResponse(buf); !result.ok()) { 143 return Errorf("checkDnsResponse failed: {}", result.error().message()); 144 } 145 146 return {}; 147 } 148 } 149 } 150 151 } // namespace 152 153 std::future<DnsTlsTransport::Result> DnsTlsTransport::query(const netdutils::Slice query) { 154 std::lock_guard guard(mLock); 155 156 auto record = mQueries.recordQuery(query); 157 if (!record) { 158 return std::async(std::launch::deferred, []{ 159 return (Result) { .code = Response::internal_error }; 160 }); 161 } 162 163 if (!mSocket) { 164 LOG(DEBUG) << "No socket for query. Opening socket and sending."; 165 doConnect(); 166 } else { 167 sendQuery(record->query); 168 } 169 170 return std::move(record->result); 171 } 172 173 int DnsTlsTransport::getConnectCounter() const { 174 std::lock_guard guard(mLock); 175 return mConnectCounter; 176 } 177 178 bool DnsTlsTransport::sendQuery(const DnsTlsQueryMap::Query& q) { 179 // Strip off the ID number and send the new ID instead. 180 const bool sent = mSocket->query(q.newId, netdutils::drop(netdutils::makeSlice(q.query), 2)); 181 if (sent) { 182 mQueries.markTried(q.newId); 183 } 184 return sent; 185 } 186 187 void DnsTlsTransport::doConnect() { 188 LOG(DEBUG) << "Constructing new socket"; 189 mSocket = mFactory->createDnsTlsSocket(mServer, mMark, this, &mCache); 190 191 bool success = true; 192 if (mSocket.get() == nullptr || !mSocket->startHandshake()) { 193 success = false; 194 } 195 mConnectCounter++; 196 197 if (success) { 198 auto queries = mQueries.getAll(); 199 LOG(DEBUG) << "Initialization succeeded. Reissuing " << queries.size() << " queries."; 200 for(auto& q : queries) { 201 if (!sendQuery(q)) { 202 break; 203 } 204 } 205 } else { 206 LOG(DEBUG) << "Initialization failed."; 207 mSocket.reset(); 208 LOG(DEBUG) << "Failing all pending queries."; 209 mQueries.clear(); 210 } 211 } 212 213 void DnsTlsTransport::onResponse(std::vector<uint8_t> response) { 214 mQueries.onResponse(std::move(response)); 215 } 216 217 void DnsTlsTransport::onClosed() { 218 std::lock_guard guard(mLock); 219 if (mClosing) { 220 return; 221 } 222 // Move remaining operations to a new thread. 223 // This is necessary because 224 // 1. onClosed is currently running on a thread that blocks mSocket's destructor 225 // 2. doReconnect will call that destructor 226 if (mReconnectThread) { 227 // Complete cleanup of a previous reconnect thread, if present. 228 mReconnectThread->join(); 229 // Joining a thread that is trying to acquire mLock, while holding mLock, 230 // looks like it risks a deadlock. However, a deadlock will not occur because 231 // once onClosed is called, it cannot be called again until after doReconnect 232 // acquires mLock. 233 } 234 mReconnectThread.reset(new std::thread(&DnsTlsTransport::doReconnect, this)); 235 } 236 237 void DnsTlsTransport::doReconnect() { 238 std::lock_guard guard(mLock); 239 setThreadName(StringPrintf("TlsReconn_%u", mMark & 0xffff).c_str()); 240 if (mClosing) { 241 return; 242 } 243 mQueries.cleanup(); 244 if (!mQueries.empty()) { 245 LOG(DEBUG) << "Fast reconnect to retry remaining queries"; 246 doConnect(); 247 } else { 248 LOG(DEBUG) << "No pending queries. Going idle."; 249 mSocket.reset(); 250 } 251 } 252 253 DnsTlsTransport::~DnsTlsTransport() { 254 LOG(DEBUG) << "Destructor"; 255 { 256 std::lock_guard guard(mLock); 257 LOG(DEBUG) << "Locked destruction procedure"; 258 mQueries.clear(); 259 mClosing = true; 260 } 261 // It's possible that a reconnect thread was spawned and waiting for mLock. 262 // It's safe for that thread to run now because mClosing is true (and mQueries is empty), 263 // but we need to wait for it to finish before allowing destruction to proceed. 264 if (mReconnectThread) { 265 LOG(DEBUG) << "Waiting for reconnect thread to terminate"; 266 mReconnectThread->join(); 267 mReconnectThread.reset(); 268 } 269 // Ensure that the socket is destroyed, and can clean up its callback threads, 270 // before any of this object's fields become invalid. 271 mSocket.reset(); 272 LOG(DEBUG) << "Destructor completed"; 273 } 274 275 // static 276 // TODO: Use this function to preheat the session cache. 277 // That may require moving it to DnsTlsDispatcher. 278 bool DnsTlsTransport::validate(const DnsTlsServer& server, uint32_t mark) { 279 LOG(DEBUG) << "Beginning validation with mark " << std::hex << mark; 280 281 const std::vector<uint8_t> query = makeDnsQuery(); 282 DnsTlsSocketFactory factory; 283 DnsTlsTransport transport(server, mark, &factory); 284 285 // Send the initial query to warm up the connection. 286 auto r = transport.query(netdutils::makeSlice(query)).get(); 287 if (r.code != Response::success) { 288 LOG(WARNING) << "query failed"; 289 return false; 290 } 291 292 if (auto result = checkDnsResponse(r.response); !result.ok()) { 293 LOG(WARNING) << "checkDnsResponse failed: " << result.error().message(); 294 return false; 295 } 296 297 // If this validation is not for opportunistic mode, or the flags are not properly set, 298 // the validation is done. If not, the validation will compare DoT probe latency and 299 // UDP probe latency, and it will pass if: 300 // dot_probe_latency < latencyFactor * udp_probe_latency + latencyOffsetMs 301 // 302 // For instance, with latencyFactor = 3 and latencyOffsetMs = 10, if UDP probe latency is 5 ms, 303 // DoT probe latency must less than 25 ms. 304 const bool versionHigherThanAndroidR = getApiLevel() >= 30; 305 int latencyFactor = Experiments::getInstance()->getFlag("dot_validation_latency_factor", 306 (versionHigherThanAndroidR ? 3 : -1)); 307 int latencyOffsetMs = Experiments::getInstance()->getFlag( 308 "dot_validation_latency_offset_ms", (versionHigherThanAndroidR ? 100 : -1)); 309 const bool shouldCompareUdpLatency = 310 server.name.empty() && 311 (latencyFactor >= 0 && latencyOffsetMs >= 0 && latencyFactor + latencyOffsetMs != 0); 312 if (!shouldCompareUdpLatency) { 313 return true; 314 } 315 316 LOG(INFO) << fmt::format("Use flags: latencyFactor={}, latencyOffsetMs={}", latencyFactor, 317 latencyOffsetMs); 318 319 int64_t udpProbeTimeUs = 0; 320 bool udpProbeGotAnswer = false; 321 std::thread udpProbeThread([&] { 322 // Can issue another probe if the first one fails or is lost. 323 for (int i = 1; i < 3; i++) { 324 netdutils::Stopwatch stopwatch; 325 auto result = sendUdpQuery(server.addr().ip(), mark, query); 326 udpProbeTimeUs = stopwatch.timeTakenUs(); 327 udpProbeGotAnswer = result.ok(); 328 LOG(INFO) << fmt::format("UDP probe for {} {}, took {:.3f}ms", server.toIpString(), 329 (udpProbeGotAnswer ? "succeeded" : "failed"), 330 udpProbeTimeUs / 1000.0); 331 332 if (udpProbeGotAnswer) { 333 break; 334 } 335 LOG(WARNING) << "sendUdpQuery attempt " << i << " failed: " << result.error().message(); 336 } 337 }); 338 339 int64_t dotProbeTimeUs = 0; 340 bool dotProbeGotAnswer = false; 341 std::thread dotProbeThread([&] { 342 netdutils::Stopwatch stopwatch; 343 auto r = transport.query(netdutils::makeSlice(query)).get(); 344 dotProbeTimeUs = stopwatch.timeTakenUs(); 345 346 if (r.code != Response::success) { 347 LOG(WARNING) << "query failed"; 348 } else { 349 if (auto result = checkDnsResponse(r.response); !result.ok()) { 350 LOG(WARNING) << "checkDnsResponse failed: " << result.error().message(); 351 } else { 352 dotProbeGotAnswer = true; 353 } 354 } 355 356 LOG(INFO) << fmt::format("DoT probe for {} {}, took {:.3f}ms", server.toIpString(), 357 (dotProbeGotAnswer ? "succeeded" : "failed"), 358 dotProbeTimeUs / 1000.0); 359 }); 360 361 // TODO: If DoT probe thread finishes before UDP probe thread and dotProbeGotAnswer is false, 362 // actively cancel UDP probe thread. 363 dotProbeThread.join(); 364 udpProbeThread.join(); 365 366 if (!dotProbeGotAnswer) return false; 367 if (!udpProbeGotAnswer) return true; 368 return dotProbeTimeUs < (latencyFactor * udpProbeTimeUs + latencyOffsetMs * 1000); 369 } 370 371 } // end of namespace net 372 } // end of namespace android 373