1 /* 2 * Copyright (C) 2020 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include <gmock/gmock.h> 18 #include <gtest/gtest.h> 19 20 #include "PrivateDnsConfiguration.h" 21 #include "tests/dns_responder/dns_responder.h" 22 #include "tests/dns_responder/dns_tls_frontend.h" 23 #include "tests/resolv_test_utils.h" 24 25 namespace android::net { 26 27 using namespace std::chrono_literals; 28 29 class PrivateDnsConfigurationTest : public ::testing::Test { 30 public: 31 using ServerIdentity = PrivateDnsConfiguration::ServerIdentity; 32 33 static void SetUpTestSuite() { 34 // stopServer() will be called in their destructor. 35 ASSERT_TRUE(tls1.startServer()); 36 ASSERT_TRUE(tls2.startServer()); 37 ASSERT_TRUE(backend.startServer()); 38 ASSERT_TRUE(backend1ForUdpProbe.startServer()); 39 ASSERT_TRUE(backend2ForUdpProbe.startServer()); 40 } 41 42 void SetUp() { 43 mPdc.setObserver(&mObserver); 44 mPdc.mBackoffBuilder.withInitialRetransmissionTime(std::chrono::seconds(1)) 45 .withMaximumRetransmissionTime(std::chrono::seconds(1)); 46 47 // The default and sole action when the observer is notified of onValidationStateUpdate. 48 // Don't override the action. In other words, don't use WillOnce() or WillRepeatedly() 49 // when mObserver.onValidationStateUpdate is expected to be called, like: 50 // 51 // EXPECT_CALL(mObserver, onValidationStateUpdate).WillOnce(Return()); 52 // 53 // This is to ensure that tests can monitor how many validation threads are running. Tests 54 // must wait until every validation thread finishes. 55 ON_CALL(mObserver, onValidationStateUpdate) 56 .WillByDefault([&](const std::string& server, Validation validation, uint32_t) { 57 if (validation == Validation::in_process) { 58 std::lock_guard guard(mObserver.lock); 59 auto it = mObserver.serverStateMap.find(server); 60 if (it == mObserver.serverStateMap.end() || 61 it->second != Validation::in_process) { 62 // Increment runningThreads only when receive the first in_process 63 // notification. The rest of the continuous in_process notifications 64 // are due to probe retry which runs on the same thread. 65 // TODO: consider adding onValidationThreadStart() and 66 // onValidationThreadEnd() callbacks. 67 mObserver.runningThreads++; 68 } 69 } else if (validation == Validation::success || 70 validation == Validation::fail) { 71 mObserver.runningThreads--; 72 } 73 std::lock_guard guard(mObserver.lock); 74 mObserver.serverStateMap[server] = validation; 75 }); 76 } 77 78 protected: 79 class MockObserver : public PrivateDnsValidationObserver { 80 public: 81 MOCK_METHOD(void, onValidationStateUpdate, 82 (const std::string& serverIp, Validation validation, uint32_t netId), 83 (override)); 84 85 std::map<std::string, Validation> getServerStateMap() const { 86 std::lock_guard guard(lock); 87 return serverStateMap; 88 } 89 90 void removeFromServerStateMap(const std::string& server) { 91 std::lock_guard guard(lock); 92 if (const auto it = serverStateMap.find(server); it != serverStateMap.end()) 93 serverStateMap.erase(it); 94 } 95 96 // The current number of validation threads running. 97 std::atomic<int> runningThreads = 0; 98 99 mutable std::mutex lock; 100 std::map<std::string, Validation> serverStateMap GUARDED_BY(lock); 101 }; 102 103 void expectPrivateDnsStatus(PrivateDnsMode mode) { 104 // Use PollForCondition because mObserver is notified asynchronously. 105 EXPECT_TRUE(PollForCondition([&]() { return checkPrivateDnsStatus(mode); })); 106 } 107 108 bool checkPrivateDnsStatus(PrivateDnsMode mode) { 109 const PrivateDnsStatus status = mPdc.getStatus(kNetId); 110 if (status.mode != mode) return false; 111 112 std::map<std::string, Validation> serverStateMap; 113 for (const auto& [server, validation] : status.serversMap) { 114 serverStateMap[ToString(&server.ss)] = validation; 115 } 116 return (serverStateMap == mObserver.getServerStateMap()); 117 } 118 119 bool hasPrivateDnsServer(const ServerIdentity& identity, unsigned netId) { 120 return mPdc.getPrivateDns(identity, netId).ok(); 121 } 122 123 static constexpr uint32_t kNetId = 30; 124 static constexpr uint32_t kMark = 30; 125 static constexpr char kBackend[] = "127.0.2.1"; 126 static constexpr char kServer1[] = "127.0.2.2"; 127 static constexpr char kServer2[] = "127.0.2.3"; 128 129 MockObserver mObserver; 130 PrivateDnsConfiguration mPdc; 131 132 // TODO: Because incorrect CAs result in validation failed in strict mode, have 133 // PrivateDnsConfiguration run mocked code rather than DnsTlsTransport::validate(). 134 inline static test::DnsTlsFrontend tls1{kServer1, "853", kBackend, "53"}; 135 inline static test::DnsTlsFrontend tls2{kServer2, "853", kBackend, "53"}; 136 inline static test::DNSResponder backend{kBackend, "53"}; 137 inline static test::DNSResponder backend1ForUdpProbe{kServer1, "53"}; 138 inline static test::DNSResponder backend2ForUdpProbe{kServer2, "53"}; 139 }; 140 141 TEST_F(PrivateDnsConfigurationTest, ValidationSuccess) { 142 testing::InSequence seq; 143 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::in_process, kNetId)); 144 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId)); 145 146 EXPECT_EQ(mPdc.set(kNetId, kMark, {kServer1}, {}, {}), 0); 147 expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC); 148 149 ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; })); 150 } 151 152 TEST_F(PrivateDnsConfigurationTest, ValidationFail_Opportunistic) { 153 ASSERT_TRUE(backend.stopServer()); 154 155 testing::InSequence seq; 156 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::in_process, kNetId)); 157 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::fail, kNetId)); 158 159 EXPECT_EQ(mPdc.set(kNetId, kMark, {kServer1}, {}, {}), 0); 160 expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC); 161 162 // Strictly wait for all of the validation finish; otherwise, the test can crash somehow. 163 ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; })); 164 ASSERT_TRUE(backend.startServer()); 165 } 166 167 TEST_F(PrivateDnsConfigurationTest, Revalidation_Opportunistic) { 168 const DnsTlsServer server(netdutils::IPSockAddr::toIPSockAddr(kServer1, 853)); 169 170 // Step 1: Set up and wait for validation complete. 171 testing::InSequence seq; 172 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::in_process, kNetId)); 173 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId)); 174 175 EXPECT_EQ(mPdc.set(kNetId, kMark, {kServer1}, {}, {}), 0); 176 expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC); 177 ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; })); 178 179 // Step 2: Simulate the DNS is temporarily broken, and then request a validation. 180 // Expect the validation to run as follows: 181 // 1. DnsResolver notifies of Validation::in_process when the validation is about to run. 182 // 2. The first probing fails. DnsResolver notifies of Validation::in_process. 183 // 3. One second later, the second probing begins and succeeds. DnsResolver notifies of 184 // Validation::success. 185 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::in_process, kNetId)) 186 .Times(2); 187 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId)); 188 189 std::thread t([] { 190 std::this_thread::sleep_for(1000ms); 191 backend.startServer(); 192 }); 193 backend.stopServer(); 194 EXPECT_TRUE(mPdc.requestValidation(kNetId, ServerIdentity(server), kMark).ok()); 195 196 t.join(); 197 expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC); 198 ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; })); 199 } 200 201 TEST_F(PrivateDnsConfigurationTest, ValidationBlock) { 202 backend.setDeferredResp(true); 203 204 // onValidationStateUpdate() is called in sequence. 205 { 206 testing::InSequence seq; 207 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::in_process, kNetId)); 208 EXPECT_EQ(mPdc.set(kNetId, kMark, {kServer1}, {}, {}), 0); 209 ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 1; })); 210 expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC); 211 212 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer2, Validation::in_process, kNetId)); 213 EXPECT_EQ(mPdc.set(kNetId, kMark, {kServer2}, {}, {}), 0); 214 ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 2; })); 215 mObserver.removeFromServerStateMap(kServer1); 216 expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC); 217 218 // No duplicate validation as long as not in OFF mode; otherwise, an unexpected 219 // onValidationStateUpdate() will be caught. 220 EXPECT_EQ(mPdc.set(kNetId, kMark, {kServer1}, {}, {}), 0); 221 EXPECT_EQ(mPdc.set(kNetId, kMark, {kServer1, kServer2}, {}, {}), 0); 222 EXPECT_EQ(mPdc.set(kNetId, kMark, {kServer2}, {}, {}), 0); 223 expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC); 224 225 // The status keeps unchanged if pass invalid arguments. 226 EXPECT_EQ(mPdc.set(kNetId, kMark, {"invalid_addr"}, {}, {}), -EINVAL); 227 expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC); 228 } 229 230 // The update for |kServer1| will be Validation::fail because |kServer1| is not an expected 231 // server for the network. 232 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::fail, kNetId)); 233 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer2, Validation::success, kNetId)); 234 backend.setDeferredResp(false); 235 236 ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; })); 237 238 // kServer1 is not a present server and thus should not be available from 239 // PrivateDnsConfiguration::getStatus(). 240 mObserver.removeFromServerStateMap(kServer1); 241 242 expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC); 243 } 244 245 TEST_F(PrivateDnsConfigurationTest, Validation_NetworkDestroyedOrOffMode) { 246 for (const std::string_view config : {"OFF", "NETWORK_DESTROYED"}) { 247 SCOPED_TRACE(config); 248 backend.setDeferredResp(true); 249 250 testing::InSequence seq; 251 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::in_process, kNetId)); 252 EXPECT_EQ(mPdc.set(kNetId, kMark, {kServer1}, {}, {}), 0); 253 ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 1; })); 254 expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC); 255 256 if (config == "OFF") { 257 EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {}, {}), 0); 258 } else if (config == "NETWORK_DESTROYED") { 259 mPdc.clear(kNetId); 260 } 261 262 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::fail, kNetId)); 263 backend.setDeferredResp(false); 264 265 ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; })); 266 mObserver.removeFromServerStateMap(kServer1); 267 expectPrivateDnsStatus(PrivateDnsMode::OFF); 268 } 269 } 270 271 TEST_F(PrivateDnsConfigurationTest, NoValidation) { 272 // If onValidationStateUpdate() is called, the test will fail with uninteresting mock 273 // function calls in the end of the test. 274 275 const auto expectStatus = [&]() { 276 const PrivateDnsStatus status = mPdc.getStatus(kNetId); 277 EXPECT_EQ(status.mode, PrivateDnsMode::OFF); 278 EXPECT_THAT(status.serversMap, testing::IsEmpty()); 279 }; 280 281 EXPECT_EQ(mPdc.set(kNetId, kMark, {"invalid_addr"}, {}, {}), -EINVAL); 282 expectStatus(); 283 284 EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {}, {}), 0); 285 expectStatus(); 286 } 287 288 TEST_F(PrivateDnsConfigurationTest, ServerIdentity_Comparison) { 289 DnsTlsServer server(netdutils::IPSockAddr::toIPSockAddr("127.0.0.1", 853)); 290 server.name = "dns.example.com"; 291 292 // Different socket address. 293 DnsTlsServer other = server; 294 EXPECT_EQ(ServerIdentity(server), ServerIdentity(other)); 295 other.ss = netdutils::IPSockAddr::toIPSockAddr("127.0.0.1", 5353); 296 EXPECT_NE(ServerIdentity(server), ServerIdentity(other)); 297 other.ss = netdutils::IPSockAddr::toIPSockAddr("127.0.0.2", 853); 298 EXPECT_NE(ServerIdentity(server), ServerIdentity(other)); 299 300 // Different provider hostname. 301 other = server; 302 EXPECT_EQ(ServerIdentity(server), ServerIdentity(other)); 303 other.name = "other.example.com"; 304 EXPECT_NE(ServerIdentity(server), ServerIdentity(other)); 305 other.name = ""; 306 EXPECT_NE(ServerIdentity(server), ServerIdentity(other)); 307 } 308 309 TEST_F(PrivateDnsConfigurationTest, RequestValidation) { 310 const DnsTlsServer server(netdutils::IPSockAddr::toIPSockAddr(kServer1, 853)); 311 const ServerIdentity identity(server); 312 313 testing::InSequence seq; 314 315 for (const std::string_view config : {"SUCCESS", "IN_PROGRESS", "FAIL"}) { 316 SCOPED_TRACE(config); 317 318 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::in_process, kNetId)); 319 if (config == "SUCCESS") { 320 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId)); 321 } else if (config == "IN_PROGRESS") { 322 backend.setDeferredResp(true); 323 } else { 324 // config = "FAIL" 325 ASSERT_TRUE(backend.stopServer()); 326 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::fail, kNetId)); 327 } 328 EXPECT_EQ(mPdc.set(kNetId, kMark, {kServer1}, {}, {}), 0); 329 expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC); 330 331 // Wait until the validation state is transitioned. 332 const int runningThreads = (config == "IN_PROGRESS") ? 1 : 0; 333 ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == runningThreads; })); 334 335 if (config == "SUCCESS") { 336 EXPECT_CALL(mObserver, 337 onValidationStateUpdate(kServer1, Validation::in_process, kNetId)); 338 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId)); 339 EXPECT_TRUE(mPdc.requestValidation(kNetId, identity, kMark).ok()); 340 } else if (config == "IN_PROGRESS") { 341 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId)); 342 EXPECT_FALSE(mPdc.requestValidation(kNetId, identity, kMark).ok()); 343 } else if (config == "FAIL") { 344 EXPECT_FALSE(mPdc.requestValidation(kNetId, identity, kMark).ok()); 345 } 346 347 // Resending the same request or requesting nonexistent servers are denied. 348 EXPECT_FALSE(mPdc.requestValidation(kNetId, identity, kMark).ok()); 349 EXPECT_FALSE(mPdc.requestValidation(kNetId, identity, kMark + 1).ok()); 350 EXPECT_FALSE(mPdc.requestValidation(kNetId + 1, identity, kMark).ok()); 351 352 // Reset the test state. 353 backend.setDeferredResp(false); 354 backend.startServer(); 355 356 // Ensure the status of mObserver is synced. 357 expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC); 358 359 ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; })); 360 mPdc.clear(kNetId); 361 } 362 } 363 364 TEST_F(PrivateDnsConfigurationTest, GetPrivateDns) { 365 const DnsTlsServer server1(netdutils::IPSockAddr::toIPSockAddr(kServer1, 853)); 366 const DnsTlsServer server2(netdutils::IPSockAddr::toIPSockAddr(kServer2, 853)); 367 368 EXPECT_FALSE(hasPrivateDnsServer(ServerIdentity(server1), kNetId)); 369 EXPECT_FALSE(hasPrivateDnsServer(ServerIdentity(server2), kNetId)); 370 371 // Suppress the warning. 372 EXPECT_CALL(mObserver, onValidationStateUpdate).Times(2); 373 374 EXPECT_EQ(mPdc.set(kNetId, kMark, {kServer1}, {}, {}), 0); 375 expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC); 376 377 EXPECT_TRUE(hasPrivateDnsServer(ServerIdentity(server1), kNetId)); 378 EXPECT_FALSE(hasPrivateDnsServer(ServerIdentity(server2), kNetId)); 379 EXPECT_FALSE(hasPrivateDnsServer(ServerIdentity(server1), kNetId + 1)); 380 381 ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; })); 382 } 383 384 // TODO: add ValidationFail_Strict test. 385 386 } // namespace android::net 387