1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #define LOG_TAG "resolv" 18 19 #include "DnsTlsDispatcher.h" 20 21 #include <netdutils/Stopwatch.h> 22 23 #include "DnsTlsSocketFactory.h" 24 #include "Experiments.h" 25 #include "PrivateDnsConfiguration.h" 26 #include "resolv_cache.h" 27 #include "resolv_private.h" 28 #include "stats.pb.h" 29 30 #include <android-base/logging.h> 31 32 namespace android { 33 namespace net { 34 35 using android::netdutils::IPSockAddr; 36 using android::netdutils::Stopwatch; 37 using netdutils::Slice; 38 39 // static 40 std::mutex DnsTlsDispatcher::sLock; 41 42 DnsTlsDispatcher::DnsTlsDispatcher() { 43 mFactory.reset(new DnsTlsSocketFactory()); 44 } 45 46 DnsTlsDispatcher& DnsTlsDispatcher::getInstance() { 47 static DnsTlsDispatcher instance; 48 return instance; 49 } 50 51 std::list<DnsTlsServer> DnsTlsDispatcher::getOrderedAndUsableServerList( 52 const std::list<DnsTlsServer>& tlsServers, unsigned netId, unsigned mark) { 53 // Our preferred DnsTlsServer order is: 54 // 1) reuse existing IPv6 connections 55 // 2) reuse existing IPv4 connections 56 // 3) establish new IPv6 connections 57 // 4) establish new IPv4 connections 58 std::list<DnsTlsServer> existing6; 59 std::list<DnsTlsServer> existing4; 60 std::list<DnsTlsServer> new6; 61 std::list<DnsTlsServer> new4; 62 63 // Pull out any servers for which we might have existing connections and 64 // place them at the from the list of servers to try. 65 { 66 std::lock_guard guard(sLock); 67 68 for (const auto& tlsServer : tlsServers) { 69 const Key key = std::make_pair(mark, tlsServer); 70 if (const Transport* xport = getTransport(key); xport != nullptr) { 71 // DoT revalidation specific feature. 72 if (!xport->usable()) { 73 // Don't use this xport. It will be removed after timeout 74 // (IDLE_TIMEOUT minutes). 75 LOG(DEBUG) << "Skip using DoT server " << tlsServer.toIpString() << " on " 76 << netId; 77 continue; 78 } 79 80 switch (tlsServer.ss.ss_family) { 81 case AF_INET: 82 existing4.push_back(tlsServer); 83 break; 84 case AF_INET6: 85 existing6.push_back(tlsServer); 86 break; 87 } 88 } else { 89 switch (tlsServer.ss.ss_family) { 90 case AF_INET: 91 new4.push_back(tlsServer); 92 break; 93 case AF_INET6: 94 new6.push_back(tlsServer); 95 break; 96 } 97 } 98 } 99 } 100 101 auto& out = existing6; 102 out.splice(out.cend(), existing4); 103 out.splice(out.cend(), new6); 104 out.splice(out.cend(), new4); 105 return out; 106 } 107 108 DnsTlsTransport::Response DnsTlsDispatcher::query(const std::list<DnsTlsServer>& tlsServers, 109 res_state statp, const Slice query, 110 const Slice ans, int* resplen) { 111 const std::list<DnsTlsServer> servers( 112 getOrderedAndUsableServerList(tlsServers, statp->netid, statp->_mark)); 113 114 if (servers.empty()) LOG(WARNING) << "No usable DnsTlsServers"; 115 116 DnsTlsTransport::Response code = DnsTlsTransport::Response::internal_error; 117 int serverCount = 0; 118 for (const auto& server : servers) { 119 DnsQueryEvent* dnsQueryEvent = 120 statp->event->mutable_dns_query_events()->add_dns_query_event(); 121 122 bool connectTriggered = false; 123 Stopwatch queryStopwatch; 124 code = this->query(server, statp->netid, statp->_mark, query, ans, resplen, 125 &connectTriggered); 126 127 dnsQueryEvent->set_latency_micros(saturate_cast<int32_t>(queryStopwatch.timeTakenUs())); 128 dnsQueryEvent->set_dns_server_index(serverCount++); 129 dnsQueryEvent->set_ip_version(ipFamilyToIPVersion(server.ss.ss_family)); 130 dnsQueryEvent->set_protocol(PROTO_DOT); 131 dnsQueryEvent->set_type(getQueryType(query.base(), query.size())); 132 dnsQueryEvent->set_connected(connectTriggered); 133 134 switch (code) { 135 // These response codes are valid responses and not expected to 136 // change if another server is queried. 137 case DnsTlsTransport::Response::success: 138 dnsQueryEvent->set_rcode( 139 static_cast<NsRcode>(reinterpret_cast<HEADER*>(ans.base())->rcode)); 140 resolv_stats_add(statp->netid, IPSockAddr::toIPSockAddr(server.ss), dnsQueryEvent); 141 return code; 142 case DnsTlsTransport::Response::limit_error: 143 dnsQueryEvent->set_rcode(NS_R_INTERNAL_ERROR); 144 resolv_stats_add(statp->netid, IPSockAddr::toIPSockAddr(server.ss), dnsQueryEvent); 145 return code; 146 // These response codes might differ when trying other servers, so 147 // keep iterating to see if we can get a different (better) result. 148 case DnsTlsTransport::Response::network_error: 149 // Sync from res_tls_send in res_send.cpp 150 dnsQueryEvent->set_rcode(NS_R_TIMEOUT); 151 resolv_stats_add(statp->netid, IPSockAddr::toIPSockAddr(server.ss), dnsQueryEvent); 152 break; 153 case DnsTlsTransport::Response::internal_error: 154 dnsQueryEvent->set_rcode(NS_R_INTERNAL_ERROR); 155 resolv_stats_add(statp->netid, IPSockAddr::toIPSockAddr(server.ss), dnsQueryEvent); 156 break; 157 // No "default" statement. 158 } 159 } 160 161 return code; 162 } 163 164 DnsTlsTransport::Response DnsTlsDispatcher::query(const DnsTlsServer& server, unsigned netId, 165 unsigned mark, const Slice query, const Slice ans, 166 int* resplen, bool* connectTriggered) { 167 // TODO: This can cause the resolver to create multiple connections to the same DoT server 168 // merely due to different mark, such as the bit explicitlySelected unset. 169 // See if we can save them and just create one connection for one DoT server. 170 const Key key = std::make_pair(mark, server); 171 Transport* xport; 172 { 173 std::lock_guard guard(sLock); 174 if (xport = getTransport(key); xport == nullptr) { 175 xport = addTransport(server, mark, netId); 176 } 177 ++xport->useCount; 178 } 179 180 // Don't call this function and hold sLock at the same time because of the following reason: 181 // TLS handshake requires a lock which is also needed by this function, if the handshake gets 182 // stuck, this function also gets blocked. 183 const int connectCounter = xport->transport.getConnectCounter(); 184 185 const auto& result = queryInternal(*xport, query); 186 *connectTriggered = (xport->transport.getConnectCounter() > connectCounter); 187 188 DnsTlsTransport::Response code = result.code; 189 if (code == DnsTlsTransport::Response::success) { 190 if (result.response.size() > ans.size()) { 191 LOG(DEBUG) << "Response too large: " << result.response.size() << " > " << ans.size(); 192 code = DnsTlsTransport::Response::limit_error; 193 } else { 194 LOG(DEBUG) << "Got response successfully"; 195 *resplen = result.response.size(); 196 netdutils::copy(ans, netdutils::makeSlice(result.response)); 197 } 198 } else { 199 LOG(DEBUG) << "Query failed: " << (unsigned int)code; 200 } 201 202 auto now = std::chrono::steady_clock::now(); 203 { 204 std::lock_guard guard(sLock); 205 --xport->useCount; 206 xport->lastUsed = now; 207 208 // DoT revalidation specific feature. 209 if (xport->checkRevalidationNecessary(code)) { 210 // Even if the revalidation passes, it doesn't guarantee that DoT queries 211 // to the xport can stop failing because revalidation creates a new connection 212 // to probe while the xport still uses an existing connection. So far, there isn't 213 // a feasible way to force the xport to disconnect the connection. If the case 214 // happens, the xport will be marked as unusable and DoT queries won't be sent to 215 // it anymore. Eventually, after IDLE_TIMEOUT, the xport will be destroyed, and 216 // a new xport will be created. 217 const auto result = PrivateDnsConfiguration::getInstance().requestValidation( 218 netId, PrivateDnsConfiguration::ServerIdentity{server}, mark); 219 LOG(WARNING) << "Requested validation for " << server.toIpString() << " with mark 0x" 220 << std::hex << mark << ", " 221 << (result.ok() ? "succeeded" : "failed: " + result.error().message()); 222 } 223 224 cleanup(now); 225 } 226 return code; 227 } 228 229 void DnsTlsDispatcher::forceCleanup(unsigned netId) { 230 std::lock_guard guard(sLock); 231 forceCleanupLocked(netId); 232 } 233 234 DnsTlsTransport::Result DnsTlsDispatcher::queryInternal(Transport& xport, 235 const netdutils::Slice query) { 236 LOG(DEBUG) << "Sending query of length " << query.size(); 237 238 // If dot_async_handshake is not set, the call might block in some cases; otherwise, 239 // the call should return very soon. 240 auto res = xport.transport.query(query); 241 LOG(DEBUG) << "Awaiting response"; 242 243 if (xport.timeout().count() == -1) { 244 // Infinite timeout. 245 return res.get(); 246 } 247 248 const auto status = res.wait_for(xport.timeout()); 249 if (status == std::future_status::timeout) { 250 // TODO(b/186613628): notify the Transport to remove this query. 251 LOG(WARNING) << "DoT query timed out after " << xport.timeout().count() << " ms"; 252 return DnsTlsTransport::Result{ 253 .code = DnsTlsTransport::Response::network_error, 254 .response = {}, 255 }; 256 } 257 258 return res.get(); 259 } 260 261 // This timeout effectively controls how long to keep SSL session tickets. 262 static constexpr std::chrono::minutes IDLE_TIMEOUT(5); 263 void DnsTlsDispatcher::cleanup(std::chrono::time_point<std::chrono::steady_clock> now) { 264 // To avoid scanning mStore after every query, return early if a cleanup has been 265 // performed recently. 266 if (now - mLastCleanup < IDLE_TIMEOUT) { 267 return; 268 } 269 for (auto it = mStore.begin(); it != mStore.end();) { 270 auto& s = it->second; 271 if (s->useCount == 0 && now - s->lastUsed > IDLE_TIMEOUT) { 272 it = mStore.erase(it); 273 } else { 274 ++it; 275 } 276 } 277 mLastCleanup = now; 278 } 279 280 // TODO: unify forceCleanupLocked() and cleanup(). 281 void DnsTlsDispatcher::forceCleanupLocked(unsigned netId) { 282 for (auto it = mStore.begin(); it != mStore.end();) { 283 auto& s = it->second; 284 if (s->useCount == 0 && s->mNetId == netId) { 285 it = mStore.erase(it); 286 } else { 287 ++it; 288 } 289 } 290 } 291 292 DnsTlsDispatcher::Transport* DnsTlsDispatcher::addTransport(const DnsTlsServer& server, 293 unsigned mark, unsigned netId) { 294 const Key key = std::make_pair(mark, server); 295 Transport* ret = getTransport(key); 296 if (ret != nullptr) return ret; 297 298 const Experiments* const instance = Experiments::getInstance(); 299 int triggerThr = 300 instance->getFlag("dot_revalidation_threshold", Transport::kDotRevalidationThreshold); 301 int unusableThr = instance->getFlag("dot_xport_unusable_threshold", 302 Transport::kDotXportUnusableThreshold); 303 int queryTimeout = instance->getFlag("dot_query_timeout_ms", Transport::kDotQueryTimeoutMs); 304 305 // Check and adjust the parameters if they are improperly set. 306 bool revalidationEnabled = false; 307 const bool isForOpportunisticMode = server.name.empty(); 308 if (triggerThr > 0 && unusableThr > 0 && isForOpportunisticMode) { 309 revalidationEnabled = true; 310 } else { 311 triggerThr = -1; 312 unusableThr = -1; 313 } 314 if (queryTimeout < 0) { 315 queryTimeout = -1; 316 } else if (queryTimeout < 1000) { 317 queryTimeout = 1000; 318 } 319 320 ret = new Transport(server, mark, netId, mFactory.get(), revalidationEnabled, triggerThr, 321 unusableThr, queryTimeout); 322 LOG(DEBUG) << "Transport is initialized with { " << triggerThr << ", " << unusableThr << ", " 323 << queryTimeout << "ms }" 324 << " for server { " << server.toIpString() << "/" << server.name << " }"; 325 326 mStore[key].reset(ret); 327 328 return ret; 329 } 330 331 DnsTlsDispatcher::Transport* DnsTlsDispatcher::getTransport(const Key& key) { 332 auto it = mStore.find(key); 333 return (it == mStore.end() ? nullptr : it->second.get()); 334 } 335 336 bool DnsTlsDispatcher::Transport::checkRevalidationNecessary(DnsTlsTransport::Response code) { 337 if (!revalidationEnabled) return false; 338 339 if (code == DnsTlsTransport::Response::network_error) { 340 continuousfailureCount++; 341 } else { 342 continuousfailureCount = 0; 343 } 344 345 // triggerThreshold must be greater than 0 because the value of revalidationEnabled is true. 346 if (usable() && continuousfailureCount == triggerThreshold) { 347 return true; 348 } 349 return false; 350 } 351 352 bool DnsTlsDispatcher::Transport::usable() const { 353 if (!revalidationEnabled) return true; 354 355 return continuousfailureCount < unusableThreshold; 356 } 357 358 } // end of namespace net 359 } // end of namespace android 360