1 /* 2 * Copyright (C) 2019 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 * 16 */ 17 18 #define LOG_TAG "resolv" 19 20 #include "DnsStats.h" 21 22 #include <android-base/logging.h> 23 #include <android-base/stringprintf.h> 24 25 namespace android::net { 26 27 using base::StringPrintf; 28 using netdutils::DumpWriter; 29 using netdutils::IPAddress; 30 using netdutils::IPSockAddr; 31 using netdutils::ScopedIndent; 32 using std::chrono::duration_cast; 33 using std::chrono::microseconds; 34 using std::chrono::milliseconds; 35 using std::chrono::seconds; 36 37 namespace { 38 39 static constexpr IPAddress INVALID_IPADDRESS = IPAddress(); 40 41 std::string rcodeToName(int rcode) { 42 // clang-format off 43 switch (rcode) { 44 case NS_R_NO_ERROR: return "NOERROR"; 45 case NS_R_FORMERR: return "FORMERR"; 46 case NS_R_SERVFAIL: return "SERVFAIL"; 47 case NS_R_NXDOMAIN: return "NXDOMAIN"; 48 case NS_R_NOTIMPL: return "NOTIMP"; 49 case NS_R_REFUSED: return "REFUSED"; 50 case NS_R_YXDOMAIN: return "YXDOMAIN"; 51 case NS_R_YXRRSET: return "YXRRSET"; 52 case NS_R_NXRRSET: return "NXRRSET"; 53 case NS_R_NOTAUTH: return "NOTAUTH"; 54 case NS_R_NOTZONE: return "NOTZONE"; 55 case NS_R_INTERNAL_ERROR: return "INTERNAL_ERROR"; 56 case NS_R_TIMEOUT: return "TIMEOUT"; 57 default: return StringPrintf("UNKNOWN(%d)", rcode); 58 } 59 // clang-format on 60 } 61 62 bool ensureNoInvalidIp(const std::vector<IPSockAddr>& servers) { 63 for (const auto& server : servers) { 64 if (server.ip() == INVALID_IPADDRESS || server.port() == 0) { 65 LOG(WARNING) << "Invalid server: " << server; 66 return false; 67 } 68 } 69 return true; 70 } 71 72 } // namespace 73 74 // The comparison ignores the last update time. 75 bool StatsData::operator==(const StatsData& o) const { 76 return std::tie(serverSockAddr, total, rcodeCounts, latencyUs) == 77 std::tie(o.serverSockAddr, o.total, o.rcodeCounts, o.latencyUs); 78 } 79 80 int StatsData::averageLatencyMs() const { 81 return (total == 0) ? 0 : duration_cast<milliseconds>(latencyUs).count() / total; 82 } 83 84 std::string StatsData::toString() const { 85 if (total == 0) return StringPrintf("%s <no data>", serverSockAddr.ip().toString().c_str()); 86 87 const auto now = std::chrono::steady_clock::now(); 88 const int lastUpdateSec = duration_cast<seconds>(now - lastUpdate).count(); 89 std::string buf; 90 for (const auto& [rcode, counts] : rcodeCounts) { 91 if (counts != 0) { 92 buf += StringPrintf("%s:%d ", rcodeToName(rcode).c_str(), counts); 93 } 94 } 95 return StringPrintf("%s (%d, %dms, [%s], %ds)", serverSockAddr.ip().toString().c_str(), total, 96 averageLatencyMs(), buf.c_str(), lastUpdateSec); 97 } 98 99 StatsRecords::StatsRecords(const IPSockAddr& ipSockAddr, size_t size) 100 : mCapacity(size), mStatsData(ipSockAddr) {} 101 102 void StatsRecords::push(const Record& record) { 103 updateStatsData(record, true); 104 mRecords.push_back(record); 105 106 if (mRecords.size() > mCapacity) { 107 updateStatsData(mRecords.front(), false); 108 mRecords.pop_front(); 109 } 110 111 // Update the quality factors. 112 mSkippedCount = 0; 113 114 // Because failures due to no permission can't prove that the quality of DNS server is bad, 115 // skip the penalty update. The average latency, however, has been updated. For short-latency 116 // servers, it will be fine. For long-latency servers, their average latency will be 117 // decreased but the latency-based algorithm will adjust their average latency back to the 118 // right range after few attempts when network is not restricted. 119 // The check is synced from isNetworkRestricted() in res_send.cpp. 120 if (record.linux_errno != EPERM) { 121 updatePenalty(record); 122 } 123 } 124 125 void StatsRecords::updateStatsData(const Record& record, const bool add) { 126 const int rcode = record.rcode; 127 if (add) { 128 mStatsData.total += 1; 129 mStatsData.rcodeCounts[rcode] += 1; 130 mStatsData.latencyUs += record.latencyUs; 131 } else { 132 mStatsData.total -= 1; 133 mStatsData.rcodeCounts[rcode] -= 1; 134 mStatsData.latencyUs -= record.latencyUs; 135 } 136 mStatsData.lastUpdate = std::chrono::steady_clock::now(); 137 } 138 139 void StatsRecords::updatePenalty(const Record& record) { 140 switch (record.rcode) { 141 case NS_R_NO_ERROR: 142 case NS_R_NXDOMAIN: 143 case NS_R_NOTAUTH: 144 mPenalty = 0; 145 return; 146 default: 147 // NS_R_TIMEOUT and NS_R_INTERNAL_ERROR are in this case. 148 if (mPenalty == 0) { 149 mPenalty = 100; 150 } else { 151 // The evaluated quality drops more quickly when continuous failures happen. 152 mPenalty = std::min(mPenalty * 2, kMaxQuality); 153 } 154 return; 155 } 156 } 157 158 double StatsRecords::score() const { 159 const int avgRtt = mStatsData.averageLatencyMs(); 160 161 // Set the lower bound to -1 in case of "avgRtt + mPenalty < mSkippedCount" 162 // 1) when the server doesn't have any stats yet. 163 // 2) when the sorting has been disabled while it was enabled before. 164 int quality = std::clamp(avgRtt + mPenalty - mSkippedCount, -1, kMaxQuality); 165 166 // Normalization. 167 return static_cast<double>(kMaxQuality - quality) * 100 / kMaxQuality; 168 } 169 170 void StatsRecords::incrementSkippedCount() { 171 mSkippedCount = std::min(mSkippedCount + 1, kMaxQuality); 172 } 173 174 bool DnsStats::setServers(const std::vector<netdutils::IPSockAddr>& servers, Protocol protocol) { 175 if (!ensureNoInvalidIp(servers)) return false; 176 177 ServerStatsMap& statsMap = mStats[protocol]; 178 for (const auto& server : servers) { 179 statsMap.try_emplace(server, StatsRecords(server, kLogSize)); 180 } 181 182 // Clean up the map to eliminate the nodes not belonging to the given list of servers. 183 const auto cleanup = [&](ServerStatsMap* statsMap) { 184 ServerStatsMap tmp; 185 for (const auto& server : servers) { 186 if (statsMap->find(server) != statsMap->end()) { 187 tmp.insert(statsMap->extract(server)); 188 } 189 } 190 statsMap->swap(tmp); 191 }; 192 193 cleanup(&statsMap); 194 195 return true; 196 } 197 198 bool DnsStats::addStats(const IPSockAddr& ipSockAddr, const DnsQueryEvent& record) { 199 if (ipSockAddr.ip() == INVALID_IPADDRESS) return false; 200 201 bool added = false; 202 for (auto& [serverSockAddr, statsRecords] : mStats[record.protocol()]) { 203 if (serverSockAddr == ipSockAddr) { 204 const StatsRecords::Record rec = { 205 .rcode = record.rcode(), 206 .linux_errno = record.linux_errno(), 207 .latencyUs = microseconds(record.latency_micros()), 208 }; 209 statsRecords.push(rec); 210 added = true; 211 } else { 212 statsRecords.incrementSkippedCount(); 213 } 214 } 215 216 return added; 217 } 218 219 std::vector<IPSockAddr> DnsStats::getSortedServers(Protocol protocol) const { 220 // DoT unsupported. The handshake overhead is expensive, and the connection will hang for a 221 // while. Need to figure out if it is worth doing for DoT servers. 222 if (protocol == PROTO_DOT) return {}; 223 224 auto it = mStats.find(protocol); 225 if (it == mStats.end()) return {}; 226 227 // Sorting on insertion in decreasing order. 228 std::multimap<double, IPSockAddr, std::greater<double>> sortedData; 229 for (const auto& [ip, statsRecords] : it->second) { 230 sortedData.insert({statsRecords.score(), ip}); 231 } 232 233 std::vector<IPSockAddr> ret; 234 ret.reserve(sortedData.size()); 235 for (auto& [_, v] : sortedData) { 236 ret.push_back(v); // IPSockAddr is trivially-copyable. 237 } 238 239 return ret; 240 } 241 242 std::optional<microseconds> DnsStats::getAverageLatencyUs(Protocol protocol) const { 243 const auto stats = getStats(protocol); 244 245 int count = 0; 246 microseconds sum; 247 for (const auto& v : stats) { 248 count += v.total; 249 sum += v.latencyUs; 250 } 251 252 if (count == 0) return std::nullopt; 253 return sum / count; 254 } 255 256 std::vector<StatsData> DnsStats::getStats(Protocol protocol) const { 257 std::vector<StatsData> ret; 258 259 if (mStats.find(protocol) != mStats.end()) { 260 for (const auto& [_, statsRecords] : mStats.at(protocol)) { 261 ret.push_back(statsRecords.getStatsData()); 262 } 263 } 264 return ret; 265 } 266 267 void DnsStats::dump(DumpWriter& dw) { 268 const auto dumpStatsMap = [&](ServerStatsMap& statsMap) { 269 ScopedIndent indentLog(dw); 270 if (statsMap.size() == 0) { 271 dw.println("<no server>"); 272 return; 273 } 274 for (const auto& [_, statsRecords] : statsMap) { 275 const StatsData& data = statsRecords.getStatsData(); 276 std::string str = data.toString(); 277 str += StringPrintf(" score{%.1f}", statsRecords.score()); 278 dw.println("%s", str.c_str()); 279 } 280 }; 281 282 dw.println("Server statistics: (total, RTT avg, {rcode:counts}, last update)"); 283 ScopedIndent indentStats(dw); 284 285 dw.println("over UDP"); 286 dumpStatsMap(mStats[PROTO_UDP]); 287 288 dw.println("over TLS"); 289 dumpStatsMap(mStats[PROTO_DOT]); 290 291 dw.println("over TCP"); 292 dumpStatsMap(mStats[PROTO_TCP]); 293 } 294 295 } // namespace android::net 296