1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "resolv"
18 
19 #include "PrivateDnsConfiguration.h"
20 
21 #include <android-base/format.h>
22 #include <android-base/logging.h>
23 #include <android-base/stringprintf.h>
24 #include <netdutils/ThreadUtil.h>
25 #include <sys/socket.h>
26 
27 #include "DnsTlsTransport.h"
28 #include "ResolverEventReporter.h"
29 #include "netd_resolv/resolv.h"
30 #include "util.h"
31 
32 using aidl::android::net::resolv::aidl::IDnsResolverUnsolicitedEventListener;
33 using aidl::android::net::resolv::aidl::PrivateDnsValidationEventParcel;
34 using android::base::StringPrintf;
35 using android::netdutils::setThreadName;
36 using std::chrono::milliseconds;
37 
38 namespace android {
39 namespace net {
40 
41 bool parseServer(const char* server, sockaddr_storage* parsed) {
42     addrinfo hints = {
43             .ai_flags = AI_NUMERICHOST | AI_NUMERICSERV,
44             .ai_family = AF_UNSPEC,
45     };
46     addrinfo* res;
47 
48     int err = getaddrinfo(server, "853", &hints, &res);
49     if (err != 0) {
50         LOG(WARNING) << "Failed to parse server address (" << server << "): " << gai_strerror(err);
51         return false;
52     }
53 
54     memcpy(parsed, res->ai_addr, res->ai_addrlen);
55     freeaddrinfo(res);
56     return true;
57 }
58 
59 int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark,
60                                  const std::vector<std::string>& servers, const std::string& name,
61                                  const std::string& caCert) {
62     LOG(DEBUG) << "PrivateDnsConfiguration::set(" << netId << ", 0x" << std::hex << mark << std::dec
63                << ", " << servers.size() << ", " << name << ")";
64 
65     // Parse the list of servers that has been passed in
66     PrivateDnsTracker tmp;
67     for (const auto& s : servers) {
68         sockaddr_storage parsed;
69         if (!parseServer(s.c_str(), &parsed)) {
70             return -EINVAL;
71         }
72         auto server = std::make_unique<DnsTlsServer>(parsed);
73         server->name = name;
74         server->certificate = caCert;
75         server->mark = mark;
76         tmp[ServerIdentity(*server)] = std::move(server);
77     }
78 
79     std::lock_guard guard(mPrivateDnsLock);
80     if (!name.empty()) {
81         mPrivateDnsModes[netId] = PrivateDnsMode::STRICT;
82     } else if (!tmp.empty()) {
83         mPrivateDnsModes[netId] = PrivateDnsMode::OPPORTUNISTIC;
84     } else {
85         mPrivateDnsModes[netId] = PrivateDnsMode::OFF;
86         mPrivateDnsTransports.erase(netId);
87         // TODO: signal validation threads to stop.
88         return 0;
89     }
90 
91     // Create the tracker if it was not present
92     auto& tracker = mPrivateDnsTransports[netId];
93 
94     // Add the servers if not contained in tracker.
95     for (auto& [identity, server] : tmp) {
96         if (tracker.find(identity) == tracker.end()) {
97             tracker[identity] = std::move(server);
98         }
99     }
100 
101     for (auto& [identity, server] : tracker) {
102         const bool active = tmp.find(identity) != tmp.end();
103         server->setActive(active);
104 
105         // For simplicity, deem the validation result of inactive servers as unreliable.
106         if (!server->active() && server->validationState() == Validation::success) {
107             updateServerState(identity, Validation::success_but_expired, netId);
108         }
109 
110         if (needsValidation(*server)) {
111             updateServerState(identity, Validation::in_process, netId);
112             startValidation(identity, netId, false);
113         }
114     }
115 
116     return 0;
117 }
118 
119 PrivateDnsStatus PrivateDnsConfiguration::getStatus(unsigned netId) const {
120     PrivateDnsStatus status{PrivateDnsMode::OFF, {}};
121     std::lock_guard guard(mPrivateDnsLock);
122 
123     const auto mode = mPrivateDnsModes.find(netId);
124     if (mode == mPrivateDnsModes.end()) return status;
125     status.mode = mode->second;
126 
127     const auto netPair = mPrivateDnsTransports.find(netId);
128     if (netPair != mPrivateDnsTransports.end()) {
129         for (const auto& [_, server] : netPair->second) {
130             if (server->isDot() && server->active()) {
131                 DnsTlsServer& dotServer = *static_cast<DnsTlsServer*>(server.get());
132                 status.serversMap.emplace(dotServer, server->validationState());
133             }
134             // TODO: also add DoH server to the map.
135         }
136     }
137 
138     return status;
139 }
140 
141 void PrivateDnsConfiguration::clear(unsigned netId) {
142     LOG(DEBUG) << "PrivateDnsConfiguration::clear(" << netId << ")";
143     std::lock_guard guard(mPrivateDnsLock);
144     mPrivateDnsModes.erase(netId);
145     mPrivateDnsTransports.erase(netId);
146 }
147 
148 base::Result<void> PrivateDnsConfiguration::requestValidation(unsigned netId,
149                                                               const ServerIdentity& identity,
150                                                               uint32_t mark) {
151     std::lock_guard guard(mPrivateDnsLock);
152 
153     // Running revalidation requires to mark the server as in_process, which means the server
154     // won't be used until the validation passes. It's necessary and safe to run revalidation
155     // when in private DNS opportunistic mode, because there's a fallback mechanics even if
156     // all of the private DNS servers are in in_process state.
157     if (auto it = mPrivateDnsModes.find(netId); it == mPrivateDnsModes.end()) {
158         return Errorf("NetId not found in mPrivateDnsModes");
159     } else if (it->second != PrivateDnsMode::OPPORTUNISTIC) {
160         return Errorf("Private DNS setting is not opportunistic mode");
161     }
162 
163     auto result = getPrivateDnsLocked(identity, netId);
164     if (!result.ok()) {
165         return result.error();
166     }
167 
168     const IPrivateDnsServer* server = result.value();
169 
170     if (!server->active()) return Errorf("Server is not active");
171 
172     if (server->validationState() != Validation::success) {
173         return Errorf("Server validation state mismatched");
174     }
175 
176     // Don't run the validation if |mark| (from android_net_context.dns_mark) is different.
177     // This is to protect validation from running on unexpected marks.
178     // Validation should be associated with a mark gotten by system permission.
179     if (server->validationMark() != mark) return Errorf("Socket mark mismatched");
180 
181     updateServerState(identity, Validation::in_process, netId);
182     startValidation(identity, netId, true);
183     return {};
184 }
185 
186 void PrivateDnsConfiguration::startValidation(const ServerIdentity& identity, unsigned netId,
187                                               bool isRevalidation) {
188     // This ensures that the thread sends probe at least once in case
189     // the server is removed before the thread starts running.
190     // TODO: consider moving these code to the thread.
191     const auto result = getPrivateDnsLocked(identity, netId);
192     if (!result.ok()) return;
193     DnsTlsServer server = *static_cast<const DnsTlsServer*>(result.value());
194 
195     std::thread validate_thread([this, identity, server, netId, isRevalidation] {
196         setThreadName(StringPrintf("TlsVerify_%u", netId).c_str());
197 
198         // cat /proc/sys/net/ipv4/tcp_syn_retries yields "6".
199         //
200         // Start with a 1 minute delay and backoff to once per hour.
201         //
202         // Assumptions:
203         //     [1] Each TLS validation is ~10KB of certs+handshake+payload.
204         //     [2] Network typically provision clients with <=4 nameservers.
205         //     [3] Average month has 30 days.
206         //
207         // Each validation pass in a given hour is ~1.2MB of data. And 24
208         // such validation passes per day is about ~30MB per month, in the
209         // worst case. Otherwise, this will cost ~600 SYNs per month
210         // (6 SYNs per ip, 4 ips per validation pass, 24 passes per day).
211         auto backoff = mBackoffBuilder.build();
212 
213         while (true) {
214             // ::validate() is a blocking call that performs network operations.
215             // It can take milliseconds to minutes, up to the SYN retry limit.
216             LOG(WARNING) << "Validating DnsTlsServer " << server.toIpString() << " with mark 0x"
217                          << std::hex << server.validationMark();
218             const bool success = DnsTlsTransport::validate(server, server.validationMark());
219             LOG(WARNING) << "validateDnsTlsServer returned " << success << " for "
220                          << server.toIpString();
221 
222             const bool needs_reeval =
223                     this->recordPrivateDnsValidation(identity, netId, success, isRevalidation);
224 
225             if (!needs_reeval) {
226                 break;
227             }
228 
229             if (backoff.hasNextTimeout()) {
230                 // TODO: make the thread able to receive signals to shutdown early.
231                 std::this_thread::sleep_for(backoff.getNextTimeout());
232             } else {
233                 break;
234             }
235         }
236     });
237     validate_thread.detach();
238 }
239 
240 void PrivateDnsConfiguration::sendPrivateDnsValidationEvent(const ServerIdentity& identity,
241                                                             unsigned netId, bool success) {
242     LOG(DEBUG) << "Sending validation " << (success ? "success" : "failure") << " event on netId "
243                << netId << " for " << identity.sockaddr.ip().toString() << " with hostname {"
244                << identity.provider << "}";
245     // Send a validation event to NetdEventListenerService.
246     const auto& listeners = ResolverEventReporter::getInstance().getListeners();
247     if (listeners.empty()) {
248         LOG(ERROR)
249                 << "Validation event not sent since no INetdEventListener receiver is available.";
250     }
251     for (const auto& it : listeners) {
252         it->onPrivateDnsValidationEvent(netId, identity.sockaddr.ip().toString(), identity.provider,
253                                         success);
254     }
255 
256     // Send a validation event to unsolicited event listeners.
257     const auto& unsolEventListeners = ResolverEventReporter::getInstance().getUnsolEventListeners();
258     const PrivateDnsValidationEventParcel validationEvent = {
259             .netId = static_cast<int32_t>(netId),
260             .ipAddress = identity.sockaddr.ip().toString(),
261             .hostname = identity.provider,
262             .validation = success ? IDnsResolverUnsolicitedEventListener::VALIDATION_RESULT_SUCCESS
263                                   : IDnsResolverUnsolicitedEventListener::VALIDATION_RESULT_FAILURE,
264     };
265     for (const auto& it : unsolEventListeners) {
266         it->onPrivateDnsValidationEvent(validationEvent);
267     }
268 }
269 
270 bool PrivateDnsConfiguration::recordPrivateDnsValidation(const ServerIdentity& identity,
271                                                          unsigned netId, bool success,
272                                                          bool isRevalidation) {
273     constexpr bool NEEDS_REEVALUATION = true;
274     constexpr bool DONT_REEVALUATE = false;
275 
276     std::lock_guard guard(mPrivateDnsLock);
277 
278     auto netPair = mPrivateDnsTransports.find(netId);
279     if (netPair == mPrivateDnsTransports.end()) {
280         LOG(WARNING) << "netId " << netId << " was erased during private DNS validation";
281         notifyValidationStateUpdate(identity.sockaddr, Validation::fail, netId);
282         return DONT_REEVALUATE;
283     }
284 
285     const auto mode = mPrivateDnsModes.find(netId);
286     if (mode == mPrivateDnsModes.end()) {
287         LOG(WARNING) << "netId " << netId << " has no private DNS validation mode";
288         notifyValidationStateUpdate(identity.sockaddr, Validation::fail, netId);
289         return DONT_REEVALUATE;
290     }
291 
292     bool reevaluationStatus = NEEDS_REEVALUATION;
293     if (success) {
294         reevaluationStatus = DONT_REEVALUATE;
295     } else if (mode->second == PrivateDnsMode::OFF) {
296         reevaluationStatus = DONT_REEVALUATE;
297     } else if (mode->second == PrivateDnsMode::OPPORTUNISTIC && !isRevalidation) {
298         reevaluationStatus = DONT_REEVALUATE;
299     }
300 
301     auto& tracker = netPair->second;
302     auto serverPair = tracker.find(identity);
303     if (serverPair == tracker.end()) {
304         LOG(WARNING) << "Server " << identity.sockaddr.ip().toString()
305                      << " was removed during private DNS validation";
306         success = false;
307         reevaluationStatus = DONT_REEVALUATE;
308     } else if (!serverPair->second->active()) {
309         LOG(WARNING) << "Server " << identity.sockaddr.ip().toString()
310                      << " was removed from the configuration";
311         success = false;
312         reevaluationStatus = DONT_REEVALUATE;
313     }
314 
315     // Send private dns validation result to listeners.
316     sendPrivateDnsValidationEvent(identity, netId, success);
317 
318     if (success) {
319         updateServerState(identity, Validation::success, netId);
320     } else {
321         // Validation failure is expected if a user is on a captive portal.
322         // TODO: Trigger a second validation attempt after captive portal login
323         // succeeds.
324         const auto result = (reevaluationStatus == NEEDS_REEVALUATION) ? Validation::in_process
325                                                                        : Validation::fail;
326         updateServerState(identity, result, netId);
327     }
328     LOG(WARNING) << "Validation " << (success ? "success" : "failed");
329 
330     return reevaluationStatus;
331 }
332 
333 void PrivateDnsConfiguration::updateServerState(const ServerIdentity& identity, Validation state,
334                                                 uint32_t netId) {
335     const auto result = getPrivateDnsLocked(identity, netId);
336     if (!result.ok()) {
337         notifyValidationStateUpdate(identity.sockaddr, Validation::fail, netId);
338         return;
339     }
340 
341     auto* server = result.value();
342 
343     server->setValidationState(state);
344     notifyValidationStateUpdate(identity.sockaddr, state, netId);
345 
346     RecordEntry record(netId, identity, state);
347     mPrivateDnsLog.push(std::move(record));
348 }
349 
350 bool PrivateDnsConfiguration::needsValidation(const IPrivateDnsServer& server) const {
351     // The server is not expected to be used on the network.
352     if (!server.active()) return false;
353 
354     // The server is newly added.
355     if (server.validationState() == Validation::unknown_server) return true;
356 
357     // The server has failed at least one validation attempt. Give it another try.
358     if (server.validationState() == Validation::fail) return true;
359 
360     // The previous validation result might be unreliable.
361     if (server.validationState() == Validation::success_but_expired) return true;
362 
363     return false;
364 }
365 
366 base::Result<IPrivateDnsServer*> PrivateDnsConfiguration::getPrivateDns(
367         const ServerIdentity& identity, unsigned netId) {
368     std::lock_guard guard(mPrivateDnsLock);
369     return getPrivateDnsLocked(identity, netId);
370 }
371 
372 base::Result<IPrivateDnsServer*> PrivateDnsConfiguration::getPrivateDnsLocked(
373         const ServerIdentity& identity, unsigned netId) {
374     auto netPair = mPrivateDnsTransports.find(netId);
375     if (netPair == mPrivateDnsTransports.end()) {
376         return Errorf("Failed to get private DNS: netId {} not found", netId);
377     }
378 
379     auto iter = netPair->second.find(identity);
380     if (iter == netPair->second.end()) {
381         return Errorf("Failed to get private DNS: server {{{}/{}}} not found", identity.sockaddr,
382                       identity.provider);
383     }
384 
385     return iter->second.get();
386 }
387 
388 void PrivateDnsConfiguration::setObserver(PrivateDnsValidationObserver* observer) {
389     std::lock_guard guard(mPrivateDnsLock);
390     mObserver = observer;
391 }
392 
393 void PrivateDnsConfiguration::notifyValidationStateUpdate(const netdutils::IPSockAddr& sockaddr,
394                                                           Validation validation,
395                                                           uint32_t netId) const {
396     if (mObserver) {
397         mObserver->onValidationStateUpdate(sockaddr.ip().toString(), validation, netId);
398     }
399 }
400 
401 void PrivateDnsConfiguration::dump(netdutils::DumpWriter& dw) const {
402     dw.println("PrivateDnsLog:");
403     netdutils::ScopedIndent indentStats(dw);
404 
405     for (const auto& record : mPrivateDnsLog.copy()) {
406         dw.println(fmt::format(
407                 "{} - netId={} PrivateDns={{{}/{}}} state={}", timestampToString(record.timestamp),
408                 record.netId, record.serverIdentity.sockaddr.toString(),
409                 record.serverIdentity.provider, validationStatusToString(record.state)));
410     }
411     dw.blankline();
412 }
413 
414 }  // namespace net
415 }  // namespace android
416