1 // Copyright 2020 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "discovery/mdns/mdns_probe_manager.h"
6
7 #include <utility>
8
9 #include "discovery/common/config.h"
10 #include "discovery/mdns/mdns_probe.h"
11 #include "discovery/mdns/mdns_querier.h"
12 #include "discovery/mdns/mdns_random.h"
13 #include "discovery/mdns/mdns_receiver.h"
14 #include "discovery/mdns/mdns_sender.h"
15 #include "gmock/gmock.h"
16 #include "gtest/gtest.h"
17 #include "platform/test/fake_clock.h"
18 #include "platform/test/fake_task_runner.h"
19 #include "platform/test/fake_udp_socket.h"
20
21 using testing::_;
22 using testing::Invoke;
23 using testing::Return;
24 using testing::StrictMock;
25
26 namespace openscreen {
27 namespace discovery {
28
29 class MockDomainConfirmedProvider : public MdnsDomainConfirmedProvider {
30 public:
31 MOCK_METHOD2(OnDomainFound, void(const DomainName&, const DomainName&));
32 };
33
34 class MockMdnsSender : public MdnsSender {
35 public:
MockMdnsSender(UdpSocket * socket)36 explicit MockMdnsSender(UdpSocket* socket) : MdnsSender(socket) {}
37
38 MOCK_METHOD1(SendMulticast, Error(const MdnsMessage& message));
39 MOCK_METHOD2(SendMessage,
40 Error(const MdnsMessage& message, const IPEndpoint& endpoint));
41 };
42
43 class MockMdnsProbe : public MdnsProbe {
44 public:
MockMdnsProbe(DomainName target_name,IPAddress address)45 MockMdnsProbe(DomainName target_name, IPAddress address)
46 : MdnsProbe(std::move(target_name), std::move(address)) {}
47
48 MOCK_METHOD1(Postpone, void(std::chrono::seconds));
49 MOCK_METHOD1(OnMessageReceived, void(const MdnsMessage&));
50 };
51
52 class TestMdnsProbeManager : public MdnsProbeManagerImpl {
53 public:
54 using MdnsProbeManagerImpl::MdnsProbeManagerImpl;
55
56 using MdnsProbeManagerImpl::OnProbeFailure;
57 using MdnsProbeManagerImpl::OnProbeSuccess;
58
CreateProbe(DomainName name,IPAddress address)59 std::unique_ptr<MdnsProbe> CreateProbe(DomainName name,
60 IPAddress address) override {
61 return std::make_unique<StrictMock<MockMdnsProbe>>(std::move(name),
62 std::move(address));
63 }
64
GetOngoingMockProbeByTarget(const DomainName & target)65 StrictMock<MockMdnsProbe>* GetOngoingMockProbeByTarget(
66 const DomainName& target) {
67 const auto it =
68 std::find_if(ongoing_probes_.begin(), ongoing_probes_.end(),
69 [&target](const OngoingProbe& ongoing) {
70 return ongoing.probe->target_name() == target;
71 });
72 if (it != ongoing_probes_.end()) {
73 return static_cast<StrictMock<MockMdnsProbe>*>(it->probe.get());
74 }
75 return nullptr;
76 }
77
GetCompletedMockProbe(const DomainName & target)78 StrictMock<MockMdnsProbe>* GetCompletedMockProbe(const DomainName& target) {
79 const auto it = FindCompletedProbe(target);
80 if (it != completed_probes_.end()) {
81 return static_cast<StrictMock<MockMdnsProbe>*>(it->get());
82 }
83 return nullptr;
84 }
85
HasOngoingProbe(const DomainName & target)86 bool HasOngoingProbe(const DomainName& target) {
87 return GetOngoingMockProbeByTarget(target) != nullptr;
88 }
89
HasCompletedProbe(const DomainName & target)90 bool HasCompletedProbe(const DomainName& target) {
91 return GetCompletedMockProbe(target) != nullptr;
92 }
93
GetOngoingProbeCount()94 size_t GetOngoingProbeCount() { return ongoing_probes_.size(); }
95
GetCompletedProbeCount()96 size_t GetCompletedProbeCount() { return completed_probes_.size(); }
97 };
98
99 class MdnsProbeManagerTests : public testing::Test {
100 public:
MdnsProbeManagerTests()101 MdnsProbeManagerTests()
102 : clock_(Clock::now()),
103 task_runner_(&clock_),
104 socket_(&task_runner_),
105 sender_(&socket_),
106 receiver_(config_),
107 manager_(&sender_,
108 &receiver_,
109 &random_,
110 &task_runner_,
111 FakeClock::now) {
112 ExpectProbeStopped(name_);
113 ExpectProbeStopped(name2_);
114 ExpectProbeStopped(name_retry_);
115 }
116
117 protected:
CreateProbeQueryMessage(DomainName domain,const IPAddress & address)118 MdnsMessage CreateProbeQueryMessage(DomainName domain,
119 const IPAddress& address) {
120 MdnsMessage message(CreateMessageId(), MessageType::Query);
121 MdnsQuestion question(domain, DnsType::kANY, DnsClass::kANY,
122 ResponseType::kUnicast);
123 MdnsRecord record = CreateAddressRecord(std::move(domain), address);
124 message.AddQuestion(std::move(question));
125 message.AddAuthorityRecord(std::move(record));
126 return message;
127 }
128
ExpectProbeStopped(const DomainName & name)129 void ExpectProbeStopped(const DomainName& name) {
130 EXPECT_FALSE(manager_.HasOngoingProbe(name));
131 EXPECT_FALSE(manager_.HasCompletedProbe(name));
132 EXPECT_FALSE(manager_.IsDomainClaimed(name));
133 }
134
ExpectProbeOngoing(const DomainName & name)135 StrictMock<MockMdnsProbe>* ExpectProbeOngoing(const DomainName& name) {
136 // Get around limitations of using an assert in a function with a return
137 // value.
138 auto validate = [this, &name]() {
139 ASSERT_TRUE(manager_.HasOngoingProbe(name));
140 EXPECT_FALSE(manager_.HasCompletedProbe(name));
141 EXPECT_FALSE(manager_.IsDomainClaimed(name));
142 };
143 validate();
144
145 return manager_.GetOngoingMockProbeByTarget(name);
146 }
147
ExpectProbeCompleted(const DomainName & name)148 StrictMock<MockMdnsProbe>* ExpectProbeCompleted(const DomainName& name) {
149 // Get around limitations of using an assert in a function with a return
150 // value.
151 auto validate = [this, &name]() {
152 EXPECT_FALSE(manager_.HasOngoingProbe(name));
153 ASSERT_TRUE(manager_.HasCompletedProbe(name));
154 EXPECT_TRUE(manager_.IsDomainClaimed(name));
155 };
156 validate();
157
158 return manager_.GetCompletedMockProbe(name);
159 }
160
SetUpCompletedProbe(const DomainName & name,const IPAddress & address)161 StrictMock<MockMdnsProbe>* SetUpCompletedProbe(const DomainName& name,
162 const IPAddress& address) {
163 EXPECT_TRUE(manager_.StartProbe(&callback_, name, address).ok());
164 EXPECT_CALL(callback_, OnDomainFound(name, name));
165 StrictMock<MockMdnsProbe>* ongoing_probe = ExpectProbeOngoing(name);
166 manager_.OnProbeSuccess(ongoing_probe);
167 ExpectProbeCompleted(name);
168 testing::Mock::VerifyAndClearExpectations(ongoing_probe);
169
170 return ongoing_probe;
171 }
172
173 Config config_;
174 FakeClock clock_;
175 FakeTaskRunner task_runner_;
176 FakeUdpSocket socket_;
177 StrictMock<MockMdnsSender> sender_;
178 MdnsReceiver receiver_;
179 MdnsRandom random_;
180 StrictMock<TestMdnsProbeManager> manager_;
181 MockDomainConfirmedProvider callback_;
182
183 const DomainName name_{"test", "_googlecast", "_tcp", "local"};
184 const DomainName name_retry_{"test1", "_googlecast", "_tcp", "local"};
185 const DomainName name2_{"test2", "_googlecast", "_tcp", "local"};
186
187 // When used to create address records A, B, C, A > B because comparison of
188 // the rdata in each results in the comparison of endpoints, for which
189 // address_b_ < address_a_. A < C because A is DnsType kA with value 1 and
190 // C is DnsType kAAAA with value 28.
191 const IPAddress address_a_{192, 168, 0, 0};
192 const IPAddress address_b_{190, 160, 0, 0};
193 const IPAddress address_c_{0x0102, 0x0304, 0x0506, 0x0708,
194 0x090a, 0x0b0c, 0x0d0e, 0x0f10};
195 const IPEndpoint endpoint_{{192, 168, 0, 0}, 80};
196 };
197
TEST_F(MdnsProbeManagerTests,StartProbeBeginsProbeWhenNoneExistsOnly)198 TEST_F(MdnsProbeManagerTests, StartProbeBeginsProbeWhenNoneExistsOnly) {
199 EXPECT_TRUE(manager_.StartProbe(&callback_, name_, address_a_).ok());
200 ExpectProbeOngoing(name_);
201 EXPECT_FALSE(manager_.IsDomainClaimed(name2_));
202
203 EXPECT_FALSE(manager_.StartProbe(&callback_, name_, address_a_).ok());
204
205 EXPECT_CALL(callback_, OnDomainFound(name_, name_));
206 StrictMock<MockMdnsProbe>* ongoing_probe = ExpectProbeOngoing(name_);
207 manager_.OnProbeSuccess(ongoing_probe);
208 EXPECT_FALSE(manager_.IsDomainClaimed(name2_));
209 testing::Mock::VerifyAndClearExpectations(ongoing_probe);
210
211 EXPECT_FALSE(manager_.StartProbe(&callback_, name_, address_a_).ok());
212
213 StrictMock<MockMdnsProbe>* completed_probe = ExpectProbeCompleted(name_);
214 EXPECT_EQ(ongoing_probe, completed_probe);
215 EXPECT_FALSE(manager_.IsDomainClaimed(name2_));
216 }
217
TEST_F(MdnsProbeManagerTests,StopProbeChangesOngoingProbesOnly)218 TEST_F(MdnsProbeManagerTests, StopProbeChangesOngoingProbesOnly) {
219 EXPECT_FALSE(manager_.StopProbe(name_).ok());
220 ExpectProbeStopped(name_);
221
222 EXPECT_TRUE(manager_.StartProbe(&callback_, name_, address_a_).ok());
223 ExpectProbeOngoing(name_);
224
225 EXPECT_TRUE(manager_.StopProbe(name_).ok());
226 ExpectProbeStopped(name_);
227
228 SetUpCompletedProbe(name_, address_a_);
229
230 EXPECT_FALSE(manager_.StopProbe(name_).ok());
231 ExpectProbeCompleted(name_);
232 }
233
TEST_F(MdnsProbeManagerTests,RespondToProbeQuerySendsNothingOnUnownedDomain)234 TEST_F(MdnsProbeManagerTests, RespondToProbeQuerySendsNothingOnUnownedDomain) {
235 const MdnsMessage query = CreateProbeQueryMessage(name_, address_c_);
236 manager_.RespondToProbeQuery(query, endpoint_);
237 }
238
TEST_F(MdnsProbeManagerTests,RespondToProbeQueryWorksForCompletedProbes)239 TEST_F(MdnsProbeManagerTests, RespondToProbeQueryWorksForCompletedProbes) {
240 SetUpCompletedProbe(name_, address_a_);
241
242 const MdnsMessage query = CreateProbeQueryMessage(name_, address_c_);
243 EXPECT_CALL(sender_, SendMessage(_, endpoint_))
244 .WillOnce([this](const MdnsMessage& message,
245 const IPEndpoint& endpoint) -> Error {
246 EXPECT_EQ(message.answers().size(), size_t{1});
247 EXPECT_EQ(message.answers()[0].dns_type(), DnsType::kA);
248 EXPECT_EQ(message.answers()[0].name(), this->name_);
249 return Error::None();
250 });
251 manager_.RespondToProbeQuery(query, endpoint_);
252 }
253
TEST_F(MdnsProbeManagerTests,TiebreakProbeQueryWorksForSingleRecordQueries)254 TEST_F(MdnsProbeManagerTests, TiebreakProbeQueryWorksForSingleRecordQueries) {
255 EXPECT_TRUE(manager_.StartProbe(&callback_, name_, address_a_).ok());
256 StrictMock<MockMdnsProbe>* ongoing_probe = ExpectProbeOngoing(name_);
257
258 // If the probe message received matches the currently running probe, do
259 // nothing.
260 MdnsMessage query = CreateProbeQueryMessage(name_, address_a_);
261 manager_.RespondToProbeQuery(query, endpoint_);
262
263 // If the probe message received is less than the ongoing probe, ignore the
264 // incoming probe.
265 query = CreateProbeQueryMessage(name_, address_b_);
266 manager_.RespondToProbeQuery(query, endpoint_);
267
268 // If the probe message received is greater than the ongoing probe, postpone
269 // the currently running probe.
270 query = CreateProbeQueryMessage(name_, address_c_);
271 EXPECT_CALL(*ongoing_probe, Postpone(_)).Times(1);
272 manager_.RespondToProbeQuery(query, endpoint_);
273 }
274
TEST_F(MdnsProbeManagerTests,TiebreakProbeQueryWorksForMultiRecordQueries)275 TEST_F(MdnsProbeManagerTests, TiebreakProbeQueryWorksForMultiRecordQueries) {
276 EXPECT_TRUE(manager_.StartProbe(&callback_, name_, address_a_).ok());
277 StrictMock<MockMdnsProbe>* ongoing_probe = ExpectProbeOngoing(name_);
278
279 // For the below tests, note that if records A, B, C are generated from
280 // addresses |address_a_|, |address_b_|, and |address_c_| respectively,
281 // then B < A < C.
282 //
283 // If the received records have one record less than the tested record, they
284 // are sorted and the lowest record is compared.
285 MdnsMessage query = CreateProbeQueryMessage(name_, address_b_);
286 query.AddAuthorityRecord(CreateAddressRecord(name_, address_c_));
287 manager_.RespondToProbeQuery(query, endpoint_);
288
289 query = CreateProbeQueryMessage(name_, address_c_);
290 query.AddAuthorityRecord(CreateAddressRecord(name_, address_b_));
291 manager_.RespondToProbeQuery(query, endpoint_);
292
293 query = CreateProbeQueryMessage(name_, address_a_);
294 query.AddAuthorityRecord(CreateAddressRecord(name_, address_b_));
295 query.AddAuthorityRecord(CreateAddressRecord(name_, address_c_));
296 manager_.RespondToProbeQuery(query, endpoint_);
297
298 // If the probe message received has the same first record as what's being
299 // compared and the query has more records, the query wins.
300 query = CreateProbeQueryMessage(name_, address_a_);
301 query.AddAuthorityRecord(CreateAddressRecord(name_, address_c_));
302 EXPECT_CALL(*ongoing_probe, Postpone(_)).Times(1);
303 manager_.RespondToProbeQuery(query, endpoint_);
304 testing::Mock::VerifyAndClearExpectations(ongoing_probe);
305
306 query = CreateProbeQueryMessage(name_, address_c_);
307 query.AddAuthorityRecord(CreateAddressRecord(name_, address_a_));
308 EXPECT_CALL(*ongoing_probe, Postpone(_)).Times(1);
309 manager_.RespondToProbeQuery(query, endpoint_);
310 }
311
TEST_F(MdnsProbeManagerTests,ProbeSuccessAfterProbeRemovalNoOp)312 TEST_F(MdnsProbeManagerTests, ProbeSuccessAfterProbeRemovalNoOp) {
313 EXPECT_TRUE(manager_.StartProbe(&callback_, name_, address_a_).ok());
314 StrictMock<MockMdnsProbe>* ongoing_probe = ExpectProbeOngoing(name_);
315 EXPECT_TRUE(manager_.StopProbe(name_).ok());
316 manager_.OnProbeSuccess(ongoing_probe);
317 }
318
TEST_F(MdnsProbeManagerTests,ProbeFailureAfterProbeRemovalNoOp)319 TEST_F(MdnsProbeManagerTests, ProbeFailureAfterProbeRemovalNoOp) {
320 EXPECT_TRUE(manager_.StartProbe(&callback_, name_, address_a_).ok());
321 StrictMock<MockMdnsProbe>* ongoing_probe = ExpectProbeOngoing(name_);
322 EXPECT_TRUE(manager_.StopProbe(name_).ok());
323 manager_.OnProbeFailure(ongoing_probe);
324 }
325
TEST_F(MdnsProbeManagerTests,ProbeFailureCallsCallbackWhenAlreadyClaimed)326 TEST_F(MdnsProbeManagerTests, ProbeFailureCallsCallbackWhenAlreadyClaimed) {
327 // This test first starts a probe with domain |name_retry_| so that when
328 // probe with domain |name_| fails, the newly generated domain with equal
329 // |name_retry_|.
330 StrictMock<MockMdnsProbe>* ongoing_probe =
331 SetUpCompletedProbe(name_retry_, address_a_);
332
333 // Because |name_retry_| has already succeeded, the retry logic should skip
334 // over re-querying for |name_retry_| and jump right to success.
335 EXPECT_TRUE(manager_.StartProbe(&callback_, name_, address_a_).ok());
336 ongoing_probe = ExpectProbeOngoing(name_);
337 EXPECT_CALL(callback_, OnDomainFound(name_, name_retry_));
338 manager_.OnProbeFailure(ongoing_probe);
339 ExpectProbeStopped(name_);
340 ExpectProbeCompleted(name_retry_);
341 }
342
TEST_F(MdnsProbeManagerTests,ProbeFailureCreatesNewProbeIfNameUnclaimed)343 TEST_F(MdnsProbeManagerTests, ProbeFailureCreatesNewProbeIfNameUnclaimed) {
344 EXPECT_TRUE(manager_.StartProbe(&callback_, name_, address_a_).ok());
345 StrictMock<MockMdnsProbe>* ongoing_probe = ExpectProbeOngoing(name_);
346 manager_.OnProbeFailure(ongoing_probe);
347 ExpectProbeStopped(name_);
348 ongoing_probe = ExpectProbeOngoing(name_retry_);
349 EXPECT_EQ(ongoing_probe->target_name(), name_retry_);
350
351 EXPECT_CALL(callback_, OnDomainFound(name_, name_retry_));
352 manager_.OnProbeSuccess(ongoing_probe);
353 ExpectProbeCompleted(name_retry_);
354 ExpectProbeStopped(name_);
355 }
356
357 } // namespace discovery
358 } // namespace openscreen
359