1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  *
16  */
17 
18 #define LOG_TAG "resolv"
19 
20 #include "DnsStats.h"
21 
22 #include <android-base/logging.h>
23 #include <android-base/stringprintf.h>
24 
25 namespace android::net {
26 
27 using base::StringPrintf;
28 using netdutils::DumpWriter;
29 using netdutils::IPAddress;
30 using netdutils::IPSockAddr;
31 using netdutils::ScopedIndent;
32 using std::chrono::duration_cast;
33 using std::chrono::microseconds;
34 using std::chrono::milliseconds;
35 using std::chrono::seconds;
36 
37 namespace {
38 
39 static constexpr IPAddress INVALID_IPADDRESS = IPAddress();
40 
41 std::string rcodeToName(int rcode) {
42     // clang-format off
43     switch (rcode) {
44         case NS_R_NO_ERROR: return "NOERROR";
45         case NS_R_FORMERR: return "FORMERR";
46         case NS_R_SERVFAIL: return "SERVFAIL";
47         case NS_R_NXDOMAIN: return "NXDOMAIN";
48         case NS_R_NOTIMPL: return "NOTIMP";
49         case NS_R_REFUSED: return "REFUSED";
50         case NS_R_YXDOMAIN: return "YXDOMAIN";
51         case NS_R_YXRRSET: return "YXRRSET";
52         case NS_R_NXRRSET: return "NXRRSET";
53         case NS_R_NOTAUTH: return "NOTAUTH";
54         case NS_R_NOTZONE: return "NOTZONE";
55         case NS_R_INTERNAL_ERROR: return "INTERNAL_ERROR";
56         case NS_R_TIMEOUT: return "TIMEOUT";
57         default: return StringPrintf("UNKNOWN(%d)", rcode);
58     }
59     // clang-format on
60 }
61 
62 bool ensureNoInvalidIp(const std::vector<IPSockAddr>& servers) {
63     for (const auto& server : servers) {
64         if (server.ip() == INVALID_IPADDRESS || server.port() == 0) {
65             LOG(WARNING) << "Invalid server: " << server;
66             return false;
67         }
68     }
69     return true;
70 }
71 
72 }  // namespace
73 
74 // The comparison ignores the last update time.
75 bool StatsData::operator==(const StatsData& o) const {
76     return std::tie(serverSockAddr, total, rcodeCounts, latencyUs) ==
77            std::tie(o.serverSockAddr, o.total, o.rcodeCounts, o.latencyUs);
78 }
79 
80 int StatsData::averageLatencyMs() const {
81     return (total == 0) ? 0 : duration_cast<milliseconds>(latencyUs).count() / total;
82 }
83 
84 std::string StatsData::toString() const {
85     if (total == 0) return StringPrintf("%s <no data>", serverSockAddr.ip().toString().c_str());
86 
87     const auto now = std::chrono::steady_clock::now();
88     const int lastUpdateSec = duration_cast<seconds>(now - lastUpdate).count();
89     std::string buf;
90     for (const auto& [rcode, counts] : rcodeCounts) {
91         if (counts != 0) {
92             buf += StringPrintf("%s:%d ", rcodeToName(rcode).c_str(), counts);
93         }
94     }
95     return StringPrintf("%s (%d, %dms, [%s], %ds)", serverSockAddr.ip().toString().c_str(), total,
96                         averageLatencyMs(), buf.c_str(), lastUpdateSec);
97 }
98 
99 StatsRecords::StatsRecords(const IPSockAddr& ipSockAddr, size_t size)
100     : mCapacity(size), mStatsData(ipSockAddr) {}
101 
102 void StatsRecords::push(const Record& record) {
103     updateStatsData(record, true);
104     mRecords.push_back(record);
105 
106     if (mRecords.size() > mCapacity) {
107         updateStatsData(mRecords.front(), false);
108         mRecords.pop_front();
109     }
110 
111     // Update the quality factors.
112     mSkippedCount = 0;
113 
114     // Because failures due to no permission can't prove that the quality of DNS server is bad,
115     // skip the penalty update. The average latency, however, has been updated. For short-latency
116     // servers, it will be fine. For long-latency servers, their average latency will be
117     // decreased but the latency-based algorithm will adjust their average latency back to the
118     // right range after few attempts when network is not restricted.
119     // The check is synced from isNetworkRestricted() in res_send.cpp.
120     if (record.linux_errno != EPERM) {
121         updatePenalty(record);
122     }
123 }
124 
125 void StatsRecords::updateStatsData(const Record& record, const bool add) {
126     const int rcode = record.rcode;
127     if (add) {
128         mStatsData.total += 1;
129         mStatsData.rcodeCounts[rcode] += 1;
130         mStatsData.latencyUs += record.latencyUs;
131     } else {
132         mStatsData.total -= 1;
133         mStatsData.rcodeCounts[rcode] -= 1;
134         mStatsData.latencyUs -= record.latencyUs;
135     }
136     mStatsData.lastUpdate = std::chrono::steady_clock::now();
137 }
138 
139 void StatsRecords::updatePenalty(const Record& record) {
140     switch (record.rcode) {
141         case NS_R_NO_ERROR:
142         case NS_R_NXDOMAIN:
143         case NS_R_NOTAUTH:
144             mPenalty = 0;
145             return;
146         default:
147             // NS_R_TIMEOUT and NS_R_INTERNAL_ERROR are in this case.
148             if (mPenalty == 0) {
149                 mPenalty = 100;
150             } else {
151                 // The evaluated quality drops more quickly when continuous failures happen.
152                 mPenalty = std::min(mPenalty * 2, kMaxQuality);
153             }
154             return;
155     }
156 }
157 
158 double StatsRecords::score() const {
159     const int avgRtt = mStatsData.averageLatencyMs();
160 
161     // Set the lower bound to -1 in case of "avgRtt + mPenalty < mSkippedCount"
162     //   1) when the server doesn't have any stats yet.
163     //   2) when the sorting has been disabled while it was enabled before.
164     int quality = std::clamp(avgRtt + mPenalty - mSkippedCount, -1, kMaxQuality);
165 
166     // Normalization.
167     return static_cast<double>(kMaxQuality - quality) * 100 / kMaxQuality;
168 }
169 
170 void StatsRecords::incrementSkippedCount() {
171     mSkippedCount = std::min(mSkippedCount + 1, kMaxQuality);
172 }
173 
174 bool DnsStats::setServers(const std::vector<netdutils::IPSockAddr>& servers, Protocol protocol) {
175     if (!ensureNoInvalidIp(servers)) return false;
176 
177     ServerStatsMap& statsMap = mStats[protocol];
178     for (const auto& server : servers) {
179         statsMap.try_emplace(server, StatsRecords(server, kLogSize));
180     }
181 
182     // Clean up the map to eliminate the nodes not belonging to the given list of servers.
183     const auto cleanup = [&](ServerStatsMap* statsMap) {
184         ServerStatsMap tmp;
185         for (const auto& server : servers) {
186             if (statsMap->find(server) != statsMap->end()) {
187                 tmp.insert(statsMap->extract(server));
188             }
189         }
190         statsMap->swap(tmp);
191     };
192 
193     cleanup(&statsMap);
194 
195     return true;
196 }
197 
198 bool DnsStats::addStats(const IPSockAddr& ipSockAddr, const DnsQueryEvent& record) {
199     if (ipSockAddr.ip() == INVALID_IPADDRESS) return false;
200 
201     bool added = false;
202     for (auto& [serverSockAddr, statsRecords] : mStats[record.protocol()]) {
203         if (serverSockAddr == ipSockAddr) {
204             const StatsRecords::Record rec = {
205                     .rcode = record.rcode(),
206                     .linux_errno = record.linux_errno(),
207                     .latencyUs = microseconds(record.latency_micros()),
208             };
209             statsRecords.push(rec);
210             added = true;
211         } else {
212             statsRecords.incrementSkippedCount();
213         }
214     }
215 
216     return added;
217 }
218 
219 std::vector<IPSockAddr> DnsStats::getSortedServers(Protocol protocol) const {
220     // DoT unsupported. The handshake overhead is expensive, and the connection will hang for a
221     // while. Need to figure out if it is worth doing for DoT servers.
222     if (protocol == PROTO_DOT) return {};
223 
224     auto it = mStats.find(protocol);
225     if (it == mStats.end()) return {};
226 
227     // Sorting on insertion in decreasing order.
228     std::multimap<double, IPSockAddr, std::greater<double>> sortedData;
229     for (const auto& [ip, statsRecords] : it->second) {
230         sortedData.insert({statsRecords.score(), ip});
231     }
232 
233     std::vector<IPSockAddr> ret;
234     ret.reserve(sortedData.size());
235     for (auto& [_, v] : sortedData) {
236         ret.push_back(v);  // IPSockAddr is trivially-copyable.
237     }
238 
239     return ret;
240 }
241 
242 std::optional<microseconds> DnsStats::getAverageLatencyUs(Protocol protocol) const {
243     const auto stats = getStats(protocol);
244 
245     int count = 0;
246     microseconds sum;
247     for (const auto& v : stats) {
248         count += v.total;
249         sum += v.latencyUs;
250     }
251 
252     if (count == 0) return std::nullopt;
253     return sum / count;
254 }
255 
256 std::vector<StatsData> DnsStats::getStats(Protocol protocol) const {
257     std::vector<StatsData> ret;
258 
259     if (mStats.find(protocol) != mStats.end()) {
260         for (const auto& [_, statsRecords] : mStats.at(protocol)) {
261             ret.push_back(statsRecords.getStatsData());
262         }
263     }
264     return ret;
265 }
266 
267 void DnsStats::dump(DumpWriter& dw) {
268     const auto dumpStatsMap = [&](ServerStatsMap& statsMap) {
269         ScopedIndent indentLog(dw);
270         if (statsMap.size() == 0) {
271             dw.println("<no server>");
272             return;
273         }
274         for (const auto& [_, statsRecords] : statsMap) {
275             const StatsData& data = statsRecords.getStatsData();
276             std::string str = data.toString();
277             str += StringPrintf(" score{%.1f}", statsRecords.score());
278             dw.println("%s", str.c_str());
279         }
280     };
281 
282     dw.println("Server statistics: (total, RTT avg, {rcode:counts}, last update)");
283     ScopedIndent indentStats(dw);
284 
285     dw.println("over UDP");
286     dumpStatsMap(mStats[PROTO_UDP]);
287 
288     dw.println("over TLS");
289     dumpStatsMap(mStats[PROTO_DOT]);
290 
291     dw.println("over TCP");
292     dumpStatsMap(mStats[PROTO_TCP]);
293 }
294 
295 }  // namespace android::net
296