1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "discovery/mdns/mdns_publisher.h"
6 
7 #include <chrono>
8 #include <vector>
9 
10 #include "discovery/common/config.h"
11 #include "discovery/mdns/mdns_probe_manager.h"
12 #include "discovery/mdns/mdns_sender.h"
13 #include "discovery/mdns/testing/mdns_test_util.h"
14 #include "platform/test/fake_task_runner.h"
15 #include "platform/test/fake_udp_socket.h"
16 
17 using testing::_;
18 using testing::Invoke;
19 using testing::Return;
20 using testing::StrictMock;
21 
22 namespace openscreen {
23 namespace discovery {
24 namespace {
25 
26 constexpr Clock::duration kAnnounceGoodbyeDelay = std::chrono::milliseconds(25);
27 
ContainsRecord(const std::vector<MdnsRecord::ConstRef> & records,MdnsRecord record)28 bool ContainsRecord(const std::vector<MdnsRecord::ConstRef>& records,
29                     MdnsRecord record) {
30   return std::find_if(records.begin(), records.end(),
31                       [&record](const MdnsRecord& ref) {
32                         return ref == record;
33                       }) != records.end();
34 }
35 
36 }  // namespace
37 
38 class MockMdnsSender : public MdnsSender {
39  public:
MockMdnsSender(UdpSocket * socket)40   explicit MockMdnsSender(UdpSocket* socket) : MdnsSender(socket) {}
41 
42   MOCK_METHOD1(SendMulticast, Error(const MdnsMessage& message));
43   MOCK_METHOD2(SendMessage,
44                Error(const MdnsMessage& message, const IPEndpoint& endpoint));
45 };
46 
47 class MockProbeManager : public MdnsProbeManager {
48  public:
49   MOCK_CONST_METHOD1(IsDomainClaimed, bool(const DomainName&));
50   MOCK_METHOD2(RespondToProbeQuery,
51                void(const MdnsMessage&, const IPEndpoint&));
52 };
53 
54 class MdnsPublisherTesting : public MdnsPublisher {
55  public:
56   using MdnsPublisher::GetPtrRecords;
57   using MdnsPublisher::GetRecords;
58   using MdnsPublisher::MdnsPublisher;
59 
IsNonPtrRecordPresent(const DomainName & name)60   bool IsNonPtrRecordPresent(const DomainName& name) {
61     auto it = records_.find(name);
62     if (it == records_.end()) {
63       return false;
64     }
65 
66     return std::find_if(it->second.begin(), it->second.end(),
67                         [](const RecordAnnouncerPtr& announcer) {
68                           return announcer->record().dns_type() !=
69                                  DnsType::kPTR;
70                         }) != it->second.end();
71   }
72 };
73 
74 class MdnsPublisherTest : public testing::Test {
75  public:
MdnsPublisherTest()76   MdnsPublisherTest()
77       : clock_(Clock::now()),
78         task_runner_(&clock_),
79         socket_(&task_runner_),
80         sender_(&socket_),
81         publisher_(&sender_,
82                    &probe_manager_,
83                    &task_runner_,
84                    FakeClock::now,
85                    config_) {}
86 
~MdnsPublisherTest()87   ~MdnsPublisherTest() {
88     // Clear out any remaining calls in the task runner queue.
89     clock_.Advance(Clock::to_duration(std::chrono::seconds(1)));
90   }
91 
92  protected:
IsAnnounced(const MdnsRecord & original,const MdnsMessage & message)93   Error IsAnnounced(const MdnsRecord& original, const MdnsMessage& message) {
94     EXPECT_EQ(message.type(), MessageType::Response);
95     EXPECT_EQ(message.questions().size(), size_t{0});
96     EXPECT_EQ(message.authority_records().size(), size_t{0});
97     EXPECT_EQ(message.additional_records().size(), size_t{0});
98     EXPECT_EQ(message.answers().size(), size_t{1});
99 
100     const MdnsRecord& sent = message.answers()[0];
101     EXPECT_EQ(original.name(), sent.name());
102     EXPECT_EQ(original.dns_type(), sent.dns_type());
103     EXPECT_EQ(original.dns_class(), sent.dns_class());
104     EXPECT_EQ(original.record_type(), sent.record_type());
105     EXPECT_EQ(original.rdata(), sent.rdata());
106     EXPECT_EQ(original.ttl(), sent.ttl());
107     return Error::None();
108   }
109 
IsGoodbyeRecord(const MdnsRecord & original,const MdnsMessage & message)110   Error IsGoodbyeRecord(const MdnsRecord& original,
111                         const MdnsMessage& message) {
112     EXPECT_EQ(message.type(), MessageType::Response);
113     EXPECT_EQ(message.questions().size(), size_t{0});
114     EXPECT_EQ(message.authority_records().size(), size_t{0});
115     EXPECT_EQ(message.additional_records().size(), size_t{0});
116     EXPECT_EQ(message.answers().size(), size_t{1});
117 
118     const MdnsRecord& sent = message.answers()[0];
119     EXPECT_EQ(original.name(), sent.name());
120     EXPECT_EQ(original.dns_type(), sent.dns_type());
121     EXPECT_EQ(original.dns_class(), sent.dns_class());
122     EXPECT_EQ(original.record_type(), sent.record_type());
123     EXPECT_EQ(original.rdata(), sent.rdata());
124     EXPECT_EQ(std::chrono::seconds(0), sent.ttl());
125     return Error::None();
126   }
127 
CheckPublishedRecords(const DomainName & domain,DnsType type,std::vector<MdnsRecord> expected_records)128   void CheckPublishedRecords(const DomainName& domain,
129                              DnsType type,
130                              std::vector<MdnsRecord> expected_records) {
131     EXPECT_EQ(publisher_.GetRecordCount(), expected_records.size());
132     auto records = publisher_.GetRecords(domain, type, DnsClass::kIN);
133     for (const auto& record : expected_records) {
134       EXPECT_TRUE(ContainsRecord(records, record));
135     }
136   }
137 
TestUniqueRecordRegistrationWorkflow(MdnsRecord record,MdnsRecord record2)138   void TestUniqueRecordRegistrationWorkflow(MdnsRecord record,
139                                             MdnsRecord record2) {
140     EXPECT_CALL(probe_manager_, IsDomainClaimed(domain_))
141         .WillRepeatedly(Return(true));
142     DnsType type = record.dns_type();
143 
144     // Check preconditions.
145     ASSERT_EQ(record.dns_type(), record2.dns_type());
146     auto records = publisher_.GetRecords(domain_, type, DnsClass::kIN);
147     ASSERT_EQ(publisher_.GetRecordCount(), size_t{0});
148     ASSERT_EQ(records.size(), size_t{0});
149     ASSERT_NE(record, record2);
150     ASSERT_TRUE(records.empty());
151 
152     // Register a new record.
153     EXPECT_CALL(sender_, SendMulticast(_))
154         .WillOnce([this, &record](const MdnsMessage& message) -> Error {
155           return IsAnnounced(record, message);
156         });
157     EXPECT_TRUE(publisher_.RegisterRecord(record).ok());
158     clock_.Advance(kAnnounceGoodbyeDelay);
159     testing::Mock::VerifyAndClearExpectations(&sender_);
160     CheckPublishedRecords(domain_, type, {record});
161     EXPECT_TRUE(publisher_.IsNonPtrRecordPresent(domain_));
162 
163     // Re-register the same record.
164     EXPECT_FALSE(publisher_.RegisterRecord(record).ok());
165     clock_.Advance(kAnnounceGoodbyeDelay);
166     testing::Mock::VerifyAndClearExpectations(&sender_);
167     CheckPublishedRecords(domain_, type, {record});
168     EXPECT_TRUE(publisher_.IsNonPtrRecordPresent(domain_));
169 
170     // Update a record that doesn't exist
171     EXPECT_FALSE(publisher_.UpdateRegisteredRecord(record2, record).ok());
172     clock_.Advance(kAnnounceGoodbyeDelay);
173     testing::Mock::VerifyAndClearExpectations(&sender_);
174     CheckPublishedRecords(domain_, type, {record});
175     EXPECT_TRUE(publisher_.IsNonPtrRecordPresent(domain_));
176 
177     // Update an existing record.
178     EXPECT_CALL(sender_, SendMulticast(_))
179         .WillOnce([this, &record2](const MdnsMessage& message) -> Error {
180           return IsAnnounced(record2, message);
181         });
182     EXPECT_TRUE(publisher_.UpdateRegisteredRecord(record, record2).ok());
183     clock_.Advance(kAnnounceGoodbyeDelay);
184     testing::Mock::VerifyAndClearExpectations(&sender_);
185     CheckPublishedRecords(domain_, type, {record2});
186     EXPECT_TRUE(publisher_.IsNonPtrRecordPresent(domain_));
187 
188     // Add back the original record
189     EXPECT_CALL(sender_, SendMulticast(_))
190         .WillOnce([this, &record](const MdnsMessage& message) -> Error {
191           return IsAnnounced(record, message);
192         });
193     EXPECT_TRUE(publisher_.RegisterRecord(record).ok());
194     clock_.Advance(kAnnounceGoodbyeDelay);
195     testing::Mock::VerifyAndClearExpectations(&sender_);
196     CheckPublishedRecords(domain_, type, {record, record2});
197     EXPECT_TRUE(publisher_.IsNonPtrRecordPresent(domain_));
198 
199     // Delete an existing record.
200     EXPECT_CALL(sender_, SendMulticast(_))
201         .WillOnce([this, &record2](const MdnsMessage& message) -> Error {
202           return IsGoodbyeRecord(record2, message);
203         });
204     EXPECT_TRUE(publisher_.UnregisterRecord(record2).ok());
205     clock_.Advance(kAnnounceGoodbyeDelay);
206     testing::Mock::VerifyAndClearExpectations(&sender_);
207     CheckPublishedRecords(domain_, type, {record});
208     EXPECT_TRUE(publisher_.IsNonPtrRecordPresent(domain_));
209 
210     // Delete a non-existing record.
211     EXPECT_FALSE(publisher_.UnregisterRecord(record2).ok());
212     clock_.Advance(kAnnounceGoodbyeDelay);
213     testing::Mock::VerifyAndClearExpectations(&sender_);
214     CheckPublishedRecords(domain_, type, {record});
215     EXPECT_TRUE(publisher_.IsNonPtrRecordPresent(domain_));
216 
217     // Delete the last record
218     EXPECT_CALL(sender_, SendMulticast(_))
219         .WillOnce([this, &record](const MdnsMessage& message) -> Error {
220           return IsGoodbyeRecord(record, message);
221         });
222     EXPECT_TRUE(publisher_.UnregisterRecord(record).ok());
223     clock_.Advance(kAnnounceGoodbyeDelay);
224     testing::Mock::VerifyAndClearExpectations(&sender_);
225     CheckPublishedRecords(domain_, type, {});
226     EXPECT_FALSE(publisher_.IsNonPtrRecordPresent(domain_));
227   }
228 
229   FakeClock clock_;
230   FakeTaskRunner task_runner_;
231   FakeUdpSocket socket_;
232   StrictMock<MockMdnsSender> sender_;
233   StrictMock<MockProbeManager> probe_manager_;
234   Config config_;
235   MdnsPublisherTesting publisher_;
236 
237   DomainName domain_{"instance", "_googlecast", "_tcp", "local"};
238   DomainName ptr_domain_{"_googlecast", "_tcp", "local"};
239 };
240 
TEST_F(MdnsPublisherTest,ARecordRegistrationWorkflow)241 TEST_F(MdnsPublisherTest, ARecordRegistrationWorkflow) {
242   const MdnsRecord record1 = GetFakeARecord(domain_);
243   const MdnsRecord record2 =
244       GetFakeARecord(domain_, std::chrono::seconds(1000));
245   TestUniqueRecordRegistrationWorkflow(record1, record2);
246 }
247 
TEST_F(MdnsPublisherTest,AAAARecordRegistrationWorkflow)248 TEST_F(MdnsPublisherTest, AAAARecordRegistrationWorkflow) {
249   const MdnsRecord record1 = GetFakeAAAARecord(domain_);
250   const MdnsRecord record2 =
251       GetFakeAAAARecord(domain_, std::chrono::seconds(1000));
252   TestUniqueRecordRegistrationWorkflow(record1, record2);
253 }
254 
TEST_F(MdnsPublisherTest,TXTRecordRegistrationWorkflow)255 TEST_F(MdnsPublisherTest, TXTRecordRegistrationWorkflow) {
256   const MdnsRecord record1 = GetFakeTxtRecord(domain_);
257   const MdnsRecord record2 =
258       GetFakeTxtRecord(domain_, std::chrono::seconds(1000));
259   TestUniqueRecordRegistrationWorkflow(record1, record2);
260 }
261 
TEST_F(MdnsPublisherTest,SRVRecordRegistrationWorkflow)262 TEST_F(MdnsPublisherTest, SRVRecordRegistrationWorkflow) {
263   const MdnsRecord record1 = GetFakeSrvRecord(domain_);
264   const MdnsRecord record2 =
265       GetFakeSrvRecord(domain_, std::chrono::seconds(1000));
266   TestUniqueRecordRegistrationWorkflow(record1, record2);
267 }
268 
TEST_F(MdnsPublisherTest,PTRRecordRegistrationWorkflow)269 TEST_F(MdnsPublisherTest, PTRRecordRegistrationWorkflow) {
270   const MdnsRecord record = GetFakePtrRecord(domain_);
271   const MdnsRecord record2 =
272       GetFakePtrRecord(domain_, std::chrono::seconds(1000));
273 
274   EXPECT_CALL(probe_manager_, IsDomainClaimed(domain_))
275       .WillRepeatedly(Return(true));
276   DnsType type = DnsType::kPTR;
277 
278   // Check preconditions.
279   ASSERT_EQ(record.dns_type(), record2.dns_type());
280   ASSERT_EQ(publisher_.GetRecordCount(), size_t{0});
281   auto records = publisher_.GetRecords(domain_, type, DnsClass::kIN);
282   ASSERT_EQ(records.size(), size_t{0});
283   records = publisher_.GetRecords(ptr_domain_, type, DnsClass::kIN);
284   ASSERT_EQ(records.size(), size_t{0});
285   ASSERT_NE(record, record2);
286   ASSERT_TRUE(records.empty());
287   ASSERT_EQ(publisher_.GetPtrRecords(DnsClass::kANY).size(), size_t{0});
288 
289   // Register a new record.
290   EXPECT_CALL(sender_, SendMulticast(_))
291       .WillOnce([this, &record](const MdnsMessage& message) -> Error {
292         return IsAnnounced(record, message);
293       });
294   EXPECT_TRUE(publisher_.RegisterRecord(record).ok());
295   clock_.Advance(kAnnounceGoodbyeDelay);
296   testing::Mock::VerifyAndClearExpectations(&sender_);
297   CheckPublishedRecords(ptr_domain_, type, {record});
298   ASSERT_EQ(publisher_.GetPtrRecords(DnsClass::kANY).size(), size_t{1});
299 
300   // Re-register the same record.
301   EXPECT_FALSE(publisher_.RegisterRecord(record).ok());
302   clock_.Advance(kAnnounceGoodbyeDelay);
303   testing::Mock::VerifyAndClearExpectations(&sender_);
304   CheckPublishedRecords(ptr_domain_, type, {record});
305   ASSERT_EQ(publisher_.GetPtrRecords(DnsClass::kANY).size(), size_t{1});
306 
307   // Register a second record.
308   EXPECT_CALL(sender_, SendMulticast(_))
309       .WillOnce([this, &record2](const MdnsMessage& message) -> Error {
310         return IsAnnounced(record2, message);
311       });
312   EXPECT_TRUE(publisher_.RegisterRecord(record2).ok());
313   clock_.Advance(kAnnounceGoodbyeDelay);
314   testing::Mock::VerifyAndClearExpectations(&sender_);
315   CheckPublishedRecords(ptr_domain_, type, {record, record2});
316   ASSERT_EQ(publisher_.GetPtrRecords(DnsClass::kANY).size(), size_t{2});
317 
318   // Delete an existing record.
319   EXPECT_CALL(sender_, SendMulticast(_))
320       .WillOnce([this, &record2](const MdnsMessage& message) -> Error {
321         return IsGoodbyeRecord(record2, message);
322       });
323   EXPECT_TRUE(publisher_.UnregisterRecord(record2).ok());
324   clock_.Advance(kAnnounceGoodbyeDelay);
325   testing::Mock::VerifyAndClearExpectations(&sender_);
326   CheckPublishedRecords(ptr_domain_, type, {record});
327   ASSERT_EQ(publisher_.GetPtrRecords(DnsClass::kANY).size(), size_t{1});
328 
329   // Delete a non-existing record.
330   EXPECT_FALSE(publisher_.UnregisterRecord(record2).ok());
331   clock_.Advance(kAnnounceGoodbyeDelay);
332   testing::Mock::VerifyAndClearExpectations(&sender_);
333   CheckPublishedRecords(ptr_domain_, type, {record});
334   ASSERT_EQ(publisher_.GetPtrRecords(DnsClass::kANY).size(), size_t{1});
335 
336   // Delete the last record
337   EXPECT_CALL(sender_, SendMulticast(_))
338       .WillOnce([this, &record](const MdnsMessage& message) -> Error {
339         return IsGoodbyeRecord(record, message);
340       });
341   EXPECT_TRUE(publisher_.UnregisterRecord(record).ok());
342   clock_.Advance(kAnnounceGoodbyeDelay);
343   testing::Mock::VerifyAndClearExpectations(&sender_);
344   CheckPublishedRecords(ptr_domain_, type, {});
345   ASSERT_EQ(publisher_.GetPtrRecords(DnsClass::kANY).size(), size_t{0});
346 }
347 
TEST_F(MdnsPublisherTest,RegisteringUnownedRecordsFail)348 TEST_F(MdnsPublisherTest, RegisteringUnownedRecordsFail) {
349   EXPECT_CALL(probe_manager_, IsDomainClaimed(domain_))
350       .WillRepeatedly(Return(false));
351   EXPECT_FALSE(publisher_.RegisterRecord(GetFakePtrRecord(domain_)).ok());
352   EXPECT_FALSE(publisher_.RegisterRecord(GetFakeSrvRecord(domain_)).ok());
353   EXPECT_FALSE(publisher_.RegisterRecord(GetFakeTxtRecord(domain_)).ok());
354   EXPECT_FALSE(publisher_.RegisterRecord(GetFakeARecord(domain_)).ok());
355   EXPECT_FALSE(publisher_.RegisterRecord(GetFakeAAAARecord(domain_)).ok());
356 }
357 
TEST_F(MdnsPublisherTest,RegistrationAnnouncesEightTimes)358 TEST_F(MdnsPublisherTest, RegistrationAnnouncesEightTimes) {
359   EXPECT_CALL(probe_manager_, IsDomainClaimed(domain_))
360       .WillRepeatedly(Return(true));
361   constexpr Clock::duration kOneSecond =
362       Clock::to_duration(std::chrono::seconds(1));
363 
364   // First announce, at registration.
365   const MdnsRecord record = GetFakeARecord(domain_);
366   EXPECT_CALL(sender_, SendMulticast(_))
367       .WillOnce([this, &record](const MdnsMessage& message) -> Error {
368         return IsAnnounced(record, message);
369       });
370   EXPECT_TRUE(publisher_.RegisterRecord(record).ok());
371   clock_.Advance(kAnnounceGoodbyeDelay);
372 
373   // Second announce, at 2 seconds.
374   testing::Mock::VerifyAndClearExpectations(&sender_);
375   EXPECT_CALL(sender_, SendMulticast(_))
376       .WillOnce([this, &record](const MdnsMessage& message) -> Error {
377         return IsAnnounced(record, message);
378       });
379   clock_.Advance(kOneSecond);
380   clock_.Advance(kAnnounceGoodbyeDelay);
381   testing::Mock::VerifyAndClearExpectations(&sender_);
382 
383   // Third announce, at 4 seconds.
384   clock_.Advance(kOneSecond);
385   clock_.Advance(kAnnounceGoodbyeDelay);
386   EXPECT_CALL(sender_, SendMulticast(_))
387       .WillOnce([this, &record](const MdnsMessage& message) -> Error {
388         return IsAnnounced(record, message);
389       });
390   clock_.Advance(kOneSecond);
391   clock_.Advance(kAnnounceGoodbyeDelay);
392   testing::Mock::VerifyAndClearExpectations(&sender_);
393 
394   // Fourth announce, at 8 seconds.
395   clock_.Advance(kOneSecond * 3);
396   clock_.Advance(kAnnounceGoodbyeDelay);
397   EXPECT_CALL(sender_, SendMulticast(_))
398       .WillOnce([this, &record](const MdnsMessage& message) -> Error {
399         return IsAnnounced(record, message);
400       });
401   clock_.Advance(kOneSecond);
402   clock_.Advance(kAnnounceGoodbyeDelay);
403   testing::Mock::VerifyAndClearExpectations(&sender_);
404 
405   // Fifth announce, at 16 seconds.
406   clock_.Advance(kOneSecond * 7);
407   clock_.Advance(kAnnounceGoodbyeDelay);
408   EXPECT_CALL(sender_, SendMulticast(_))
409       .WillOnce([this, &record](const MdnsMessage& message) -> Error {
410         return IsAnnounced(record, message);
411       });
412   clock_.Advance(kOneSecond);
413   clock_.Advance(kAnnounceGoodbyeDelay);
414   testing::Mock::VerifyAndClearExpectations(&sender_);
415 
416   // Sixth announce, at 32 seconds.
417   clock_.Advance(kOneSecond * 15);
418   clock_.Advance(kAnnounceGoodbyeDelay);
419   EXPECT_CALL(sender_, SendMulticast(_))
420       .WillOnce([this, &record](const MdnsMessage& message) -> Error {
421         return IsAnnounced(record, message);
422       });
423   clock_.Advance(kOneSecond);
424   clock_.Advance(kAnnounceGoodbyeDelay);
425   testing::Mock::VerifyAndClearExpectations(&sender_);
426 
427   // Seventh announce, at 64 seconds.
428   clock_.Advance(kOneSecond * 31);
429   clock_.Advance(kAnnounceGoodbyeDelay);
430   EXPECT_CALL(sender_, SendMulticast(_))
431       .WillOnce([this, &record](const MdnsMessage& message) -> Error {
432         return IsAnnounced(record, message);
433       });
434   clock_.Advance(kOneSecond);
435   clock_.Advance(kAnnounceGoodbyeDelay);
436   testing::Mock::VerifyAndClearExpectations(&sender_);
437 
438   // Eighth announce, at 128 seconds.
439   clock_.Advance(kOneSecond * 63);
440   clock_.Advance(kAnnounceGoodbyeDelay);
441   EXPECT_CALL(sender_, SendMulticast(_))
442       .WillOnce([this, &record](const MdnsMessage& message) -> Error {
443         return IsAnnounced(record, message);
444       });
445   clock_.Advance(kOneSecond);
446   clock_.Advance(kAnnounceGoodbyeDelay);
447   testing::Mock::VerifyAndClearExpectations(&sender_);
448 
449   // No more announcements
450   clock_.Advance(kOneSecond * 1024);
451   clock_.Advance(kAnnounceGoodbyeDelay);
452   testing::Mock::VerifyAndClearExpectations(&sender_);
453 
454   // Sends goodbye message when removed.
455   EXPECT_CALL(sender_, SendMulticast(_))
456       .WillOnce([this, &record](const MdnsMessage& message) -> Error {
457         return IsGoodbyeRecord(record, message);
458       });
459   EXPECT_TRUE(publisher_.UnregisterRecord(record).ok());
460   clock_.Advance(kAnnounceGoodbyeDelay);
461 }
462 
463 }  // namespace discovery
464 }  // namespace openscreen
465