1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "resolv"
18 
19 #include "DnsTlsTransport.h"
20 
21 #include <span>
22 
23 #include <android-base/format.h>
24 #include <android-base/logging.h>
25 #include <android-base/result.h>
26 #include <arpa/inet.h>
27 #include <arpa/nameser.h>
28 #include <netdutils/Stopwatch.h>
29 #include <netdutils/ThreadUtil.h>
30 #include <private/android_filesystem_config.h>  // AID_DNS
31 #include <sys/poll.h>
32 
33 #include "DnsTlsSocketFactory.h"
34 #include "Experiments.h"
35 #include "IDnsTlsSocketFactory.h"
36 #include "resolv_private.h"
37 #include "util.h"
38 
39 using android::netdutils::setThreadName;
40 
41 namespace android {
42 namespace net {
43 
44 namespace {
45 
46 // Make a DNS query for the hostname "<random>-dnsotls-ds.metric.gstatic.com".
makeDnsQuery()47 std::vector<uint8_t> makeDnsQuery() {
48     static const char kDnsSafeChars[] =
49             "abcdefhijklmnopqrstuvwxyz"
50             "ABCDEFHIJKLMNOPQRSTUVWXYZ"
51             "0123456789";
52     const auto c = [](uint8_t rnd) -> uint8_t {
53         return kDnsSafeChars[(rnd % std::size(kDnsSafeChars))];
54     };
55     uint8_t rnd[8];
56     arc4random_buf(rnd, std::size(rnd));
57 
58     return std::vector<uint8_t>{
59             rnd[6], rnd[7],  // [0-1]   query ID
60             1,      0,       // [2-3]   flags; query[2] = 1 for recursion desired (RD).
61             0,      1,       // [4-5]   QDCOUNT (number of queries)
62             0,      0,       // [6-7]   ANCOUNT (number of answers)
63             0,      0,       // [8-9]   NSCOUNT (number of name server records)
64             0,      0,       // [10-11] ARCOUNT (number of additional records)
65             17,     c(rnd[0]), c(rnd[1]), c(rnd[2]), c(rnd[3]), c(rnd[4]), c(rnd[5]), '-', 'd', 'n',
66             's',    'o',       't',       'l',       's',       '-',       'd',       's', 6,   'm',
67             'e',    't',       'r',       'i',       'c',       7,         'g',       's', 't', 'a',
68             't',    'i',       'c',       3,         'c',       'o',       'm',
69             0,                  // null terminator of FQDN (root TLD)
70             0,      ns_t_aaaa,  // QTYPE
71             0,      ns_c_in     // QCLASS
72     };
73 }
74 
checkDnsResponse(const std::span<const uint8_t> answer)75 base::Result<void> checkDnsResponse(const std::span<const uint8_t> answer) {
76     if (answer.size() < NS_HFIXEDSZ) {
77         return Errorf("short response: {}", answer.size());
78     }
79 
80     const int qdcount = (answer[4] << 8) | answer[5];
81     if (qdcount != 1) {
82         return Errorf("reply query count != 1: {}", qdcount);
83     }
84 
85     const int ancount = (answer[6] << 8) | answer[7];
86     LOG(DEBUG) << "answer count: " << ancount;
87 
88     // TODO: Further validate the response contents (check for valid AAAA record, ...).
89     // Note that currently, integration tests rely on this function accepting a
90     // response with zero records.
91 
92     return {};
93 }
94 
95 // Sends |query| to the given server, and returns the DNS response.
sendUdpQuery(netdutils::IPAddress ip,uint32_t mark,std::span<const uint8_t> query)96 base::Result<void> sendUdpQuery(netdutils::IPAddress ip, uint32_t mark,
97                                 std::span<const uint8_t> query) {
98     const sockaddr_storage ss = netdutils::IPSockAddr(ip, 53);
99     const sockaddr* nsap = reinterpret_cast<const sockaddr*>(&ss);
100     const int nsaplen = sockaddrSize(nsap);
101     const int sockType = SOCK_DGRAM | SOCK_NONBLOCK | SOCK_CLOEXEC;
102     android::base::unique_fd fd{socket(nsap->sa_family, sockType, 0)};
103     if (fd < 0) {
104         return ErrnoErrorf("socket failed");
105     }
106 
107     resolv_tag_socket(fd.get(), AID_DNS, NET_CONTEXT_INVALID_PID);
108     if (setsockopt(fd.get(), SOL_SOCKET, SO_MARK, &mark, sizeof(mark)) < 0) {
109         return ErrnoErrorf("setsockopt failed");
110     }
111 
112     if (connect(fd.get(), nsap, (socklen_t)nsaplen) < 0) {
113         return ErrnoErrorf("connect failed");
114     }
115 
116     if (send(fd, query.data(), query.size(), 0) != static_cast<ptrdiff_t>(query.size())) {
117         return ErrnoErrorf("send failed");
118     }
119 
120     const int timeoutMs = 3000;
121     while (true) {
122         pollfd fds = {.fd = fd, .events = POLLIN};
123 
124         const int n = TEMP_FAILURE_RETRY(poll(&fds, 1, timeoutMs));
125         if (n == 0) {
126             return Errorf("poll timed out");
127         }
128         if (n < 0) {
129             return ErrnoErrorf("poll failed");
130         }
131         if (fds.revents & (POLLIN | POLLERR)) {
132             std::vector<uint8_t> buf(MAXPACKET);
133             const int resplen = recv(fd, buf.data(), buf.size(), 0);
134 
135             if (resplen < 0) {
136                 return ErrnoErrorf("recvfrom failed");
137             }
138 
139             buf.resize(resplen);
140             if (auto result = checkDnsResponse(buf); !result.ok()) {
141                 return Errorf("checkDnsResponse failed: {}", result.error().message());
142             }
143 
144             return {};
145         }
146     }
147 }
148 
149 }  // namespace
150 
query(const netdutils::Slice query)151 std::future<DnsTlsTransport::Result> DnsTlsTransport::query(const netdutils::Slice query) {
152     std::lock_guard guard(mLock);
153 
154     auto record = mQueries.recordQuery(query);
155     if (!record) {
156         return std::async(std::launch::deferred, []{
157             return (Result) { .code = Response::internal_error };
158         });
159     }
160 
161     if (!mSocket) {
162         LOG(DEBUG) << "No socket for query.  Opening socket and sending.";
163         doConnect();
164     } else {
165         sendQuery(record->query);
166     }
167 
168     return std::move(record->result);
169 }
170 
getConnectCounter() const171 int DnsTlsTransport::getConnectCounter() const {
172     std::lock_guard guard(mLock);
173     return mConnectCounter;
174 }
175 
sendQuery(const DnsTlsQueryMap::Query & q)176 bool DnsTlsTransport::sendQuery(const DnsTlsQueryMap::Query& q) {
177     // Strip off the ID number and send the new ID instead.
178     const bool sent = mSocket->query(q.newId, netdutils::drop(netdutils::makeSlice(q.query), 2));
179     if (sent) {
180         mQueries.markTried(q.newId);
181     }
182     return sent;
183 }
184 
doConnect()185 void DnsTlsTransport::doConnect() {
186     LOG(DEBUG) << "Constructing new socket";
187     mSocket = mFactory->createDnsTlsSocket(mServer, mMark, this, &mCache);
188 
189     bool success = true;
190     if (mSocket.get() == nullptr || !mSocket->startHandshake()) {
191         success = false;
192     }
193     mConnectCounter++;
194 
195     if (success) {
196         auto queries = mQueries.getAll();
197         LOG(DEBUG) << "Initialization succeeded.  Reissuing " << queries.size() << " queries.";
198         for(auto& q : queries) {
199             if (!sendQuery(q)) {
200                 break;
201             }
202         }
203     } else {
204         LOG(DEBUG) << "Initialization failed.";
205         mSocket.reset();
206         LOG(DEBUG) << "Failing all pending queries.";
207         mQueries.clear();
208     }
209 }
210 
onResponse(std::vector<uint8_t> response)211 void DnsTlsTransport::onResponse(std::vector<uint8_t> response) {
212     mQueries.onResponse(std::move(response));
213 }
214 
onClosed()215 void DnsTlsTransport::onClosed() {
216     std::lock_guard guard(mLock);
217     if (mClosing) {
218         return;
219     }
220     // Move remaining operations to a new thread.
221     // This is necessary because
222     // 1. onClosed is currently running on a thread that blocks mSocket's destructor
223     // 2. doReconnect will call that destructor
224     if (mReconnectThread) {
225         // Complete cleanup of a previous reconnect thread, if present.
226         mReconnectThread->join();
227         // Joining a thread that is trying to acquire mLock, while holding mLock,
228         // looks like it risks a deadlock.  However, a deadlock will not occur because
229         // once onClosed is called, it cannot be called again until after doReconnect
230         // acquires mLock.
231     }
232     mReconnectThread.reset(new std::thread(&DnsTlsTransport::doReconnect, this));
233 }
234 
doReconnect()235 void DnsTlsTransport::doReconnect() {
236     std::lock_guard guard(mLock);
237     setThreadName(fmt::format("TlsReconn_{}", mMark & 0xffff));
238     if (mClosing) {
239         return;
240     }
241     mQueries.cleanup();
242     if (!mQueries.empty()) {
243         LOG(DEBUG) << "Fast reconnect to retry remaining queries";
244         doConnect();
245     } else {
246         LOG(DEBUG) << "No pending queries.  Going idle.";
247         mSocket.reset();
248     }
249 }
250 
~DnsTlsTransport()251 DnsTlsTransport::~DnsTlsTransport() {
252     LOG(DEBUG) << "Destructor";
253     {
254         std::lock_guard guard(mLock);
255         LOG(DEBUG) << "Locked destruction procedure";
256         mQueries.clear();
257         mClosing = true;
258     }
259     // It's possible that a reconnect thread was spawned and waiting for mLock.
260     // It's safe for that thread to run now because mClosing is true (and mQueries is empty),
261     // but we need to wait for it to finish before allowing destruction to proceed.
262     if (mReconnectThread) {
263         LOG(DEBUG) << "Waiting for reconnect thread to terminate";
264         mReconnectThread->join();
265         mReconnectThread.reset();
266     }
267     // Ensure that the socket is destroyed, and can clean up its callback threads,
268     // before any of this object's fields become invalid.
269     mSocket.reset();
270     LOG(DEBUG) << "Destructor completed";
271 }
272 
273 // static
274 // TODO: Use this function to preheat the session cache.
275 // That may require moving it to DnsTlsDispatcher.
validate(const DnsTlsServer & server,uint32_t mark)276 bool DnsTlsTransport::validate(const DnsTlsServer& server, uint32_t mark) {
277     LOG(DEBUG) << "Beginning validation with mark " << std::hex << mark;
278 
279     const std::vector<uint8_t> query = makeDnsQuery();
280     DnsTlsSocketFactory factory;
281     DnsTlsTransport transport(server, mark, &factory);
282 
283     // Send the initial query to warm up the connection.
284     auto r = transport.query(netdutils::makeSlice(query)).get();
285     if (r.code != Response::success) {
286         LOG(WARNING) << "query failed";
287         return false;
288     }
289 
290     if (auto result = checkDnsResponse(r.response); !result.ok()) {
291         LOG(WARNING) << "checkDnsResponse failed: " << result.error().message();
292         return false;
293     }
294 
295     // If this validation is not for opportunistic mode, or the flags are not properly set,
296     // the validation is done. If not, the validation will compare DoT probe latency and
297     // UDP probe latency, and it will pass if:
298     //   dot_probe_latency < latencyFactor * udp_probe_latency + latencyOffsetMs
299     //
300     // For instance, with latencyFactor = 3 and latencyOffsetMs = 10, if UDP probe latency is 5 ms,
301     // DoT probe latency must less than 25 ms.
302     const bool isAtLeastR = getApiLevel() >= 30;
303     int latencyFactor = Experiments::getInstance()->getFlag("dot_validation_latency_factor",
304                                                             (isAtLeastR ? 3 : -1));
305     int latencyOffsetMs = Experiments::getInstance()->getFlag("dot_validation_latency_offset_ms",
306                                                               (isAtLeastR ? 100 : -1));
307     const bool shouldCompareUdpLatency =
308             server.name.empty() &&
309             (latencyFactor >= 0 && latencyOffsetMs >= 0 && latencyFactor + latencyOffsetMs != 0);
310     if (!shouldCompareUdpLatency) {
311         return true;
312     }
313 
314     LOG(INFO) << fmt::format("Use flags: latencyFactor={}, latencyOffsetMs={}", latencyFactor,
315                              latencyOffsetMs);
316 
317     int64_t udpProbeTimeUs = 0;
318     bool udpProbeGotAnswer = false;
319     std::thread udpProbeThread([&] {
320         // Can issue another probe if the first one fails or is lost.
321         for (int i = 1; i < 3; i++) {
322             netdutils::Stopwatch stopwatch;
323             auto result = sendUdpQuery(server.addr().ip(), mark, query);
324             udpProbeTimeUs = stopwatch.timeTakenUs();
325             udpProbeGotAnswer = result.ok();
326             LOG(INFO) << fmt::format("UDP probe for {} {}, took {:.3f}ms", server.toIpString(),
327                                      (udpProbeGotAnswer ? "succeeded" : "failed"),
328                                      udpProbeTimeUs / 1000.0);
329 
330             if (udpProbeGotAnswer) {
331                 break;
332             }
333             LOG(WARNING) << "sendUdpQuery attempt " << i << " failed: " << result.error().message();
334         }
335     });
336 
337     int64_t dotProbeTimeUs = 0;
338     bool dotProbeGotAnswer = false;
339     std::thread dotProbeThread([&] {
340         netdutils::Stopwatch stopwatch;
341         auto r = transport.query(netdutils::makeSlice(query)).get();
342         dotProbeTimeUs = stopwatch.timeTakenUs();
343 
344         if (r.code != Response::success) {
345             LOG(WARNING) << "query failed";
346         } else {
347             if (auto result = checkDnsResponse(r.response); !result.ok()) {
348                 LOG(WARNING) << "checkDnsResponse failed: " << result.error().message();
349             } else {
350                 dotProbeGotAnswer = true;
351             }
352         }
353 
354         LOG(INFO) << fmt::format("DoT probe for {} {}, took {:.3f}ms", server.toIpString(),
355                                  (dotProbeGotAnswer ? "succeeded" : "failed"),
356                                  dotProbeTimeUs / 1000.0);
357     });
358 
359     // TODO: If DoT probe thread finishes before UDP probe thread and dotProbeGotAnswer is false,
360     // actively cancel UDP probe thread.
361     dotProbeThread.join();
362     udpProbeThread.join();
363 
364     if (!dotProbeGotAnswer) return false;
365     if (!udpProbeGotAnswer) return true;
366     return dotProbeTimeUs < (latencyFactor * udpProbeTimeUs + latencyOffsetMs * 1000);
367 }
368 
369 }  // end of namespace net
370 }  // end of namespace android
371