1 /*
2  * Copyright (C) 2016 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless requied by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  *
16  */
17 
18 #include <arpa/inet.h>
19 #include <errno.h>
20 #include <netdb.h>
21 #include <stdarg.h>
22 #include <stdio.h>
23 #include <stdlib.h>
24 #include <unistd.h>
25 
26 #include <cutils/sockets.h>
27 #include <android-base/stringprintf.h>
28 #include <private/android_filesystem_config.h>
29 
30 #include <algorithm>
31 #include <chrono>
32 #include <iterator>
33 #include <numeric>
34 #include <thread>
35 
36 #define LOG_TAG "netd_test"
37 // TODO: make this dynamic and stop depending on implementation details.
38 #define TEST_OEM_NETWORK "oem29"
39 #define TEST_NETID 30
40 
41 #include "NetdClient.h"
42 
43 #include <gtest/gtest.h>
44 
45 #include <utils/Log.h>
46 
47 #include <testUtil.h>
48 
49 #include "dns_responder.h"
50 #include "resolv_params.h"
51 #include "ResolverStats.h"
52 
53 #include "android/net/INetd.h"
54 #include "binder/IServiceManager.h"
55 
56 using android::base::StringPrintf;
57 using android::base::StringAppendF;
58 using android::net::ResolverStats;
59 
60 // Emulates the behavior of UnorderedElementsAreArray, which currently cannot be used.
61 // TODO: Use UnorderedElementsAreArray, which depends on being able to compile libgmock_host,
62 // if that is not possible, improve this hacky algorithm, which is O(n**2)
63 template <class A, class B>
UnorderedCompareArray(const A & a,const B & b)64 bool UnorderedCompareArray(const A& a, const B& b) {
65     if (a.size() != b.size()) return false;
66     for (const auto& a_elem : a) {
67         size_t a_count = 0;
68         for (const auto& a_elem2 : a) {
69             if (a_elem == a_elem2) {
70                 ++a_count;
71             }
72         }
73         size_t b_count = 0;
74         for (const auto& b_elem : b) {
75             if (a_elem == b_elem) ++b_count;
76         }
77         if (a_count != b_count) return false;
78     }
79     return true;
80 }
81 
82 // The only response code used in this test, see
83 // frameworks/base/services/java/com/android/server/NetworkManagementService.java
84 // for others.
85 static constexpr int ResponseCodeOK = 200;
86 
87 // Returns ResponseCode.
netdCommand(const char * sockname,const char * command)88 int netdCommand(const char* sockname, const char* command) {
89     int sock = socket_local_client(sockname,
90                                    ANDROID_SOCKET_NAMESPACE_RESERVED,
91                                    SOCK_STREAM);
92     if (sock < 0) {
93         perror("Error connecting");
94         return -1;
95     }
96 
97     // FrameworkListener expects the whole command in one read.
98     char buffer[256];
99     int nwritten = snprintf(buffer, sizeof(buffer), "0 %s", command);
100     if (write(sock, buffer, nwritten + 1) < 0) {
101         perror("Error sending netd command");
102         close(sock);
103         return -1;
104     }
105 
106     int nread = read(sock, buffer, sizeof(buffer));
107     if (nread < 0) {
108         perror("Error reading response");
109         close(sock);
110         return -1;
111     }
112     close(sock);
113     return atoi(buffer);
114 }
115 
expectNetdResult(int expected,const char * sockname,const char * format,...)116 bool expectNetdResult(int expected, const char* sockname, const char* format, ...) {
117     char command[256];
118     va_list args;
119     va_start(args, format);
120     vsnprintf(command, sizeof(command), format, args);
121     va_end(args);
122     int result = netdCommand(sockname, command);
123     EXPECT_EQ(expected, result) << command;
124     return (200 <= expected && expected < 300);
125 }
126 
127 class AddrInfo {
128   public:
AddrInfo()129     AddrInfo() : ai_(nullptr), error_(0) {}
130 
AddrInfo(const char * node,const char * service,const addrinfo & hints)131     AddrInfo(const char* node, const char* service, const addrinfo& hints) : ai_(nullptr) {
132         init(node, service, hints);
133     }
134 
AddrInfo(const char * node,const char * service)135     AddrInfo(const char* node, const char* service) : ai_(nullptr) {
136         init(node, service);
137     }
138 
~AddrInfo()139     ~AddrInfo() { clear(); }
140 
init(const char * node,const char * service,const addrinfo & hints)141     int init(const char* node, const char* service, const addrinfo& hints) {
142         clear();
143         error_ = getaddrinfo(node, service, &hints, &ai_);
144         return error_;
145     }
146 
init(const char * node,const char * service)147     int init(const char* node, const char* service) {
148         clear();
149         error_ = getaddrinfo(node, service, nullptr, &ai_);
150         return error_;
151     }
152 
clear()153     void clear() {
154         if (ai_ != nullptr) {
155             freeaddrinfo(ai_);
156             ai_ = nullptr;
157             error_ = 0;
158         }
159     }
160 
operator *() const161     const addrinfo& operator*() const { return *ai_; }
get() const162     const addrinfo* get() const { return ai_; }
operator &() const163     const addrinfo* operator&() const { return ai_; }
error() const164     int error() const { return error_; }
165 
166   private:
167     addrinfo* ai_;
168     int error_;
169 };
170 
171 class ResolverTest : public ::testing::Test {
172 protected:
173     struct Mapping {
174         std::string host;
175         std::string entry;
176         std::string ip4;
177         std::string ip6;
178     };
179 
SetUp()180     virtual void SetUp() {
181         // Ensure resolutions go via proxy.
182         setenv("ANDROID_DNS_MODE", "", 1);
183         uid = getuid();
184         pid = getpid();
185         SetupOemNetwork();
186 
187         // binder setup
188         auto binder = android::defaultServiceManager()->getService(android::String16("netd"));
189         ASSERT_TRUE(binder != nullptr);
190         mNetdSrv = android::interface_cast<android::net::INetd>(binder);
191     }
192 
TearDown()193     virtual void TearDown() {
194         TearDownOemNetwork();
195         netdCommand("netd", "network destroy " TEST_OEM_NETWORK);
196     }
197 
SetupOemNetwork()198     void SetupOemNetwork() {
199         netdCommand("netd", "network destroy " TEST_OEM_NETWORK);
200         if (expectNetdResult(ResponseCodeOK, "netd",
201                              "network create %s", TEST_OEM_NETWORK)) {
202             oemNetId = TEST_NETID;
203         }
204         setNetworkForProcess(oemNetId);
205         ASSERT_EQ((unsigned) oemNetId, getNetworkForProcess());
206     }
207 
SetupMappings(unsigned num_hosts,const std::vector<std::string> & domains,std::vector<Mapping> * mappings) const208     void SetupMappings(unsigned num_hosts, const std::vector<std::string>& domains,
209             std::vector<Mapping>* mappings) const {
210         mappings->resize(num_hosts * domains.size());
211         auto mappings_it = mappings->begin();
212         for (unsigned i = 0 ; i < num_hosts ; ++i) {
213             for (const auto& domain : domains) {
214                 ASSERT_TRUE(mappings_it != mappings->end());
215                 mappings_it->host = StringPrintf("host%u", i);
216                 mappings_it->entry = StringPrintf("%s.%s.", mappings_it->host.c_str(),
217                         domain.c_str());
218                 mappings_it->ip4 = StringPrintf("192.0.2.%u", i%253 + 1);
219                 mappings_it->ip6 = StringPrintf("2001:db8::%x", i%65534 + 1);
220                 ++mappings_it;
221             }
222         }
223     }
224 
SetupDNSServers(unsigned num_servers,const std::vector<Mapping> & mappings,std::vector<std::unique_ptr<test::DNSResponder>> * dns,std::vector<std::string> * servers) const225     void SetupDNSServers(unsigned num_servers, const std::vector<Mapping>& mappings,
226             std::vector<std::unique_ptr<test::DNSResponder>>* dns,
227             std::vector<std::string>* servers) const {
228         ASSERT_TRUE(num_servers != 0 && num_servers < 100);
229         const char* listen_srv = "53";
230         dns->resize(num_servers);
231         servers->resize(num_servers);
232         for (unsigned i = 0 ; i < num_servers ; ++i) {
233             auto& server = (*servers)[i];
234             auto& d = (*dns)[i];
235             server = StringPrintf("127.0.0.%u", i + 100);
236             d = std::make_unique<test::DNSResponder>(server, listen_srv, 250,
237                     ns_rcode::ns_r_servfail, 1.0);
238             ASSERT_TRUE(d.get() != nullptr);
239             for (const auto& mapping : mappings) {
240                 d->addMapping(mapping.entry.c_str(), ns_type::ns_t_a, mapping.ip4.c_str());
241                 d->addMapping(mapping.entry.c_str(), ns_type::ns_t_aaaa, mapping.ip6.c_str());
242             }
243             ASSERT_TRUE(d->startServer());
244         }
245     }
246 
ShutdownDNSServers(std::vector<std::unique_ptr<test::DNSResponder>> * dns) const247     void ShutdownDNSServers(std::vector<std::unique_ptr<test::DNSResponder>>* dns) const {
248         for (const auto& d : *dns) {
249             ASSERT_TRUE(d.get() != nullptr);
250             d->stopServer();
251         }
252         dns->clear();
253     }
254 
TearDownOemNetwork()255     void TearDownOemNetwork() {
256         if (oemNetId != -1) {
257             expectNetdResult(ResponseCodeOK, "netd",
258                              "network destroy %s", TEST_OEM_NETWORK);
259         }
260     }
261 
SetResolversForNetwork(const std::vector<std::string> & servers,const std::vector<std::string> & domains,const std::vector<int> & params)262     bool SetResolversForNetwork(const std::vector<std::string>& servers,
263             const std::vector<std::string>& domains, const std::vector<int>& params) {
264         auto rv = mNetdSrv->setResolverConfiguration(TEST_NETID, servers, domains, params);
265         return rv.isOk();
266     }
267 
SetResolversForNetwork(const std::vector<std::string> & searchDomains,const std::vector<std::string> & servers,const std::string & params)268     bool SetResolversForNetwork(const std::vector<std::string>& searchDomains,
269             const std::vector<std::string>& servers, const std::string& params) {
270         std::string cmd = StringPrintf("resolver setnetdns %d \"", oemNetId);
271         if (!searchDomains.empty()) {
272             cmd += searchDomains[0].c_str();
273             for (size_t i = 1 ; i < searchDomains.size() ; ++i) {
274                 cmd += " ";
275                 cmd += searchDomains[i];
276             }
277         }
278         cmd += "\"";
279 
280         for (const auto& str : servers) {
281             cmd += " ";
282             cmd += str;
283         }
284 
285         if (!params.empty()) {
286             cmd += " --params \"";
287             cmd += params;
288             cmd += "\"";
289         }
290 
291         int rv = netdCommand("netd", cmd.c_str());
292         if (rv != ResponseCodeOK) {
293             return false;
294         }
295         return true;
296     }
297 
GetResolverInfo(std::vector<std::string> * servers,std::vector<std::string> * domains,__res_params * params,std::vector<ResolverStats> * stats)298     bool GetResolverInfo(std::vector<std::string>* servers, std::vector<std::string>* domains,
299             __res_params* params, std::vector<ResolverStats>* stats) {
300         using android::net::INetd;
301         std::vector<int32_t> params32;
302         std::vector<int32_t> stats32;
303         auto rv = mNetdSrv->getResolverInfo(TEST_NETID, servers, domains, &params32, &stats32);
304         if (!rv.isOk() || params32.size() != INetd::RESOLVER_PARAMS_COUNT) {
305             return false;
306         }
307         *params = __res_params {
308             .sample_validity = static_cast<uint16_t>(
309                     params32[INetd::RESOLVER_PARAMS_SAMPLE_VALIDITY]),
310             .success_threshold = static_cast<uint8_t>(
311                     params32[INetd::RESOLVER_PARAMS_SUCCESS_THRESHOLD]),
312             .min_samples = static_cast<uint8_t>(
313                     params32[INetd::RESOLVER_PARAMS_MIN_SAMPLES]),
314             .max_samples = static_cast<uint8_t>(
315                     params32[INetd::RESOLVER_PARAMS_MAX_SAMPLES])
316         };
317         return ResolverStats::decodeAll(stats32, stats);
318     }
319 
ToString(const hostent * he) const320     std::string ToString(const hostent* he) const {
321         if (he == nullptr) return "<null>";
322         char buffer[INET6_ADDRSTRLEN];
323         if (!inet_ntop(he->h_addrtype, he->h_addr_list[0], buffer, sizeof(buffer))) {
324             return "<invalid>";
325         }
326         return buffer;
327     }
328 
ToString(const addrinfo * ai) const329     std::string ToString(const addrinfo* ai) const {
330         if (!ai)
331             return "<null>";
332         for (const auto* aip = ai ; aip != nullptr ; aip = aip->ai_next) {
333             char host[NI_MAXHOST];
334             int rv = getnameinfo(aip->ai_addr, aip->ai_addrlen, host, sizeof(host), nullptr, 0,
335                     NI_NUMERICHOST);
336             if (rv != 0)
337                 return gai_strerror(rv);
338             return host;
339         }
340         return "<invalid>";
341     }
342 
GetNumQueries(const test::DNSResponder & dns,const char * name) const343     size_t GetNumQueries(const test::DNSResponder& dns, const char* name) const {
344         auto queries = dns.queries();
345         size_t found = 0;
346         for (const auto& p : queries) {
347             if (p.first == name) {
348                 ++found;
349             }
350         }
351         return found;
352     }
353 
GetNumQueriesForType(const test::DNSResponder & dns,ns_type type,const char * name) const354     size_t GetNumQueriesForType(const test::DNSResponder& dns, ns_type type,
355             const char* name) const {
356         auto queries = dns.queries();
357         size_t found = 0;
358         for (const auto& p : queries) {
359             if (p.second == type && p.first == name) {
360                 ++found;
361             }
362         }
363         return found;
364     }
365 
RunGetAddrInfoStressTest_Binder(unsigned num_hosts,unsigned num_threads,unsigned num_queries)366     void RunGetAddrInfoStressTest_Binder(unsigned num_hosts, unsigned num_threads,
367             unsigned num_queries) {
368         std::vector<std::string> domains = { "example.com" };
369         std::vector<std::unique_ptr<test::DNSResponder>> dns;
370         std::vector<std::string> servers;
371         std::vector<Mapping> mappings;
372         ASSERT_NO_FATAL_FAILURE(SetupMappings(num_hosts, domains, &mappings));
373         ASSERT_NO_FATAL_FAILURE(SetupDNSServers(MAXNS, mappings, &dns, &servers));
374 
375         ASSERT_TRUE(SetResolversForNetwork(servers, domains, mDefaultParams_Binder));
376 
377         auto t0 = std::chrono::steady_clock::now();
378         std::vector<std::thread> threads(num_threads);
379         for (std::thread& thread : threads) {
380            thread = std::thread([this, &servers, &dns, &mappings, num_queries]() {
381                 for (unsigned i = 0 ; i < num_queries ; ++i) {
382                     uint32_t ofs = arc4random_uniform(mappings.size());
383                     ASSERT_TRUE(ofs < mappings.size());
384                     auto& mapping = mappings[i];
385                     addrinfo* result = nullptr;
386                     int rv = getaddrinfo(mapping.host.c_str(), nullptr, nullptr, &result);
387                     EXPECT_EQ(0, rv) << "error [" << rv << "] " << gai_strerror(rv);
388                     if (rv == 0) {
389                         std::string result_str = ToString(result);
390                         EXPECT_TRUE(result_str == mapping.ip4 || result_str == mapping.ip6)
391                             << "result='" << result_str << "', ip4='" << mapping.ip4
392                             << "', ip6='" << mapping.ip6;
393                     }
394                     if (result) {
395                         freeaddrinfo(result);
396                         result = nullptr;
397                     }
398                 }
399             });
400         }
401 
402         for (std::thread& thread : threads) {
403             thread.join();
404         }
405         auto t1 = std::chrono::steady_clock::now();
406         ALOGI("%u hosts, %u threads, %u queries, %Es", num_hosts, num_threads, num_queries,
407                 std::chrono::duration<double>(t1 - t0).count());
408         ASSERT_NO_FATAL_FAILURE(ShutdownDNSServers(&dns));
409     }
410 
411     int pid;
412     int uid;
413     int oemNetId = -1;
414     android::sp<android::net::INetd> mNetdSrv = nullptr;
415     const std::vector<std::string> mDefaultSearchDomains = { "example.com" };
416     // <sample validity in s> <success threshold in percent> <min samples> <max samples>
417     const std::string mDefaultParams = "300 25 8 8";
418     const std::vector<int> mDefaultParams_Binder = { 300, 25, 8, 8 };
419 };
420 
TEST_F(ResolverTest,GetHostByName)421 TEST_F(ResolverTest, GetHostByName) {
422     const char* listen_addr = "127.0.0.3";
423     const char* listen_srv = "53";
424     const char* host_name = "hello.example.com.";
425     test::DNSResponder dns(listen_addr, listen_srv, 250, ns_rcode::ns_r_servfail, 1.0);
426     dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.3");
427     ASSERT_TRUE(dns.startServer());
428     std::vector<std::string> servers = { listen_addr };
429     ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
430 
431     dns.clearQueries();
432     const hostent* result = gethostbyname("hello");
433     EXPECT_EQ(1U, GetNumQueriesForType(dns, ns_type::ns_t_a, host_name));
434     ASSERT_FALSE(result == nullptr);
435     ASSERT_EQ(4, result->h_length);
436     ASSERT_FALSE(result->h_addr_list[0] == nullptr);
437     EXPECT_EQ("1.2.3.3", ToString(result));
438     EXPECT_TRUE(result->h_addr_list[1] == nullptr);
439     dns.stopServer();
440 }
441 
TEST_F(ResolverTest,TestBinderSerialization)442 TEST_F(ResolverTest, TestBinderSerialization) {
443     using android::net::INetd;
444     std::vector<int> params_offsets = {
445         INetd::RESOLVER_PARAMS_SAMPLE_VALIDITY,
446         INetd::RESOLVER_PARAMS_SUCCESS_THRESHOLD,
447         INetd::RESOLVER_PARAMS_MIN_SAMPLES,
448         INetd::RESOLVER_PARAMS_MAX_SAMPLES
449     };
450     int size = static_cast<int>(params_offsets.size());
451     EXPECT_EQ(size, INetd::RESOLVER_PARAMS_COUNT);
452     std::sort(params_offsets.begin(), params_offsets.end());
453     for (int i = 0 ; i < size ; ++i) {
454         EXPECT_EQ(params_offsets[i], i);
455     }
456 }
457 
TEST_F(ResolverTest,GetHostByName_Binder)458 TEST_F(ResolverTest, GetHostByName_Binder) {
459     using android::net::INetd;
460 
461     std::vector<std::string> domains = { "example.com" };
462     std::vector<std::unique_ptr<test::DNSResponder>> dns;
463     std::vector<std::string> servers;
464     std::vector<Mapping> mappings;
465     ASSERT_NO_FATAL_FAILURE(SetupMappings(1, domains, &mappings));
466     ASSERT_NO_FATAL_FAILURE(SetupDNSServers(4, mappings, &dns, &servers));
467     ASSERT_EQ(1U, mappings.size());
468     const Mapping& mapping = mappings[0];
469 
470     ASSERT_TRUE(SetResolversForNetwork(servers, domains, mDefaultParams_Binder));
471 
472     const hostent* result = gethostbyname(mapping.host.c_str());
473     size_t total_queries = std::accumulate(dns.begin(), dns.end(), 0,
474             [this, &mapping](size_t total, auto& d) {
475                 return total + GetNumQueriesForType(*d, ns_type::ns_t_a, mapping.entry.c_str());
476             });
477 
478     EXPECT_LE(1U, total_queries);
479     ASSERT_FALSE(result == nullptr);
480     ASSERT_EQ(4, result->h_length);
481     ASSERT_FALSE(result->h_addr_list[0] == nullptr);
482     EXPECT_EQ(mapping.ip4, ToString(result));
483     EXPECT_TRUE(result->h_addr_list[1] == nullptr);
484 
485     std::vector<std::string> res_servers;
486     std::vector<std::string> res_domains;
487     __res_params res_params;
488     std::vector<ResolverStats> res_stats;
489     ASSERT_TRUE(GetResolverInfo(&res_servers, &res_domains, &res_params, &res_stats));
490     EXPECT_EQ(servers.size(), res_servers.size());
491     EXPECT_EQ(domains.size(), res_domains.size());
492     ASSERT_EQ(INetd::RESOLVER_PARAMS_COUNT, mDefaultParams_Binder.size());
493     EXPECT_EQ(mDefaultParams_Binder[INetd::RESOLVER_PARAMS_SAMPLE_VALIDITY],
494             res_params.sample_validity);
495     EXPECT_EQ(mDefaultParams_Binder[INetd::RESOLVER_PARAMS_SUCCESS_THRESHOLD],
496             res_params.success_threshold);
497     EXPECT_EQ(mDefaultParams_Binder[INetd::RESOLVER_PARAMS_MIN_SAMPLES], res_params.min_samples);
498     EXPECT_EQ(mDefaultParams_Binder[INetd::RESOLVER_PARAMS_MAX_SAMPLES], res_params.max_samples);
499     EXPECT_EQ(servers.size(), res_stats.size());
500 
501     EXPECT_TRUE(UnorderedCompareArray(res_servers, servers));
502     EXPECT_TRUE(UnorderedCompareArray(res_domains, domains));
503 
504     ASSERT_NO_FATAL_FAILURE(ShutdownDNSServers(&dns));
505 }
506 
TEST_F(ResolverTest,GetAddrInfo)507 TEST_F(ResolverTest, GetAddrInfo) {
508     addrinfo* result = nullptr;
509 
510     const char* listen_addr = "127.0.0.4";
511     const char* listen_addr2 = "127.0.0.5";
512     const char* listen_srv = "53";
513     const char* host_name = "howdy.example.com.";
514     test::DNSResponder dns(listen_addr, listen_srv, 250,
515                            ns_rcode::ns_r_servfail, 1.0);
516     dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.4");
517     dns.addMapping(host_name, ns_type::ns_t_aaaa, "::1.2.3.4");
518     ASSERT_TRUE(dns.startServer());
519 
520     test::DNSResponder dns2(listen_addr2, listen_srv, 250,
521                             ns_rcode::ns_r_servfail, 1.0);
522     dns2.addMapping(host_name, ns_type::ns_t_a, "1.2.3.4");
523     dns2.addMapping(host_name, ns_type::ns_t_aaaa, "::1.2.3.4");
524     ASSERT_TRUE(dns2.startServer());
525 
526     for (size_t i = 0 ; i < 1000 ; ++i) {
527         std::vector<std::string> servers = { listen_addr };
528         ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
529         dns.clearQueries();
530         dns2.clearQueries();
531 
532         EXPECT_EQ(0, getaddrinfo("howdy", nullptr, nullptr, &result));
533         size_t found = GetNumQueries(dns, host_name);
534         EXPECT_LE(1U, found);
535         // Could be A or AAAA
536         std::string result_str = ToString(result);
537         EXPECT_TRUE(result_str == "1.2.3.4" || result_str == "::1.2.3.4")
538             << ", result_str='" << result_str << "'";
539         // TODO: Use ScopedAddrinfo or similar once it is available in a common header file.
540         if (result) {
541             freeaddrinfo(result);
542             result = nullptr;
543         }
544 
545         // Verify that the name is cached.
546         size_t old_found = found;
547         EXPECT_EQ(0, getaddrinfo("howdy", nullptr, nullptr, &result));
548         found = GetNumQueries(dns, host_name);
549         EXPECT_LE(1U, found);
550         EXPECT_EQ(old_found, found);
551         result_str = ToString(result);
552         EXPECT_TRUE(result_str == "1.2.3.4" || result_str == "::1.2.3.4")
553             << result_str;
554         if (result) {
555             freeaddrinfo(result);
556             result = nullptr;
557         }
558 
559         // Change the DNS resolver, ensure that queries are no longer cached.
560         servers = { listen_addr2 };
561         ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
562         dns.clearQueries();
563         dns2.clearQueries();
564 
565         EXPECT_EQ(0, getaddrinfo("howdy", nullptr, nullptr, &result));
566         found = GetNumQueries(dns, host_name);
567         size_t found2 = GetNumQueries(dns2, host_name);
568         EXPECT_EQ(0U, found);
569         EXPECT_LE(1U, found2);
570 
571         // Could be A or AAAA
572         result_str = ToString(result);
573         EXPECT_TRUE(result_str == "1.2.3.4" || result_str == "::1.2.3.4")
574             << ", result_str='" << result_str << "'";
575         if (result) {
576             freeaddrinfo(result);
577             result = nullptr;
578         }
579     }
580     dns.stopServer();
581     dns2.stopServer();
582 }
583 
TEST_F(ResolverTest,GetAddrInfoV4)584 TEST_F(ResolverTest, GetAddrInfoV4) {
585     addrinfo* result = nullptr;
586 
587     const char* listen_addr = "127.0.0.5";
588     const char* listen_srv = "53";
589     const char* host_name = "hola.example.com.";
590     test::DNSResponder dns(listen_addr, listen_srv, 250,
591                            ns_rcode::ns_r_servfail, 1.0);
592     dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.5");
593     ASSERT_TRUE(dns.startServer());
594     std::vector<std::string> servers = { listen_addr };
595     ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
596 
597     addrinfo hints;
598     memset(&hints, 0, sizeof(hints));
599     hints.ai_family = AF_INET;
600     EXPECT_EQ(0, getaddrinfo("hola", nullptr, &hints, &result));
601     EXPECT_EQ(1U, GetNumQueries(dns, host_name));
602     EXPECT_EQ("1.2.3.5", ToString(result));
603     if (result) {
604         freeaddrinfo(result);
605         result = nullptr;
606     }
607 }
608 
TEST_F(ResolverTest,MultidomainResolution)609 TEST_F(ResolverTest, MultidomainResolution) {
610     std::vector<std::string> searchDomains = { "example1.com", "example2.com", "example3.com" };
611     const char* listen_addr = "127.0.0.6";
612     const char* listen_srv = "53";
613     const char* host_name = "nihao.example2.com.";
614     test::DNSResponder dns(listen_addr, listen_srv, 250,
615                            ns_rcode::ns_r_servfail, 1.0);
616     dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.3");
617     ASSERT_TRUE(dns.startServer());
618     std::vector<std::string> servers = { listen_addr };
619     ASSERT_TRUE(SetResolversForNetwork(searchDomains, servers, mDefaultParams));
620 
621     dns.clearQueries();
622     const hostent* result = gethostbyname("nihao");
623     EXPECT_EQ(1U, GetNumQueriesForType(dns, ns_type::ns_t_a, host_name));
624     ASSERT_FALSE(result == nullptr);
625     ASSERT_EQ(4, result->h_length);
626     ASSERT_FALSE(result->h_addr_list[0] == nullptr);
627     EXPECT_EQ("1.2.3.3", ToString(result));
628     EXPECT_TRUE(result->h_addr_list[1] == nullptr);
629     dns.stopServer();
630 }
631 
TEST_F(ResolverTest,GetAddrInfoV6_failing)632 TEST_F(ResolverTest, GetAddrInfoV6_failing) {
633     addrinfo* result = nullptr;
634 
635     const char* listen_addr0 = "127.0.0.7";
636     const char* listen_addr1 = "127.0.0.8";
637     const char* listen_srv = "53";
638     const char* host_name = "ohayou.example.com.";
639     test::DNSResponder dns0(listen_addr0, listen_srv, 250,
640                             ns_rcode::ns_r_servfail, 0.0);
641     test::DNSResponder dns1(listen_addr1, listen_srv, 250,
642                             ns_rcode::ns_r_servfail, 1.0);
643     dns0.addMapping(host_name, ns_type::ns_t_aaaa, "2001:db8::5");
644     dns1.addMapping(host_name, ns_type::ns_t_aaaa, "2001:db8::6");
645     ASSERT_TRUE(dns0.startServer());
646     ASSERT_TRUE(dns1.startServer());
647     std::vector<std::string> servers = { listen_addr0, listen_addr1 };
648     // <sample validity in s> <success threshold in percent> <min samples> <max samples>
649     unsigned sample_validity = 300;
650     int success_threshold = 25;
651     int sample_count = 8;
652     std::string params = StringPrintf("%u %d %d %d", sample_validity, success_threshold,
653             sample_count, sample_count);
654     ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, params));
655 
656     // Repeatedly perform resolutions for non-existing domains until MAXNSSAMPLES resolutions have
657     // reached the dns0, which is set to fail. No more requests should then arrive at that server
658     // for the next sample_lifetime seconds.
659     // TODO: This approach is implementation-dependent, change once metrics reporting is available.
660     addrinfo hints;
661     memset(&hints, 0, sizeof(hints));
662     hints.ai_family = AF_INET6;
663     for (int i = 0 ; i < sample_count ; ++i) {
664         std::string domain = StringPrintf("nonexistent%d", i);
665         getaddrinfo(domain.c_str(), nullptr, &hints, &result);
666         if (result) {
667             freeaddrinfo(result);
668             result = nullptr;
669         }
670     }
671     // Due to 100% errors for all possible samples, the server should be ignored from now on and
672     // only the second one used for all following queries, until NSSAMPLE_VALIDITY is reached.
673     dns0.clearQueries();
674     dns1.clearQueries();
675     EXPECT_EQ(0, getaddrinfo("ohayou", nullptr, &hints, &result));
676     EXPECT_EQ(0U, GetNumQueries(dns0, host_name));
677     EXPECT_EQ(1U, GetNumQueries(dns1, host_name));
678     if (result) {
679         freeaddrinfo(result);
680         result = nullptr;
681     }
682 }
683 
TEST_F(ResolverTest,GetAddrInfoV6_concurrent)684 TEST_F(ResolverTest, GetAddrInfoV6_concurrent) {
685     const char* listen_addr0 = "127.0.0.9";
686     const char* listen_addr1 = "127.0.0.10";
687     const char* listen_addr2 = "127.0.0.11";
688     const char* listen_srv = "53";
689     const char* host_name = "konbanha.example.com.";
690     test::DNSResponder dns0(listen_addr0, listen_srv, 250,
691                             ns_rcode::ns_r_servfail, 1.0);
692     test::DNSResponder dns1(listen_addr1, listen_srv, 250,
693                             ns_rcode::ns_r_servfail, 1.0);
694     test::DNSResponder dns2(listen_addr2, listen_srv, 250,
695                             ns_rcode::ns_r_servfail, 1.0);
696     dns0.addMapping(host_name, ns_type::ns_t_aaaa, "2001:db8::5");
697     dns1.addMapping(host_name, ns_type::ns_t_aaaa, "2001:db8::6");
698     dns2.addMapping(host_name, ns_type::ns_t_aaaa, "2001:db8::7");
699     ASSERT_TRUE(dns0.startServer());
700     ASSERT_TRUE(dns1.startServer());
701     ASSERT_TRUE(dns2.startServer());
702     const std::vector<std::string> servers = { listen_addr0, listen_addr1, listen_addr2 };
703     std::vector<std::thread> threads(10);
704     for (std::thread& thread : threads) {
705        thread = std::thread([this, &servers, &dns0, &dns1, &dns2]() {
706             unsigned delay = arc4random_uniform(1*1000*1000); // <= 1s
707             usleep(delay);
708             std::vector<std::string> serverSubset;
709             for (const auto& server : servers) {
710                 if (arc4random_uniform(2)) {
711                     serverSubset.push_back(server);
712                 }
713             }
714             if (serverSubset.empty()) serverSubset = servers;
715             ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, serverSubset,
716                     mDefaultParams));
717             addrinfo hints;
718             memset(&hints, 0, sizeof(hints));
719             hints.ai_family = AF_INET6;
720             addrinfo* result = nullptr;
721             int rv = getaddrinfo("konbanha", nullptr, &hints, &result);
722             EXPECT_EQ(0, rv) << "error [" << rv << "] " << gai_strerror(rv);
723             if (result) {
724                 freeaddrinfo(result);
725                 result = nullptr;
726             }
727         });
728     }
729     for (std::thread& thread : threads) {
730         thread.join();
731     }
732 }
733 
TEST_F(ResolverTest,GetAddrInfoStressTest_Binder_100)734 TEST_F(ResolverTest, GetAddrInfoStressTest_Binder_100) {
735     const unsigned num_hosts = 100;
736     const unsigned num_threads = 100;
737     const unsigned num_queries = 100;
738     ASSERT_NO_FATAL_FAILURE(RunGetAddrInfoStressTest_Binder(num_hosts, num_threads, num_queries));
739 }
740 
TEST_F(ResolverTest,GetAddrInfoStressTest_Binder_100000)741 TEST_F(ResolverTest, GetAddrInfoStressTest_Binder_100000) {
742     const unsigned num_hosts = 100000;
743     const unsigned num_threads = 100;
744     const unsigned num_queries = 100;
745     ASSERT_NO_FATAL_FAILURE(RunGetAddrInfoStressTest_Binder(num_hosts, num_threads, num_queries));
746 }
747 
TEST_F(ResolverTest,EmptySetup)748 TEST_F(ResolverTest, EmptySetup) {
749     using android::net::INetd;
750     std::vector<std::string> servers;
751     std::vector<std::string> domains;
752     ASSERT_TRUE(SetResolversForNetwork(servers, domains, mDefaultParams_Binder));
753     std::vector<std::string> res_servers;
754     std::vector<std::string> res_domains;
755     __res_params res_params;
756     std::vector<ResolverStats> res_stats;
757     ASSERT_TRUE(GetResolverInfo(&res_servers, &res_domains, &res_params, &res_stats));
758     EXPECT_EQ(0U, res_servers.size());
759     EXPECT_EQ(0U, res_domains.size());
760     ASSERT_EQ(INetd::RESOLVER_PARAMS_COUNT, mDefaultParams_Binder.size());
761     EXPECT_EQ(mDefaultParams_Binder[INetd::RESOLVER_PARAMS_SAMPLE_VALIDITY],
762             res_params.sample_validity);
763     EXPECT_EQ(mDefaultParams_Binder[INetd::RESOLVER_PARAMS_SUCCESS_THRESHOLD],
764             res_params.success_threshold);
765     EXPECT_EQ(mDefaultParams_Binder[INetd::RESOLVER_PARAMS_MIN_SAMPLES], res_params.min_samples);
766     EXPECT_EQ(mDefaultParams_Binder[INetd::RESOLVER_PARAMS_MAX_SAMPLES], res_params.max_samples);
767 }
768 
TEST_F(ResolverTest,SearchPathChange)769 TEST_F(ResolverTest, SearchPathChange) {
770     addrinfo* result = nullptr;
771 
772     const char* listen_addr = "127.0.0.13";
773     const char* listen_srv = "53";
774     const char* host_name1 = "test13.domain1.org.";
775     const char* host_name2 = "test13.domain2.org.";
776     test::DNSResponder dns(listen_addr, listen_srv, 250,
777                            ns_rcode::ns_r_servfail, 1.0);
778     dns.addMapping(host_name1, ns_type::ns_t_aaaa, "2001:db8::13");
779     dns.addMapping(host_name2, ns_type::ns_t_aaaa, "2001:db8::1:13");
780     ASSERT_TRUE(dns.startServer());
781     std::vector<std::string> servers = { listen_addr };
782     std::vector<std::string> domains = { "domain1.org" };
783     ASSERT_TRUE(SetResolversForNetwork(domains, servers, mDefaultParams));
784 
785     addrinfo hints;
786     memset(&hints, 0, sizeof(hints));
787     hints.ai_family = AF_INET6;
788     EXPECT_EQ(0, getaddrinfo("test13", nullptr, &hints, &result));
789     EXPECT_EQ(1U, dns.queries().size());
790     EXPECT_EQ(1U, GetNumQueries(dns, host_name1));
791     EXPECT_EQ("2001:db8::13", ToString(result));
792     if (result) freeaddrinfo(result);
793 
794     // Test that changing the domain search path on its own works.
795     domains = { "domain2.org" };
796     ASSERT_TRUE(SetResolversForNetwork(domains, servers, mDefaultParams));
797     dns.clearQueries();
798 
799     EXPECT_EQ(0, getaddrinfo("test13", nullptr, &hints, &result));
800     EXPECT_EQ(1U, dns.queries().size());
801     EXPECT_EQ(1U, GetNumQueries(dns, host_name2));
802     EXPECT_EQ("2001:db8::1:13", ToString(result));
803     if (result) freeaddrinfo(result);
804 }
805 
TEST_F(ResolverTest,MaxServerPrune_Binder)806 TEST_F(ResolverTest, MaxServerPrune_Binder) {
807     using android::net::INetd;
808 
809     std::vector<std::string> domains = { "example.com" };
810     std::vector<std::unique_ptr<test::DNSResponder>> dns;
811     std::vector<std::string> servers;
812     std::vector<Mapping> mappings;
813     ASSERT_NO_FATAL_FAILURE(SetupMappings(1, domains, &mappings));
814     ASSERT_NO_FATAL_FAILURE(SetupDNSServers(MAXNS + 1, mappings, &dns, &servers));
815 
816     ASSERT_TRUE(SetResolversForNetwork(servers, domains,  mDefaultParams_Binder));
817 
818     std::vector<std::string> res_servers;
819     std::vector<std::string> res_domains;
820     __res_params res_params;
821     std::vector<ResolverStats> res_stats;
822     ASSERT_TRUE(GetResolverInfo(&res_servers, &res_domains, &res_params, &res_stats));
823     EXPECT_EQ(static_cast<size_t>(MAXNS), res_servers.size());
824 
825     ASSERT_NO_FATAL_FAILURE(ShutdownDNSServers(&dns));
826 }
827