1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "resolv"
18
19 #include "DnsTlsTransport.h"
20
21 #include <span>
22
23 #include <android-base/format.h>
24 #include <android-base/logging.h>
25 #include <android-base/result.h>
26 #include <arpa/inet.h>
27 #include <arpa/nameser.h>
28 #include <netdutils/Stopwatch.h>
29 #include <netdutils/ThreadUtil.h>
30 #include <private/android_filesystem_config.h> // AID_DNS
31 #include <sys/poll.h>
32
33 #include "DnsTlsSocketFactory.h"
34 #include "Experiments.h"
35 #include "IDnsTlsSocketFactory.h"
36 #include "resolv_private.h"
37 #include "util.h"
38
39 using android::netdutils::setThreadName;
40
41 namespace android {
42 namespace net {
43
44 namespace {
45
46 // Make a DNS query for the hostname "<random>-dnsotls-ds.metric.gstatic.com".
makeDnsQuery()47 std::vector<uint8_t> makeDnsQuery() {
48 static const char kDnsSafeChars[] =
49 "abcdefhijklmnopqrstuvwxyz"
50 "ABCDEFHIJKLMNOPQRSTUVWXYZ"
51 "0123456789";
52 const auto c = [](uint8_t rnd) -> uint8_t {
53 return kDnsSafeChars[(rnd % std::size(kDnsSafeChars))];
54 };
55 uint8_t rnd[8];
56 arc4random_buf(rnd, std::size(rnd));
57
58 return std::vector<uint8_t>{
59 rnd[6], rnd[7], // [0-1] query ID
60 1, 0, // [2-3] flags; query[2] = 1 for recursion desired (RD).
61 0, 1, // [4-5] QDCOUNT (number of queries)
62 0, 0, // [6-7] ANCOUNT (number of answers)
63 0, 0, // [8-9] NSCOUNT (number of name server records)
64 0, 0, // [10-11] ARCOUNT (number of additional records)
65 17, c(rnd[0]), c(rnd[1]), c(rnd[2]), c(rnd[3]), c(rnd[4]), c(rnd[5]), '-', 'd', 'n',
66 's', 'o', 't', 'l', 's', '-', 'd', 's', 6, 'm',
67 'e', 't', 'r', 'i', 'c', 7, 'g', 's', 't', 'a',
68 't', 'i', 'c', 3, 'c', 'o', 'm',
69 0, // null terminator of FQDN (root TLD)
70 0, ns_t_aaaa, // QTYPE
71 0, ns_c_in // QCLASS
72 };
73 }
74
checkDnsResponse(const std::span<const uint8_t> answer)75 base::Result<void> checkDnsResponse(const std::span<const uint8_t> answer) {
76 if (answer.size() < NS_HFIXEDSZ) {
77 return Errorf("short response: {}", answer.size());
78 }
79
80 const int qdcount = (answer[4] << 8) | answer[5];
81 if (qdcount != 1) {
82 return Errorf("reply query count != 1: {}", qdcount);
83 }
84
85 const int ancount = (answer[6] << 8) | answer[7];
86 LOG(DEBUG) << "answer count: " << ancount;
87
88 // TODO: Further validate the response contents (check for valid AAAA record, ...).
89 // Note that currently, integration tests rely on this function accepting a
90 // response with zero records.
91
92 return {};
93 }
94
95 // Sends |query| to the given server, and returns the DNS response.
sendUdpQuery(netdutils::IPAddress ip,uint32_t mark,std::span<const uint8_t> query)96 base::Result<void> sendUdpQuery(netdutils::IPAddress ip, uint32_t mark,
97 std::span<const uint8_t> query) {
98 const sockaddr_storage ss = netdutils::IPSockAddr(ip, 53);
99 const sockaddr* nsap = reinterpret_cast<const sockaddr*>(&ss);
100 const int nsaplen = sockaddrSize(nsap);
101 const int sockType = SOCK_DGRAM | SOCK_NONBLOCK | SOCK_CLOEXEC;
102 android::base::unique_fd fd{socket(nsap->sa_family, sockType, 0)};
103 if (fd < 0) {
104 return ErrnoErrorf("socket failed");
105 }
106
107 resolv_tag_socket(fd.get(), AID_DNS, NET_CONTEXT_INVALID_PID);
108 if (setsockopt(fd.get(), SOL_SOCKET, SO_MARK, &mark, sizeof(mark)) < 0) {
109 return ErrnoErrorf("setsockopt failed");
110 }
111
112 if (connect(fd.get(), nsap, (socklen_t)nsaplen) < 0) {
113 return ErrnoErrorf("connect failed");
114 }
115
116 if (send(fd, query.data(), query.size(), 0) != static_cast<ptrdiff_t>(query.size())) {
117 return ErrnoErrorf("send failed");
118 }
119
120 const int timeoutMs = 3000;
121 while (true) {
122 pollfd fds = {.fd = fd, .events = POLLIN};
123
124 const int n = TEMP_FAILURE_RETRY(poll(&fds, 1, timeoutMs));
125 if (n == 0) {
126 return Errorf("poll timed out");
127 }
128 if (n < 0) {
129 return ErrnoErrorf("poll failed");
130 }
131 if (fds.revents & (POLLIN | POLLERR)) {
132 std::vector<uint8_t> buf(MAXPACKET);
133 const int resplen = recv(fd, buf.data(), buf.size(), 0);
134
135 if (resplen < 0) {
136 return ErrnoErrorf("recvfrom failed");
137 }
138
139 buf.resize(resplen);
140 if (auto result = checkDnsResponse(buf); !result.ok()) {
141 return Errorf("checkDnsResponse failed: {}", result.error().message());
142 }
143
144 return {};
145 }
146 }
147 }
148
149 } // namespace
150
query(const netdutils::Slice query)151 std::future<DnsTlsTransport::Result> DnsTlsTransport::query(const netdutils::Slice query) {
152 std::lock_guard guard(mLock);
153
154 auto record = mQueries.recordQuery(query);
155 if (!record) {
156 return std::async(std::launch::deferred, []{
157 return (Result) { .code = Response::internal_error };
158 });
159 }
160
161 if (!mSocket) {
162 LOG(DEBUG) << "No socket for query. Opening socket and sending.";
163 doConnect();
164 } else {
165 sendQuery(record->query);
166 }
167
168 return std::move(record->result);
169 }
170
getConnectCounter() const171 int DnsTlsTransport::getConnectCounter() const {
172 std::lock_guard guard(mLock);
173 return mConnectCounter;
174 }
175
sendQuery(const DnsTlsQueryMap::Query & q)176 bool DnsTlsTransport::sendQuery(const DnsTlsQueryMap::Query& q) {
177 // Strip off the ID number and send the new ID instead.
178 const bool sent = mSocket->query(q.newId, netdutils::drop(netdutils::makeSlice(q.query), 2));
179 if (sent) {
180 mQueries.markTried(q.newId);
181 }
182 return sent;
183 }
184
doConnect()185 void DnsTlsTransport::doConnect() {
186 LOG(DEBUG) << "Constructing new socket";
187 mSocket = mFactory->createDnsTlsSocket(mServer, mMark, this, &mCache);
188
189 bool success = true;
190 if (mSocket.get() == nullptr || !mSocket->startHandshake()) {
191 success = false;
192 }
193 mConnectCounter++;
194
195 if (success) {
196 auto queries = mQueries.getAll();
197 LOG(DEBUG) << "Initialization succeeded. Reissuing " << queries.size() << " queries.";
198 for(auto& q : queries) {
199 if (!sendQuery(q)) {
200 break;
201 }
202 }
203 } else {
204 LOG(DEBUG) << "Initialization failed.";
205 mSocket.reset();
206 LOG(DEBUG) << "Failing all pending queries.";
207 mQueries.clear();
208 }
209 }
210
onResponse(std::vector<uint8_t> response)211 void DnsTlsTransport::onResponse(std::vector<uint8_t> response) {
212 mQueries.onResponse(std::move(response));
213 }
214
onClosed()215 void DnsTlsTransport::onClosed() {
216 std::lock_guard guard(mLock);
217 if (mClosing) {
218 return;
219 }
220 // Move remaining operations to a new thread.
221 // This is necessary because
222 // 1. onClosed is currently running on a thread that blocks mSocket's destructor
223 // 2. doReconnect will call that destructor
224 if (mReconnectThread) {
225 // Complete cleanup of a previous reconnect thread, if present.
226 mReconnectThread->join();
227 // Joining a thread that is trying to acquire mLock, while holding mLock,
228 // looks like it risks a deadlock. However, a deadlock will not occur because
229 // once onClosed is called, it cannot be called again until after doReconnect
230 // acquires mLock.
231 }
232 mReconnectThread.reset(new std::thread(&DnsTlsTransport::doReconnect, this));
233 }
234
doReconnect()235 void DnsTlsTransport::doReconnect() {
236 std::lock_guard guard(mLock);
237 setThreadName(fmt::format("TlsReconn_{}", mMark & 0xffff));
238 if (mClosing) {
239 return;
240 }
241 mQueries.cleanup();
242 if (!mQueries.empty()) {
243 LOG(DEBUG) << "Fast reconnect to retry remaining queries";
244 doConnect();
245 } else {
246 LOG(DEBUG) << "No pending queries. Going idle.";
247 mSocket.reset();
248 }
249 }
250
~DnsTlsTransport()251 DnsTlsTransport::~DnsTlsTransport() {
252 LOG(DEBUG) << "Destructor";
253 {
254 std::lock_guard guard(mLock);
255 LOG(DEBUG) << "Locked destruction procedure";
256 mQueries.clear();
257 mClosing = true;
258 }
259 // It's possible that a reconnect thread was spawned and waiting for mLock.
260 // It's safe for that thread to run now because mClosing is true (and mQueries is empty),
261 // but we need to wait for it to finish before allowing destruction to proceed.
262 if (mReconnectThread) {
263 LOG(DEBUG) << "Waiting for reconnect thread to terminate";
264 mReconnectThread->join();
265 mReconnectThread.reset();
266 }
267 // Ensure that the socket is destroyed, and can clean up its callback threads,
268 // before any of this object's fields become invalid.
269 mSocket.reset();
270 LOG(DEBUG) << "Destructor completed";
271 }
272
273 // static
274 // TODO: Use this function to preheat the session cache.
275 // That may require moving it to DnsTlsDispatcher.
validate(const DnsTlsServer & server,uint32_t mark)276 bool DnsTlsTransport::validate(const DnsTlsServer& server, uint32_t mark) {
277 LOG(DEBUG) << "Beginning validation with mark " << std::hex << mark;
278
279 const std::vector<uint8_t> query = makeDnsQuery();
280 DnsTlsSocketFactory factory;
281 DnsTlsTransport transport(server, mark, &factory);
282
283 // Send the initial query to warm up the connection.
284 auto r = transport.query(netdutils::makeSlice(query)).get();
285 if (r.code != Response::success) {
286 LOG(WARNING) << "query failed";
287 return false;
288 }
289
290 if (auto result = checkDnsResponse(r.response); !result.ok()) {
291 LOG(WARNING) << "checkDnsResponse failed: " << result.error().message();
292 return false;
293 }
294
295 // If this validation is not for opportunistic mode, or the flags are not properly set,
296 // the validation is done. If not, the validation will compare DoT probe latency and
297 // UDP probe latency, and it will pass if:
298 // dot_probe_latency < latencyFactor * udp_probe_latency + latencyOffsetMs
299 //
300 // For instance, with latencyFactor = 3 and latencyOffsetMs = 10, if UDP probe latency is 5 ms,
301 // DoT probe latency must less than 25 ms.
302 const bool isAtLeastR = getApiLevel() >= 30;
303 int latencyFactor = Experiments::getInstance()->getFlag("dot_validation_latency_factor",
304 (isAtLeastR ? 3 : -1));
305 int latencyOffsetMs = Experiments::getInstance()->getFlag("dot_validation_latency_offset_ms",
306 (isAtLeastR ? 100 : -1));
307 const bool shouldCompareUdpLatency =
308 server.name.empty() &&
309 (latencyFactor >= 0 && latencyOffsetMs >= 0 && latencyFactor + latencyOffsetMs != 0);
310 if (!shouldCompareUdpLatency) {
311 return true;
312 }
313
314 LOG(INFO) << fmt::format("Use flags: latencyFactor={}, latencyOffsetMs={}", latencyFactor,
315 latencyOffsetMs);
316
317 int64_t udpProbeTimeUs = 0;
318 bool udpProbeGotAnswer = false;
319 std::thread udpProbeThread([&] {
320 // Can issue another probe if the first one fails or is lost.
321 for (int i = 1; i < 3; i++) {
322 netdutils::Stopwatch stopwatch;
323 auto result = sendUdpQuery(server.addr().ip(), mark, query);
324 udpProbeTimeUs = stopwatch.timeTakenUs();
325 udpProbeGotAnswer = result.ok();
326 LOG(INFO) << fmt::format("UDP probe for {} {}, took {:.3f}ms", server.toIpString(),
327 (udpProbeGotAnswer ? "succeeded" : "failed"),
328 udpProbeTimeUs / 1000.0);
329
330 if (udpProbeGotAnswer) {
331 break;
332 }
333 LOG(WARNING) << "sendUdpQuery attempt " << i << " failed: " << result.error().message();
334 }
335 });
336
337 int64_t dotProbeTimeUs = 0;
338 bool dotProbeGotAnswer = false;
339 std::thread dotProbeThread([&] {
340 netdutils::Stopwatch stopwatch;
341 auto r = transport.query(netdutils::makeSlice(query)).get();
342 dotProbeTimeUs = stopwatch.timeTakenUs();
343
344 if (r.code != Response::success) {
345 LOG(WARNING) << "query failed";
346 } else {
347 if (auto result = checkDnsResponse(r.response); !result.ok()) {
348 LOG(WARNING) << "checkDnsResponse failed: " << result.error().message();
349 } else {
350 dotProbeGotAnswer = true;
351 }
352 }
353
354 LOG(INFO) << fmt::format("DoT probe for {} {}, took {:.3f}ms", server.toIpString(),
355 (dotProbeGotAnswer ? "succeeded" : "failed"),
356 dotProbeTimeUs / 1000.0);
357 });
358
359 // TODO: If DoT probe thread finishes before UDP probe thread and dotProbeGotAnswer is false,
360 // actively cancel UDP probe thread.
361 dotProbeThread.join();
362 udpProbeThread.join();
363
364 if (!dotProbeGotAnswer) return false;
365 if (!udpProbeGotAnswer) return true;
366 return dotProbeTimeUs < (latencyFactor * udpProbeTimeUs + latencyOffsetMs * 1000);
367 }
368
369 } // end of namespace net
370 } // end of namespace android
371