1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <gmock/gmock.h>
18 #include <gtest/gtest.h>
19 
20 #include "PrivateDnsConfiguration.h"
21 #include "tests/dns_responder/dns_responder.h"
22 #include "tests/dns_responder/dns_tls_frontend.h"
23 #include "tests/resolv_test_utils.h"
24 
25 namespace android::net {
26 
27 using namespace std::chrono_literals;
28 
29 class PrivateDnsConfigurationTest : public ::testing::Test {
30   public:
31     using ServerIdentity = PrivateDnsConfiguration::ServerIdentity;
32 
33     static void SetUpTestSuite() {
34         // stopServer() will be called in their destructor.
35         ASSERT_TRUE(tls1.startServer());
36         ASSERT_TRUE(tls2.startServer());
37         ASSERT_TRUE(backend.startServer());
38         ASSERT_TRUE(backend1ForUdpProbe.startServer());
39         ASSERT_TRUE(backend2ForUdpProbe.startServer());
40     }
41 
42     void SetUp() {
43         mPdc.setObserver(&mObserver);
44         mPdc.mBackoffBuilder.withInitialRetransmissionTime(std::chrono::seconds(1))
45                 .withMaximumRetransmissionTime(std::chrono::seconds(1));
46 
47         // The default and sole action when the observer is notified of onValidationStateUpdate.
48         // Don't override the action. In other words, don't use WillOnce() or WillRepeatedly()
49         // when mObserver.onValidationStateUpdate is expected to be called, like:
50         //
51         //   EXPECT_CALL(mObserver, onValidationStateUpdate).WillOnce(Return());
52         //
53         // This is to ensure that tests can monitor how many validation threads are running. Tests
54         // must wait until every validation thread finishes.
55         ON_CALL(mObserver, onValidationStateUpdate)
56                 .WillByDefault([&](const std::string& server, Validation validation, uint32_t) {
57                     if (validation == Validation::in_process) {
58                         std::lock_guard guard(mObserver.lock);
59                         auto it = mObserver.serverStateMap.find(server);
60                         if (it == mObserver.serverStateMap.end() ||
61                             it->second != Validation::in_process) {
62                             // Increment runningThreads only when receive the first in_process
63                             // notification. The rest of the continuous in_process notifications
64                             // are due to probe retry which runs on the same thread.
65                             // TODO: consider adding onValidationThreadStart() and
66                             // onValidationThreadEnd() callbacks.
67                             mObserver.runningThreads++;
68                         }
69                     } else if (validation == Validation::success ||
70                                validation == Validation::fail) {
71                         mObserver.runningThreads--;
72                     }
73                     std::lock_guard guard(mObserver.lock);
74                     mObserver.serverStateMap[server] = validation;
75                 });
76     }
77 
78   protected:
79     class MockObserver : public PrivateDnsValidationObserver {
80       public:
81         MOCK_METHOD(void, onValidationStateUpdate,
82                     (const std::string& serverIp, Validation validation, uint32_t netId),
83                     (override));
84 
85         std::map<std::string, Validation> getServerStateMap() const {
86             std::lock_guard guard(lock);
87             return serverStateMap;
88         }
89 
90         void removeFromServerStateMap(const std::string& server) {
91             std::lock_guard guard(lock);
92             if (const auto it = serverStateMap.find(server); it != serverStateMap.end())
93                 serverStateMap.erase(it);
94         }
95 
96         // The current number of validation threads running.
97         std::atomic<int> runningThreads = 0;
98 
99         mutable std::mutex lock;
100         std::map<std::string, Validation> serverStateMap GUARDED_BY(lock);
101     };
102 
103     void expectPrivateDnsStatus(PrivateDnsMode mode) {
104         // Use PollForCondition because mObserver is notified asynchronously.
105         EXPECT_TRUE(PollForCondition([&]() { return checkPrivateDnsStatus(mode); }));
106     }
107 
108     bool checkPrivateDnsStatus(PrivateDnsMode mode) {
109         const PrivateDnsStatus status = mPdc.getStatus(kNetId);
110         if (status.mode != mode) return false;
111 
112         std::map<std::string, Validation> serverStateMap;
113         for (const auto& [server, validation] : status.serversMap) {
114             serverStateMap[ToString(&server.ss)] = validation;
115         }
116         return (serverStateMap == mObserver.getServerStateMap());
117     }
118 
119     bool hasPrivateDnsServer(const ServerIdentity& identity, unsigned netId) {
120         return mPdc.getPrivateDns(identity, netId).ok();
121     }
122 
123     static constexpr uint32_t kNetId = 30;
124     static constexpr uint32_t kMark = 30;
125     static constexpr char kBackend[] = "127.0.2.1";
126     static constexpr char kServer1[] = "127.0.2.2";
127     static constexpr char kServer2[] = "127.0.2.3";
128 
129     MockObserver mObserver;
130     PrivateDnsConfiguration mPdc;
131 
132     // TODO: Because incorrect CAs result in validation failed in strict mode, have
133     // PrivateDnsConfiguration run mocked code rather than DnsTlsTransport::validate().
134     inline static test::DnsTlsFrontend tls1{kServer1, "853", kBackend, "53"};
135     inline static test::DnsTlsFrontend tls2{kServer2, "853", kBackend, "53"};
136     inline static test::DNSResponder backend{kBackend, "53"};
137     inline static test::DNSResponder backend1ForUdpProbe{kServer1, "53"};
138     inline static test::DNSResponder backend2ForUdpProbe{kServer2, "53"};
139 };
140 
141 TEST_F(PrivateDnsConfigurationTest, ValidationSuccess) {
142     testing::InSequence seq;
143     EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::in_process, kNetId));
144     EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId));
145 
146     EXPECT_EQ(mPdc.set(kNetId, kMark, {kServer1}, {}, {}), 0);
147     expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
148 
149     ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
150 }
151 
152 TEST_F(PrivateDnsConfigurationTest, ValidationFail_Opportunistic) {
153     ASSERT_TRUE(backend.stopServer());
154 
155     testing::InSequence seq;
156     EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::in_process, kNetId));
157     EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::fail, kNetId));
158 
159     EXPECT_EQ(mPdc.set(kNetId, kMark, {kServer1}, {}, {}), 0);
160     expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
161 
162     // Strictly wait for all of the validation finish; otherwise, the test can crash somehow.
163     ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
164     ASSERT_TRUE(backend.startServer());
165 }
166 
167 TEST_F(PrivateDnsConfigurationTest, Revalidation_Opportunistic) {
168     const DnsTlsServer server(netdutils::IPSockAddr::toIPSockAddr(kServer1, 853));
169 
170     // Step 1: Set up and wait for validation complete.
171     testing::InSequence seq;
172     EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::in_process, kNetId));
173     EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId));
174 
175     EXPECT_EQ(mPdc.set(kNetId, kMark, {kServer1}, {}, {}), 0);
176     expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
177     ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
178 
179     // Step 2: Simulate the DNS is temporarily broken, and then request a validation.
180     // Expect the validation to run as follows:
181     //   1. DnsResolver notifies of Validation::in_process when the validation is about to run.
182     //   2. The first probing fails. DnsResolver notifies of Validation::in_process.
183     //   3. One second later, the second probing begins and succeeds. DnsResolver notifies of
184     //      Validation::success.
185     EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::in_process, kNetId))
186             .Times(2);
187     EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId));
188 
189     std::thread t([] {
190         std::this_thread::sleep_for(1000ms);
191         backend.startServer();
192     });
193     backend.stopServer();
194     EXPECT_TRUE(mPdc.requestValidation(kNetId, ServerIdentity(server), kMark).ok());
195 
196     t.join();
197     expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
198     ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
199 }
200 
201 TEST_F(PrivateDnsConfigurationTest, ValidationBlock) {
202     backend.setDeferredResp(true);
203 
204     // onValidationStateUpdate() is called in sequence.
205     {
206         testing::InSequence seq;
207         EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::in_process, kNetId));
208         EXPECT_EQ(mPdc.set(kNetId, kMark, {kServer1}, {}, {}), 0);
209         ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 1; }));
210         expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
211 
212         EXPECT_CALL(mObserver, onValidationStateUpdate(kServer2, Validation::in_process, kNetId));
213         EXPECT_EQ(mPdc.set(kNetId, kMark, {kServer2}, {}, {}), 0);
214         ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 2; }));
215         mObserver.removeFromServerStateMap(kServer1);
216         expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
217 
218         // No duplicate validation as long as not in OFF mode; otherwise, an unexpected
219         // onValidationStateUpdate() will be caught.
220         EXPECT_EQ(mPdc.set(kNetId, kMark, {kServer1}, {}, {}), 0);
221         EXPECT_EQ(mPdc.set(kNetId, kMark, {kServer1, kServer2}, {}, {}), 0);
222         EXPECT_EQ(mPdc.set(kNetId, kMark, {kServer2}, {}, {}), 0);
223         expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
224 
225         // The status keeps unchanged if pass invalid arguments.
226         EXPECT_EQ(mPdc.set(kNetId, kMark, {"invalid_addr"}, {}, {}), -EINVAL);
227         expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
228     }
229 
230     // The update for |kServer1| will be Validation::fail because |kServer1| is not an expected
231     // server for the network.
232     EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::fail, kNetId));
233     EXPECT_CALL(mObserver, onValidationStateUpdate(kServer2, Validation::success, kNetId));
234     backend.setDeferredResp(false);
235 
236     ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
237 
238     // kServer1 is not a present server and thus should not be available from
239     // PrivateDnsConfiguration::getStatus().
240     mObserver.removeFromServerStateMap(kServer1);
241 
242     expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
243 }
244 
245 TEST_F(PrivateDnsConfigurationTest, Validation_NetworkDestroyedOrOffMode) {
246     for (const std::string_view config : {"OFF", "NETWORK_DESTROYED"}) {
247         SCOPED_TRACE(config);
248         backend.setDeferredResp(true);
249 
250         testing::InSequence seq;
251         EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::in_process, kNetId));
252         EXPECT_EQ(mPdc.set(kNetId, kMark, {kServer1}, {}, {}), 0);
253         ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 1; }));
254         expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
255 
256         if (config == "OFF") {
257             EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {}, {}), 0);
258         } else if (config == "NETWORK_DESTROYED") {
259             mPdc.clear(kNetId);
260         }
261 
262         EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::fail, kNetId));
263         backend.setDeferredResp(false);
264 
265         ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
266         mObserver.removeFromServerStateMap(kServer1);
267         expectPrivateDnsStatus(PrivateDnsMode::OFF);
268     }
269 }
270 
271 TEST_F(PrivateDnsConfigurationTest, NoValidation) {
272     // If onValidationStateUpdate() is called, the test will fail with uninteresting mock
273     // function calls in the end of the test.
274 
275     const auto expectStatus = [&]() {
276         const PrivateDnsStatus status = mPdc.getStatus(kNetId);
277         EXPECT_EQ(status.mode, PrivateDnsMode::OFF);
278         EXPECT_THAT(status.serversMap, testing::IsEmpty());
279     };
280 
281     EXPECT_EQ(mPdc.set(kNetId, kMark, {"invalid_addr"}, {}, {}), -EINVAL);
282     expectStatus();
283 
284     EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {}, {}), 0);
285     expectStatus();
286 }
287 
288 TEST_F(PrivateDnsConfigurationTest, ServerIdentity_Comparison) {
289     DnsTlsServer server(netdutils::IPSockAddr::toIPSockAddr("127.0.0.1", 853));
290     server.name = "dns.example.com";
291 
292     // Different socket address.
293     DnsTlsServer other = server;
294     EXPECT_EQ(ServerIdentity(server), ServerIdentity(other));
295     other.ss = netdutils::IPSockAddr::toIPSockAddr("127.0.0.1", 5353);
296     EXPECT_NE(ServerIdentity(server), ServerIdentity(other));
297     other.ss = netdutils::IPSockAddr::toIPSockAddr("127.0.0.2", 853);
298     EXPECT_NE(ServerIdentity(server), ServerIdentity(other));
299 
300     // Different provider hostname.
301     other = server;
302     EXPECT_EQ(ServerIdentity(server), ServerIdentity(other));
303     other.name = "other.example.com";
304     EXPECT_NE(ServerIdentity(server), ServerIdentity(other));
305     other.name = "";
306     EXPECT_NE(ServerIdentity(server), ServerIdentity(other));
307 }
308 
309 TEST_F(PrivateDnsConfigurationTest, RequestValidation) {
310     const DnsTlsServer server(netdutils::IPSockAddr::toIPSockAddr(kServer1, 853));
311     const ServerIdentity identity(server);
312 
313     testing::InSequence seq;
314 
315     for (const std::string_view config : {"SUCCESS", "IN_PROGRESS", "FAIL"}) {
316         SCOPED_TRACE(config);
317 
318         EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::in_process, kNetId));
319         if (config == "SUCCESS") {
320             EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId));
321         } else if (config == "IN_PROGRESS") {
322             backend.setDeferredResp(true);
323         } else {
324             // config = "FAIL"
325             ASSERT_TRUE(backend.stopServer());
326             EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::fail, kNetId));
327         }
328         EXPECT_EQ(mPdc.set(kNetId, kMark, {kServer1}, {}, {}), 0);
329         expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
330 
331         // Wait until the validation state is transitioned.
332         const int runningThreads = (config == "IN_PROGRESS") ? 1 : 0;
333         ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == runningThreads; }));
334 
335         if (config == "SUCCESS") {
336             EXPECT_CALL(mObserver,
337                         onValidationStateUpdate(kServer1, Validation::in_process, kNetId));
338             EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId));
339             EXPECT_TRUE(mPdc.requestValidation(kNetId, identity, kMark).ok());
340         } else if (config == "IN_PROGRESS") {
341             EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId));
342             EXPECT_FALSE(mPdc.requestValidation(kNetId, identity, kMark).ok());
343         } else if (config == "FAIL") {
344             EXPECT_FALSE(mPdc.requestValidation(kNetId, identity, kMark).ok());
345         }
346 
347         // Resending the same request or requesting nonexistent servers are denied.
348         EXPECT_FALSE(mPdc.requestValidation(kNetId, identity, kMark).ok());
349         EXPECT_FALSE(mPdc.requestValidation(kNetId, identity, kMark + 1).ok());
350         EXPECT_FALSE(mPdc.requestValidation(kNetId + 1, identity, kMark).ok());
351 
352         // Reset the test state.
353         backend.setDeferredResp(false);
354         backend.startServer();
355 
356         // Ensure the status of mObserver is synced.
357         expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
358 
359         ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
360         mPdc.clear(kNetId);
361     }
362 }
363 
364 TEST_F(PrivateDnsConfigurationTest, GetPrivateDns) {
365     const DnsTlsServer server1(netdutils::IPSockAddr::toIPSockAddr(kServer1, 853));
366     const DnsTlsServer server2(netdutils::IPSockAddr::toIPSockAddr(kServer2, 853));
367 
368     EXPECT_FALSE(hasPrivateDnsServer(ServerIdentity(server1), kNetId));
369     EXPECT_FALSE(hasPrivateDnsServer(ServerIdentity(server2), kNetId));
370 
371     // Suppress the warning.
372     EXPECT_CALL(mObserver, onValidationStateUpdate).Times(2);
373 
374     EXPECT_EQ(mPdc.set(kNetId, kMark, {kServer1}, {}, {}), 0);
375     expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
376 
377     EXPECT_TRUE(hasPrivateDnsServer(ServerIdentity(server1), kNetId));
378     EXPECT_FALSE(hasPrivateDnsServer(ServerIdentity(server2), kNetId));
379     EXPECT_FALSE(hasPrivateDnsServer(ServerIdentity(server1), kNetId + 1));
380 
381     ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
382 }
383 
384 // TODO: add ValidationFail_Strict test.
385 
386 }  // namespace android::net
387