1 // Copyright 2020 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "discovery/mdns/mdns_probe_manager.h"
6 
7 #include <utility>
8 
9 #include "discovery/common/config.h"
10 #include "discovery/mdns/mdns_probe.h"
11 #include "discovery/mdns/mdns_querier.h"
12 #include "discovery/mdns/mdns_random.h"
13 #include "discovery/mdns/mdns_receiver.h"
14 #include "discovery/mdns/mdns_sender.h"
15 #include "gmock/gmock.h"
16 #include "gtest/gtest.h"
17 #include "platform/test/fake_clock.h"
18 #include "platform/test/fake_task_runner.h"
19 #include "platform/test/fake_udp_socket.h"
20 
21 using testing::_;
22 using testing::Invoke;
23 using testing::Return;
24 using testing::StrictMock;
25 
26 namespace openscreen {
27 namespace discovery {
28 
29 class MockDomainConfirmedProvider : public MdnsDomainConfirmedProvider {
30  public:
31   MOCK_METHOD2(OnDomainFound, void(const DomainName&, const DomainName&));
32 };
33 
34 class MockMdnsSender : public MdnsSender {
35  public:
MockMdnsSender(UdpSocket * socket)36   explicit MockMdnsSender(UdpSocket* socket) : MdnsSender(socket) {}
37 
38   MOCK_METHOD1(SendMulticast, Error(const MdnsMessage& message));
39   MOCK_METHOD2(SendMessage,
40                Error(const MdnsMessage& message, const IPEndpoint& endpoint));
41 };
42 
43 class MockMdnsProbe : public MdnsProbe {
44  public:
MockMdnsProbe(DomainName target_name,IPAddress address)45   MockMdnsProbe(DomainName target_name, IPAddress address)
46       : MdnsProbe(std::move(target_name), std::move(address)) {}
47 
48   MOCK_METHOD1(Postpone, void(std::chrono::seconds));
49   MOCK_METHOD1(OnMessageReceived, void(const MdnsMessage&));
50 };
51 
52 class TestMdnsProbeManager : public MdnsProbeManagerImpl {
53  public:
54   using MdnsProbeManagerImpl::MdnsProbeManagerImpl;
55 
56   using MdnsProbeManagerImpl::OnProbeFailure;
57   using MdnsProbeManagerImpl::OnProbeSuccess;
58 
CreateProbe(DomainName name,IPAddress address)59   std::unique_ptr<MdnsProbe> CreateProbe(DomainName name,
60                                          IPAddress address) override {
61     return std::make_unique<StrictMock<MockMdnsProbe>>(std::move(name),
62                                                        std::move(address));
63   }
64 
GetOngoingMockProbeByTarget(const DomainName & target)65   StrictMock<MockMdnsProbe>* GetOngoingMockProbeByTarget(
66       const DomainName& target) {
67     const auto it =
68         std::find_if(ongoing_probes_.begin(), ongoing_probes_.end(),
69                      [&target](const OngoingProbe& ongoing) {
70                        return ongoing.probe->target_name() == target;
71                      });
72     if (it != ongoing_probes_.end()) {
73       return static_cast<StrictMock<MockMdnsProbe>*>(it->probe.get());
74     }
75     return nullptr;
76   }
77 
GetCompletedMockProbe(const DomainName & target)78   StrictMock<MockMdnsProbe>* GetCompletedMockProbe(const DomainName& target) {
79     const auto it = FindCompletedProbe(target);
80     if (it != completed_probes_.end()) {
81       return static_cast<StrictMock<MockMdnsProbe>*>(it->get());
82     }
83     return nullptr;
84   }
85 
HasOngoingProbe(const DomainName & target)86   bool HasOngoingProbe(const DomainName& target) {
87     return GetOngoingMockProbeByTarget(target) != nullptr;
88   }
89 
HasCompletedProbe(const DomainName & target)90   bool HasCompletedProbe(const DomainName& target) {
91     return GetCompletedMockProbe(target) != nullptr;
92   }
93 
GetOngoingProbeCount()94   size_t GetOngoingProbeCount() { return ongoing_probes_.size(); }
95 
GetCompletedProbeCount()96   size_t GetCompletedProbeCount() { return completed_probes_.size(); }
97 };
98 
99 class MdnsProbeManagerTests : public testing::Test {
100  public:
MdnsProbeManagerTests()101   MdnsProbeManagerTests()
102       : clock_(Clock::now()),
103         task_runner_(&clock_),
104         socket_(&task_runner_),
105         sender_(&socket_),
106         receiver_(config_),
107         manager_(&sender_,
108                  &receiver_,
109                  &random_,
110                  &task_runner_,
111                  FakeClock::now) {
112     ExpectProbeStopped(name_);
113     ExpectProbeStopped(name2_);
114     ExpectProbeStopped(name_retry_);
115   }
116 
117  protected:
CreateProbeQueryMessage(DomainName domain,const IPAddress & address)118   MdnsMessage CreateProbeQueryMessage(DomainName domain,
119                                       const IPAddress& address) {
120     MdnsMessage message(CreateMessageId(), MessageType::Query);
121     MdnsQuestion question(domain, DnsType::kANY, DnsClass::kANY,
122                           ResponseType::kUnicast);
123     MdnsRecord record = CreateAddressRecord(std::move(domain), address);
124     message.AddQuestion(std::move(question));
125     message.AddAuthorityRecord(std::move(record));
126     return message;
127   }
128 
ExpectProbeStopped(const DomainName & name)129   void ExpectProbeStopped(const DomainName& name) {
130     EXPECT_FALSE(manager_.HasOngoingProbe(name));
131     EXPECT_FALSE(manager_.HasCompletedProbe(name));
132     EXPECT_FALSE(manager_.IsDomainClaimed(name));
133   }
134 
ExpectProbeOngoing(const DomainName & name)135   StrictMock<MockMdnsProbe>* ExpectProbeOngoing(const DomainName& name) {
136     // Get around limitations of using an assert in a function with a return
137     // value.
138     auto validate = [this, &name]() {
139       ASSERT_TRUE(manager_.HasOngoingProbe(name));
140       EXPECT_FALSE(manager_.HasCompletedProbe(name));
141       EXPECT_FALSE(manager_.IsDomainClaimed(name));
142     };
143     validate();
144 
145     return manager_.GetOngoingMockProbeByTarget(name);
146   }
147 
ExpectProbeCompleted(const DomainName & name)148   StrictMock<MockMdnsProbe>* ExpectProbeCompleted(const DomainName& name) {
149     // Get around limitations of using an assert in a function with a return
150     // value.
151     auto validate = [this, &name]() {
152       EXPECT_FALSE(manager_.HasOngoingProbe(name));
153       ASSERT_TRUE(manager_.HasCompletedProbe(name));
154       EXPECT_TRUE(manager_.IsDomainClaimed(name));
155     };
156     validate();
157 
158     return manager_.GetCompletedMockProbe(name);
159   }
160 
SetUpCompletedProbe(const DomainName & name,const IPAddress & address)161   StrictMock<MockMdnsProbe>* SetUpCompletedProbe(const DomainName& name,
162                                                  const IPAddress& address) {
163     EXPECT_TRUE(manager_.StartProbe(&callback_, name, address).ok());
164     EXPECT_CALL(callback_, OnDomainFound(name, name));
165     StrictMock<MockMdnsProbe>* ongoing_probe = ExpectProbeOngoing(name);
166     manager_.OnProbeSuccess(ongoing_probe);
167     ExpectProbeCompleted(name);
168     testing::Mock::VerifyAndClearExpectations(ongoing_probe);
169 
170     return ongoing_probe;
171   }
172 
173   Config config_;
174   FakeClock clock_;
175   FakeTaskRunner task_runner_;
176   FakeUdpSocket socket_;
177   StrictMock<MockMdnsSender> sender_;
178   MdnsReceiver receiver_;
179   MdnsRandom random_;
180   StrictMock<TestMdnsProbeManager> manager_;
181   MockDomainConfirmedProvider callback_;
182 
183   const DomainName name_{"test", "_googlecast", "_tcp", "local"};
184   const DomainName name_retry_{"test1", "_googlecast", "_tcp", "local"};
185   const DomainName name2_{"test2", "_googlecast", "_tcp", "local"};
186 
187   // When used to create address records A, B, C, A > B because comparison of
188   // the rdata in each results in the comparison of endpoints, for which
189   // address_b_ < address_a_. A < C because A is DnsType kA with value 1 and
190   // C is DnsType kAAAA with value 28.
191   const IPAddress address_a_{192, 168, 0, 0};
192   const IPAddress address_b_{190, 160, 0, 0};
193   const IPAddress address_c_{0x0102, 0x0304, 0x0506, 0x0708,
194                              0x090a, 0x0b0c, 0x0d0e, 0x0f10};
195   const IPEndpoint endpoint_{{192, 168, 0, 0}, 80};
196 };
197 
TEST_F(MdnsProbeManagerTests,StartProbeBeginsProbeWhenNoneExistsOnly)198 TEST_F(MdnsProbeManagerTests, StartProbeBeginsProbeWhenNoneExistsOnly) {
199   EXPECT_TRUE(manager_.StartProbe(&callback_, name_, address_a_).ok());
200   ExpectProbeOngoing(name_);
201   EXPECT_FALSE(manager_.IsDomainClaimed(name2_));
202 
203   EXPECT_FALSE(manager_.StartProbe(&callback_, name_, address_a_).ok());
204 
205   EXPECT_CALL(callback_, OnDomainFound(name_, name_));
206   StrictMock<MockMdnsProbe>* ongoing_probe = ExpectProbeOngoing(name_);
207   manager_.OnProbeSuccess(ongoing_probe);
208   EXPECT_FALSE(manager_.IsDomainClaimed(name2_));
209   testing::Mock::VerifyAndClearExpectations(ongoing_probe);
210 
211   EXPECT_FALSE(manager_.StartProbe(&callback_, name_, address_a_).ok());
212 
213   StrictMock<MockMdnsProbe>* completed_probe = ExpectProbeCompleted(name_);
214   EXPECT_EQ(ongoing_probe, completed_probe);
215   EXPECT_FALSE(manager_.IsDomainClaimed(name2_));
216 }
217 
TEST_F(MdnsProbeManagerTests,StopProbeChangesOngoingProbesOnly)218 TEST_F(MdnsProbeManagerTests, StopProbeChangesOngoingProbesOnly) {
219   EXPECT_FALSE(manager_.StopProbe(name_).ok());
220   ExpectProbeStopped(name_);
221 
222   EXPECT_TRUE(manager_.StartProbe(&callback_, name_, address_a_).ok());
223   ExpectProbeOngoing(name_);
224 
225   EXPECT_TRUE(manager_.StopProbe(name_).ok());
226   ExpectProbeStopped(name_);
227 
228   SetUpCompletedProbe(name_, address_a_);
229 
230   EXPECT_FALSE(manager_.StopProbe(name_).ok());
231   ExpectProbeCompleted(name_);
232 }
233 
TEST_F(MdnsProbeManagerTests,RespondToProbeQuerySendsNothingOnUnownedDomain)234 TEST_F(MdnsProbeManagerTests, RespondToProbeQuerySendsNothingOnUnownedDomain) {
235   const MdnsMessage query = CreateProbeQueryMessage(name_, address_c_);
236   manager_.RespondToProbeQuery(query, endpoint_);
237 }
238 
TEST_F(MdnsProbeManagerTests,RespondToProbeQueryWorksForCompletedProbes)239 TEST_F(MdnsProbeManagerTests, RespondToProbeQueryWorksForCompletedProbes) {
240   SetUpCompletedProbe(name_, address_a_);
241 
242   const MdnsMessage query = CreateProbeQueryMessage(name_, address_c_);
243   EXPECT_CALL(sender_, SendMessage(_, endpoint_))
244       .WillOnce([this](const MdnsMessage& message,
245                        const IPEndpoint& endpoint) -> Error {
246         EXPECT_EQ(message.answers().size(), size_t{1});
247         EXPECT_EQ(message.answers()[0].dns_type(), DnsType::kA);
248         EXPECT_EQ(message.answers()[0].name(), this->name_);
249         return Error::None();
250       });
251   manager_.RespondToProbeQuery(query, endpoint_);
252 }
253 
TEST_F(MdnsProbeManagerTests,TiebreakProbeQueryWorksForSingleRecordQueries)254 TEST_F(MdnsProbeManagerTests, TiebreakProbeQueryWorksForSingleRecordQueries) {
255   EXPECT_TRUE(manager_.StartProbe(&callback_, name_, address_a_).ok());
256   StrictMock<MockMdnsProbe>* ongoing_probe = ExpectProbeOngoing(name_);
257 
258   // If the probe message received matches the currently running probe, do
259   // nothing.
260   MdnsMessage query = CreateProbeQueryMessage(name_, address_a_);
261   manager_.RespondToProbeQuery(query, endpoint_);
262 
263   // If the probe message received is less than the ongoing probe, ignore the
264   // incoming probe.
265   query = CreateProbeQueryMessage(name_, address_b_);
266   manager_.RespondToProbeQuery(query, endpoint_);
267 
268   // If the probe message received is greater than the ongoing probe, postpone
269   // the currently running probe.
270   query = CreateProbeQueryMessage(name_, address_c_);
271   EXPECT_CALL(*ongoing_probe, Postpone(_)).Times(1);
272   manager_.RespondToProbeQuery(query, endpoint_);
273 }
274 
TEST_F(MdnsProbeManagerTests,TiebreakProbeQueryWorksForMultiRecordQueries)275 TEST_F(MdnsProbeManagerTests, TiebreakProbeQueryWorksForMultiRecordQueries) {
276   EXPECT_TRUE(manager_.StartProbe(&callback_, name_, address_a_).ok());
277   StrictMock<MockMdnsProbe>* ongoing_probe = ExpectProbeOngoing(name_);
278 
279   // For the below tests, note that if records A, B, C are generated from
280   // addresses |address_a_|, |address_b_|, and |address_c_| respectively,
281   // then B < A < C.
282   //
283   // If the received records have one record less than the tested record, they
284   // are sorted and the lowest record is compared.
285   MdnsMessage query = CreateProbeQueryMessage(name_, address_b_);
286   query.AddAuthorityRecord(CreateAddressRecord(name_, address_c_));
287   manager_.RespondToProbeQuery(query, endpoint_);
288 
289   query = CreateProbeQueryMessage(name_, address_c_);
290   query.AddAuthorityRecord(CreateAddressRecord(name_, address_b_));
291   manager_.RespondToProbeQuery(query, endpoint_);
292 
293   query = CreateProbeQueryMessage(name_, address_a_);
294   query.AddAuthorityRecord(CreateAddressRecord(name_, address_b_));
295   query.AddAuthorityRecord(CreateAddressRecord(name_, address_c_));
296   manager_.RespondToProbeQuery(query, endpoint_);
297 
298   // If the probe message received has the same first record as what's being
299   // compared and the query has more records, the query wins.
300   query = CreateProbeQueryMessage(name_, address_a_);
301   query.AddAuthorityRecord(CreateAddressRecord(name_, address_c_));
302   EXPECT_CALL(*ongoing_probe, Postpone(_)).Times(1);
303   manager_.RespondToProbeQuery(query, endpoint_);
304   testing::Mock::VerifyAndClearExpectations(ongoing_probe);
305 
306   query = CreateProbeQueryMessage(name_, address_c_);
307   query.AddAuthorityRecord(CreateAddressRecord(name_, address_a_));
308   EXPECT_CALL(*ongoing_probe, Postpone(_)).Times(1);
309   manager_.RespondToProbeQuery(query, endpoint_);
310 }
311 
TEST_F(MdnsProbeManagerTests,ProbeSuccessAfterProbeRemovalNoOp)312 TEST_F(MdnsProbeManagerTests, ProbeSuccessAfterProbeRemovalNoOp) {
313   EXPECT_TRUE(manager_.StartProbe(&callback_, name_, address_a_).ok());
314   StrictMock<MockMdnsProbe>* ongoing_probe = ExpectProbeOngoing(name_);
315   EXPECT_TRUE(manager_.StopProbe(name_).ok());
316   manager_.OnProbeSuccess(ongoing_probe);
317 }
318 
TEST_F(MdnsProbeManagerTests,ProbeFailureAfterProbeRemovalNoOp)319 TEST_F(MdnsProbeManagerTests, ProbeFailureAfterProbeRemovalNoOp) {
320   EXPECT_TRUE(manager_.StartProbe(&callback_, name_, address_a_).ok());
321   StrictMock<MockMdnsProbe>* ongoing_probe = ExpectProbeOngoing(name_);
322   EXPECT_TRUE(manager_.StopProbe(name_).ok());
323   manager_.OnProbeFailure(ongoing_probe);
324 }
325 
TEST_F(MdnsProbeManagerTests,ProbeFailureCallsCallbackWhenAlreadyClaimed)326 TEST_F(MdnsProbeManagerTests, ProbeFailureCallsCallbackWhenAlreadyClaimed) {
327   // This test first starts a probe with domain |name_retry_| so that when
328   // probe with domain |name_| fails, the newly generated domain with equal
329   // |name_retry_|.
330   StrictMock<MockMdnsProbe>* ongoing_probe =
331       SetUpCompletedProbe(name_retry_, address_a_);
332 
333   // Because |name_retry_| has already succeeded, the retry logic should skip
334   // over re-querying for |name_retry_| and jump right to success.
335   EXPECT_TRUE(manager_.StartProbe(&callback_, name_, address_a_).ok());
336   ongoing_probe = ExpectProbeOngoing(name_);
337   EXPECT_CALL(callback_, OnDomainFound(name_, name_retry_));
338   manager_.OnProbeFailure(ongoing_probe);
339   ExpectProbeStopped(name_);
340   ExpectProbeCompleted(name_retry_);
341 }
342 
TEST_F(MdnsProbeManagerTests,ProbeFailureCreatesNewProbeIfNameUnclaimed)343 TEST_F(MdnsProbeManagerTests, ProbeFailureCreatesNewProbeIfNameUnclaimed) {
344   EXPECT_TRUE(manager_.StartProbe(&callback_, name_, address_a_).ok());
345   StrictMock<MockMdnsProbe>* ongoing_probe = ExpectProbeOngoing(name_);
346   manager_.OnProbeFailure(ongoing_probe);
347   ExpectProbeStopped(name_);
348   ongoing_probe = ExpectProbeOngoing(name_retry_);
349   EXPECT_EQ(ongoing_probe->target_name(), name_retry_);
350 
351   EXPECT_CALL(callback_, OnDomainFound(name_, name_retry_));
352   manager_.OnProbeSuccess(ongoing_probe);
353   ExpectProbeCompleted(name_retry_);
354   ExpectProbeStopped(name_);
355 }
356 
357 }  // namespace discovery
358 }  // namespace openscreen
359