1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "discovery/mdns/mdns_records.h"
6 
7 #include <algorithm>
8 #include <cctype>
9 #include <limits>
10 #include <sstream>
11 #include <vector>
12 
13 #include "absl/strings/ascii.h"
14 #include "absl/strings/match.h"
15 #include "absl/strings/str_join.h"
16 #include "discovery/mdns/mdns_writer.h"
17 
18 namespace openscreen {
19 namespace discovery {
20 
21 namespace {
22 
23 constexpr size_t kMaxRawRecordSize = std::numeric_limits<uint16_t>::max();
24 
25 constexpr size_t kMaxMessageFieldEntryCount =
26     std::numeric_limits<uint16_t>::max();
27 
CompareIgnoreCase(const std::string & x,const std::string & y)28 inline int CompareIgnoreCase(const std::string& x, const std::string& y) {
29   size_t i = 0;
30   for (; i < x.size(); i++) {
31     if (i == y.size()) {
32       return 1;
33     }
34     const char& x_char = std::tolower(x[i]);
35     const char& y_char = std::tolower(y[i]);
36     if (x_char < y_char) {
37       return -1;
38     } else if (y_char < x_char) {
39       return 1;
40     }
41   }
42   return i == y.size() ? 0 : -1;
43 }
44 
45 template <typename RDataType>
IsGreaterThan(const Rdata & lhs,const Rdata & rhs)46 bool IsGreaterThan(const Rdata& lhs, const Rdata& rhs) {
47   const RDataType& lhs_cast = absl::get<RDataType>(lhs);
48   const RDataType& rhs_cast = absl::get<RDataType>(rhs);
49 
50   // The Extra 2 in length is from the record size that Write() prepends to the
51   // result.
52   const size_t lhs_size = lhs_cast.MaxWireSize() + 2;
53   const size_t rhs_size = rhs_cast.MaxWireSize() + 2;
54 
55   uint8_t lhs_bytes[lhs_size];
56   uint8_t rhs_bytes[rhs_size];
57   MdnsWriter lhs_writer(lhs_bytes, lhs_size);
58   MdnsWriter rhs_writer(rhs_bytes, rhs_size);
59 
60   const bool lhs_write = lhs_writer.Write(lhs_cast);
61   const bool rhs_write = rhs_writer.Write(rhs_cast);
62   OSP_DCHECK(lhs_write);
63   OSP_DCHECK(rhs_write);
64 
65   // Skip the size bits.
66   const size_t min_size = std::min(lhs_writer.offset(), rhs_writer.offset());
67   for (size_t i = 2; i < min_size; i++) {
68     if (lhs_bytes[i] != rhs_bytes[i]) {
69       return lhs_bytes[i] > rhs_bytes[i];
70     }
71   }
72 
73   return lhs_size > rhs_size;
74 }
75 
IsGreaterThan(DnsType type,const Rdata & lhs,const Rdata & rhs)76 bool IsGreaterThan(DnsType type, const Rdata& lhs, const Rdata& rhs) {
77   switch (type) {
78     case DnsType::kA:
79       return IsGreaterThan<ARecordRdata>(lhs, rhs);
80     case DnsType::kPTR:
81       return IsGreaterThan<PtrRecordRdata>(lhs, rhs);
82     case DnsType::kTXT:
83       return IsGreaterThan<TxtRecordRdata>(lhs, rhs);
84     case DnsType::kAAAA:
85       return IsGreaterThan<AAAARecordRdata>(lhs, rhs);
86     case DnsType::kSRV:
87       return IsGreaterThan<SrvRecordRdata>(lhs, rhs);
88     case DnsType::kNSEC:
89       return IsGreaterThan<NsecRecordRdata>(lhs, rhs);
90     default:
91       return IsGreaterThan<RawRecordRdata>(lhs, rhs);
92   }
93 }
94 
95 }  // namespace
96 
IsValidDomainLabel(absl::string_view label)97 bool IsValidDomainLabel(absl::string_view label) {
98   const size_t label_size = label.size();
99   return label_size > 0 && label_size <= kMaxLabelLength;
100 }
101 
102 DomainName::DomainName() = default;
103 
DomainName(std::vector<std::string> labels)104 DomainName::DomainName(std::vector<std::string> labels)
105     : DomainName(labels.begin(), labels.end()) {}
106 
DomainName(const std::vector<absl::string_view> & labels)107 DomainName::DomainName(const std::vector<absl::string_view>& labels)
108     : DomainName(labels.begin(), labels.end()) {}
109 
DomainName(std::initializer_list<absl::string_view> labels)110 DomainName::DomainName(std::initializer_list<absl::string_view> labels)
111     : DomainName(labels.begin(), labels.end()) {}
112 
DomainName(std::vector<std::string> labels,size_t max_wire_size)113 DomainName::DomainName(std::vector<std::string> labels, size_t max_wire_size)
114     : max_wire_size_(max_wire_size), labels_(std::move(labels)) {}
115 
116 DomainName::DomainName(const DomainName& other) = default;
117 
118 DomainName::DomainName(DomainName&& other) noexcept = default;
119 
120 DomainName& DomainName::operator=(const DomainName& rhs) = default;
121 
122 DomainName& DomainName::operator=(DomainName&& rhs) = default;
123 
ToString() const124 std::string DomainName::ToString() const {
125   return absl::StrJoin(labels_, ".");
126 }
127 
operator <(const DomainName & rhs) const128 bool DomainName::operator<(const DomainName& rhs) const {
129   size_t i = 0;
130   for (; i < labels_.size(); i++) {
131     if (i == rhs.labels_.size()) {
132       return false;
133     } else {
134       int result = CompareIgnoreCase(labels_[i], rhs.labels_[i]);
135       if (result < 0) {
136         return true;
137       } else if (result > 0) {
138         return false;
139       }
140     }
141   }
142   return i < rhs.labels_.size();
143 }
144 
operator <=(const DomainName & rhs) const145 bool DomainName::operator<=(const DomainName& rhs) const {
146   return (*this < rhs) || (*this == rhs);
147 }
148 
operator >(const DomainName & rhs) const149 bool DomainName::operator>(const DomainName& rhs) const {
150   return !(*this < rhs) && !(*this == rhs);
151 }
152 
operator >=(const DomainName & rhs) const153 bool DomainName::operator>=(const DomainName& rhs) const {
154   return !(*this < rhs);
155 }
156 
operator ==(const DomainName & rhs) const157 bool DomainName::operator==(const DomainName& rhs) const {
158   if (labels_.size() != rhs.labels_.size()) {
159     return false;
160   }
161   for (size_t i = 0; i < labels_.size(); i++) {
162     if (CompareIgnoreCase(labels_[i], rhs.labels_[i]) != 0) {
163       return false;
164     }
165   }
166   return true;
167 }
168 
operator !=(const DomainName & rhs) const169 bool DomainName::operator!=(const DomainName& rhs) const {
170   return !(*this == rhs);
171 }
172 
MaxWireSize() const173 size_t DomainName::MaxWireSize() const {
174   return max_wire_size_;
175 }
176 
177 // static
TryCreate(std::vector<uint8_t> rdata)178 ErrorOr<RawRecordRdata> RawRecordRdata::TryCreate(std::vector<uint8_t> rdata) {
179   if (rdata.size() > kMaxRawRecordSize) {
180     return Error::Code::kIndexOutOfBounds;
181   } else {
182     return RawRecordRdata(std::move(rdata));
183   }
184 }
185 
186 RawRecordRdata::RawRecordRdata() = default;
187 
RawRecordRdata(std::vector<uint8_t> rdata)188 RawRecordRdata::RawRecordRdata(std::vector<uint8_t> rdata)
189     : rdata_(std::move(rdata)) {
190   // Ensure RDATA length does not exceed the maximum allowed.
191   OSP_DCHECK(rdata_.size() <= kMaxRawRecordSize);
192 }
193 
RawRecordRdata(const uint8_t * begin,size_t size)194 RawRecordRdata::RawRecordRdata(const uint8_t* begin, size_t size)
195     : RawRecordRdata(std::vector<uint8_t>(begin, begin + size)) {}
196 
197 RawRecordRdata::RawRecordRdata(const RawRecordRdata& other) = default;
198 
199 RawRecordRdata::RawRecordRdata(RawRecordRdata&& other) noexcept = default;
200 
201 RawRecordRdata& RawRecordRdata::operator=(const RawRecordRdata& rhs) = default;
202 
203 RawRecordRdata& RawRecordRdata::operator=(RawRecordRdata&& rhs) = default;
204 
operator ==(const RawRecordRdata & rhs) const205 bool RawRecordRdata::operator==(const RawRecordRdata& rhs) const {
206   return rdata_ == rhs.rdata_;
207 }
208 
operator !=(const RawRecordRdata & rhs) const209 bool RawRecordRdata::operator!=(const RawRecordRdata& rhs) const {
210   return !(*this == rhs);
211 }
212 
MaxWireSize() const213 size_t RawRecordRdata::MaxWireSize() const {
214   // max_wire_size includes uint16_t record length field.
215   return sizeof(uint16_t) + rdata_.size();
216 }
217 
218 SrvRecordRdata::SrvRecordRdata() = default;
219 
SrvRecordRdata(uint16_t priority,uint16_t weight,uint16_t port,DomainName target)220 SrvRecordRdata::SrvRecordRdata(uint16_t priority,
221                                uint16_t weight,
222                                uint16_t port,
223                                DomainName target)
224     : priority_(priority),
225       weight_(weight),
226       port_(port),
227       target_(std::move(target)) {}
228 
229 SrvRecordRdata::SrvRecordRdata(const SrvRecordRdata& other) = default;
230 
231 SrvRecordRdata::SrvRecordRdata(SrvRecordRdata&& other) noexcept = default;
232 
233 SrvRecordRdata& SrvRecordRdata::operator=(const SrvRecordRdata& rhs) = default;
234 
235 SrvRecordRdata& SrvRecordRdata::operator=(SrvRecordRdata&& rhs) = default;
236 
operator ==(const SrvRecordRdata & rhs) const237 bool SrvRecordRdata::operator==(const SrvRecordRdata& rhs) const {
238   return priority_ == rhs.priority_ && weight_ == rhs.weight_ &&
239          port_ == rhs.port_ && target_ == rhs.target_;
240 }
241 
operator !=(const SrvRecordRdata & rhs) const242 bool SrvRecordRdata::operator!=(const SrvRecordRdata& rhs) const {
243   return !(*this == rhs);
244 }
245 
MaxWireSize() const246 size_t SrvRecordRdata::MaxWireSize() const {
247   // max_wire_size includes uint16_t record length field.
248   return sizeof(uint16_t) + sizeof(priority_) + sizeof(weight_) +
249          sizeof(port_) + target_.MaxWireSize();
250 }
251 
252 ARecordRdata::ARecordRdata() = default;
253 
ARecordRdata(IPAddress ipv4_address,NetworkInterfaceIndex interface_index)254 ARecordRdata::ARecordRdata(IPAddress ipv4_address,
255                            NetworkInterfaceIndex interface_index)
256     : ipv4_address_(std::move(ipv4_address)),
257       interface_index_(interface_index) {
258   OSP_CHECK(ipv4_address_.IsV4());
259 }
260 
261 ARecordRdata::ARecordRdata(const ARecordRdata& other) = default;
262 
263 ARecordRdata::ARecordRdata(ARecordRdata&& other) noexcept = default;
264 
265 ARecordRdata& ARecordRdata::operator=(const ARecordRdata& rhs) = default;
266 
267 ARecordRdata& ARecordRdata::operator=(ARecordRdata&& rhs) = default;
268 
operator ==(const ARecordRdata & rhs) const269 bool ARecordRdata::operator==(const ARecordRdata& rhs) const {
270   return ipv4_address_ == rhs.ipv4_address_ &&
271          interface_index_ == rhs.interface_index_;
272 }
273 
operator !=(const ARecordRdata & rhs) const274 bool ARecordRdata::operator!=(const ARecordRdata& rhs) const {
275   return !(*this == rhs);
276 }
277 
MaxWireSize() const278 size_t ARecordRdata::MaxWireSize() const {
279   // max_wire_size includes uint16_t record length field.
280   return sizeof(uint16_t) + IPAddress::kV4Size;
281 }
282 
283 AAAARecordRdata::AAAARecordRdata() = default;
284 
AAAARecordRdata(IPAddress ipv6_address,NetworkInterfaceIndex interface_index)285 AAAARecordRdata::AAAARecordRdata(IPAddress ipv6_address,
286                                  NetworkInterfaceIndex interface_index)
287     : ipv6_address_(std::move(ipv6_address)),
288       interface_index_(interface_index) {
289   OSP_CHECK(ipv6_address_.IsV6());
290 }
291 
292 AAAARecordRdata::AAAARecordRdata(const AAAARecordRdata& other) = default;
293 
294 AAAARecordRdata::AAAARecordRdata(AAAARecordRdata&& other) noexcept = default;
295 
296 AAAARecordRdata& AAAARecordRdata::operator=(const AAAARecordRdata& rhs) =
297     default;
298 
299 AAAARecordRdata& AAAARecordRdata::operator=(AAAARecordRdata&& rhs) = default;
300 
operator ==(const AAAARecordRdata & rhs) const301 bool AAAARecordRdata::operator==(const AAAARecordRdata& rhs) const {
302   return ipv6_address_ == rhs.ipv6_address_ &&
303          interface_index_ == rhs.interface_index_;
304 }
305 
operator !=(const AAAARecordRdata & rhs) const306 bool AAAARecordRdata::operator!=(const AAAARecordRdata& rhs) const {
307   return !(*this == rhs);
308 }
309 
MaxWireSize() const310 size_t AAAARecordRdata::MaxWireSize() const {
311   // max_wire_size includes uint16_t record length field.
312   return sizeof(uint16_t) + IPAddress::kV6Size;
313 }
314 
315 PtrRecordRdata::PtrRecordRdata() = default;
316 
PtrRecordRdata(DomainName ptr_domain)317 PtrRecordRdata::PtrRecordRdata(DomainName ptr_domain)
318     : ptr_domain_(ptr_domain) {}
319 
320 PtrRecordRdata::PtrRecordRdata(const PtrRecordRdata& other) = default;
321 
322 PtrRecordRdata::PtrRecordRdata(PtrRecordRdata&& other) noexcept = default;
323 
324 PtrRecordRdata& PtrRecordRdata::operator=(const PtrRecordRdata& rhs) = default;
325 
326 PtrRecordRdata& PtrRecordRdata::operator=(PtrRecordRdata&& rhs) = default;
327 
operator ==(const PtrRecordRdata & rhs) const328 bool PtrRecordRdata::operator==(const PtrRecordRdata& rhs) const {
329   return ptr_domain_ == rhs.ptr_domain_;
330 }
331 
operator !=(const PtrRecordRdata & rhs) const332 bool PtrRecordRdata::operator!=(const PtrRecordRdata& rhs) const {
333   return !(*this == rhs);
334 }
335 
MaxWireSize() const336 size_t PtrRecordRdata::MaxWireSize() const {
337   // max_wire_size includes uint16_t record length field.
338   return sizeof(uint16_t) + ptr_domain_.MaxWireSize();
339 }
340 
341 // static
TryCreate(std::vector<Entry> texts)342 ErrorOr<TxtRecordRdata> TxtRecordRdata::TryCreate(std::vector<Entry> texts) {
343   std::vector<std::string> str_texts;
344   size_t max_wire_size = 3;
345   if (texts.size() > 0) {
346     str_texts.reserve(texts.size());
347     // max_wire_size includes uint16_t record length field.
348     max_wire_size = sizeof(uint16_t);
349     for (const auto& text : texts) {
350       if (text.empty()) {
351         return Error::Code::kParameterInvalid;
352       }
353       str_texts.push_back(
354           std::string(reinterpret_cast<const char*>(text.data()), text.size()));
355       // Include the length byte in the size calculation.
356       max_wire_size += text.size() + 1;
357     }
358   }
359   return TxtRecordRdata(std::move(str_texts), max_wire_size);
360 }
361 
362 TxtRecordRdata::TxtRecordRdata() = default;
363 
TxtRecordRdata(std::vector<Entry> texts)364 TxtRecordRdata::TxtRecordRdata(std::vector<Entry> texts) {
365   ErrorOr<TxtRecordRdata> rdata = TxtRecordRdata::TryCreate(std::move(texts));
366   *this = std::move(rdata.value());
367 }
368 
TxtRecordRdata(std::vector<std::string> texts,size_t max_wire_size)369 TxtRecordRdata::TxtRecordRdata(std::vector<std::string> texts,
370                                size_t max_wire_size)
371     : max_wire_size_(max_wire_size), texts_(std::move(texts)) {}
372 
373 TxtRecordRdata::TxtRecordRdata(const TxtRecordRdata& other) = default;
374 
375 TxtRecordRdata::TxtRecordRdata(TxtRecordRdata&& other) noexcept = default;
376 
377 TxtRecordRdata& TxtRecordRdata::operator=(const TxtRecordRdata& rhs) = default;
378 
379 TxtRecordRdata& TxtRecordRdata::operator=(TxtRecordRdata&& rhs) = default;
380 
operator ==(const TxtRecordRdata & rhs) const381 bool TxtRecordRdata::operator==(const TxtRecordRdata& rhs) const {
382   return texts_ == rhs.texts_;
383 }
384 
operator !=(const TxtRecordRdata & rhs) const385 bool TxtRecordRdata::operator!=(const TxtRecordRdata& rhs) const {
386   return !(*this == rhs);
387 }
388 
MaxWireSize() const389 size_t TxtRecordRdata::MaxWireSize() const {
390   return max_wire_size_;
391 }
392 
393 NsecRecordRdata::NsecRecordRdata() = default;
394 
NsecRecordRdata(DomainName next_domain_name,std::vector<DnsType> types)395 NsecRecordRdata::NsecRecordRdata(DomainName next_domain_name,
396                                  std::vector<DnsType> types)
397     : types_(std::move(types)), next_domain_name_(std::move(next_domain_name)) {
398   // Sort the types_ array for easier comparison later.
399   std::sort(types_.begin(), types_.end());
400 
401   // Calculate the bitmaps as described in RFC 4034 Section 4.1.2.
402   std::vector<uint8_t> block_contents;
403   uint8_t current_block = 0;
404   for (auto type : types_) {
405     const uint16_t type_int = static_cast<uint16_t>(type);
406     const uint8_t block = static_cast<uint8_t>(type_int >> 8);
407     const uint8_t block_position = static_cast<uint8_t>(type_int & 0xFF);
408     const uint8_t byte_bit_is_at = block_position >> 3;         // First 5 bits.
409     const uint8_t byte_mask = 0x80 >> (block_position & 0x07);  // Last 3 bits.
410 
411     // If the block has changed, write the previous block's info and all of its
412     // contents to the |encoded_types_| vector.
413     if (block > current_block) {
414       if (!block_contents.empty()) {
415         encoded_types_.push_back(current_block);
416         encoded_types_.push_back(static_cast<uint8_t>(block_contents.size()));
417         encoded_types_.insert(encoded_types_.end(), block_contents.begin(),
418                               block_contents.end());
419       }
420       block_contents = std::vector<uint8_t>();
421       current_block = block;
422     }
423 
424     // Make sure |block_contents| is large enough to hold the bit representing
425     // the new type , then set it.
426     if (block_contents.size() <= byte_bit_is_at) {
427       block_contents.insert(block_contents.end(),
428                             byte_bit_is_at - block_contents.size() + 1, 0x00);
429     }
430 
431     block_contents[byte_bit_is_at] |= byte_mask;
432   }
433 
434   if (!block_contents.empty()) {
435     encoded_types_.push_back(current_block);
436     encoded_types_.push_back(static_cast<uint8_t>(block_contents.size()));
437     encoded_types_.insert(encoded_types_.end(), block_contents.begin(),
438                           block_contents.end());
439   }
440 }
441 
442 NsecRecordRdata::NsecRecordRdata(const NsecRecordRdata& other) = default;
443 
444 NsecRecordRdata::NsecRecordRdata(NsecRecordRdata&& other) noexcept = default;
445 
446 NsecRecordRdata& NsecRecordRdata::operator=(const NsecRecordRdata& rhs) =
447     default;
448 
449 NsecRecordRdata& NsecRecordRdata::operator=(NsecRecordRdata&& rhs) = default;
450 
operator ==(const NsecRecordRdata & rhs) const451 bool NsecRecordRdata::operator==(const NsecRecordRdata& rhs) const {
452   return types_ == rhs.types_ && next_domain_name_ == rhs.next_domain_name_;
453 }
454 
operator !=(const NsecRecordRdata & rhs) const455 bool NsecRecordRdata::operator!=(const NsecRecordRdata& rhs) const {
456   return !(*this == rhs);
457 }
458 
MaxWireSize() const459 size_t NsecRecordRdata::MaxWireSize() const {
460   return next_domain_name_.MaxWireSize() + encoded_types_.size();
461 }
462 
MaxWireSize() const463 size_t OptRecordRdata::Option::MaxWireSize() const {
464   // One uint16_t for each of OPTION-LENGTH and OPTION-CODE as defined in RFC
465   // 6891 section 6.1.2.
466   constexpr size_t kOptionLengthAndCodeSize = 2 * sizeof(uint16_t);
467   return data.size() + kOptionLengthAndCodeSize;
468 }
469 
operator >(const OptRecordRdata::Option & rhs) const470 bool OptRecordRdata::Option::operator>(
471     const OptRecordRdata::Option& rhs) const {
472   if (code != rhs.code) {
473     return code > rhs.code;
474   } else if (length != rhs.length) {
475     return length > rhs.length;
476   } else if (data.size() != rhs.data.size()) {
477     return data.size() > rhs.data.size();
478   }
479 
480   for (int i = 0; i < static_cast<int>(data.size()); i++) {
481     if (data[i] != rhs.data[i]) {
482       return data[i] > rhs.data[i];
483     }
484   }
485 
486   return false;
487 }
488 
operator <(const OptRecordRdata::Option & rhs) const489 bool OptRecordRdata::Option::operator<(
490     const OptRecordRdata::Option& rhs) const {
491   return rhs > *this;
492 }
493 
operator >=(const OptRecordRdata::Option & rhs) const494 bool OptRecordRdata::Option::operator>=(
495     const OptRecordRdata::Option& rhs) const {
496   return !(*this < rhs);
497 }
498 
operator <=(const OptRecordRdata::Option & rhs) const499 bool OptRecordRdata::Option::operator<=(
500     const OptRecordRdata::Option& rhs) const {
501   return !(*this > rhs);
502 }
503 
operator ==(const OptRecordRdata::Option & rhs) const504 bool OptRecordRdata::Option::operator==(
505     const OptRecordRdata::Option& rhs) const {
506   return *this >= rhs && *this <= rhs;
507 }
508 
operator !=(const OptRecordRdata::Option & rhs) const509 bool OptRecordRdata::Option::operator!=(
510     const OptRecordRdata::Option& rhs) const {
511   return !(*this == rhs);
512 }
513 
514 OptRecordRdata::OptRecordRdata() = default;
515 
OptRecordRdata(std::vector<Option> options)516 OptRecordRdata::OptRecordRdata(std::vector<Option> options)
517     : options_(std::move(options)) {
518   for (const auto& option : options_) {
519     max_wire_size_ += option.MaxWireSize();
520   }
521   std::sort(options_.begin(), options_.end());
522 }
523 
524 OptRecordRdata::OptRecordRdata(const OptRecordRdata& other) = default;
525 
526 OptRecordRdata::OptRecordRdata(OptRecordRdata&& other) noexcept = default;
527 
528 OptRecordRdata& OptRecordRdata::operator=(const OptRecordRdata& rhs) = default;
529 
530 OptRecordRdata& OptRecordRdata::operator=(OptRecordRdata&& rhs) = default;
531 
operator ==(const OptRecordRdata & rhs) const532 bool OptRecordRdata::operator==(const OptRecordRdata& rhs) const {
533   return options_ == rhs.options_;
534 }
535 
operator !=(const OptRecordRdata & rhs) const536 bool OptRecordRdata::operator!=(const OptRecordRdata& rhs) const {
537   return !(*this == rhs);
538 }
539 
540 // static
TryCreate(DomainName name,DnsType dns_type,DnsClass dns_class,RecordType record_type,std::chrono::seconds ttl,Rdata rdata)541 ErrorOr<MdnsRecord> MdnsRecord::TryCreate(DomainName name,
542                                           DnsType dns_type,
543                                           DnsClass dns_class,
544                                           RecordType record_type,
545                                           std::chrono::seconds ttl,
546                                           Rdata rdata) {
547   if (!IsValidConfig(name, dns_type, ttl, rdata)) {
548     return Error::Code::kParameterInvalid;
549   } else {
550     return MdnsRecord(std::move(name), dns_type, dns_class, record_type, ttl,
551                       std::move(rdata));
552   }
553 }
554 
555 MdnsRecord::MdnsRecord() = default;
556 
MdnsRecord(DomainName name,DnsType dns_type,DnsClass dns_class,RecordType record_type,std::chrono::seconds ttl,Rdata rdata)557 MdnsRecord::MdnsRecord(DomainName name,
558                        DnsType dns_type,
559                        DnsClass dns_class,
560                        RecordType record_type,
561                        std::chrono::seconds ttl,
562                        Rdata rdata)
563     : name_(std::move(name)),
564       dns_type_(dns_type),
565       dns_class_(dns_class),
566       record_type_(record_type),
567       ttl_(ttl),
568       rdata_(std::move(rdata)) {
569   OSP_DCHECK(IsValidConfig(name_, dns_type, ttl_, rdata_));
570 }
571 
572 MdnsRecord::MdnsRecord(const MdnsRecord& other) = default;
573 
574 MdnsRecord::MdnsRecord(MdnsRecord&& other) noexcept = default;
575 
576 MdnsRecord& MdnsRecord::operator=(const MdnsRecord& rhs) = default;
577 
578 MdnsRecord& MdnsRecord::operator=(MdnsRecord&& rhs) = default;
579 
580 // static
IsValidConfig(const DomainName & name,DnsType dns_type,std::chrono::seconds ttl,const Rdata & rdata)581 bool MdnsRecord::IsValidConfig(const DomainName& name,
582                                DnsType dns_type,
583                                std::chrono::seconds ttl,
584                                const Rdata& rdata) {
585   // NOTE: Although the name_ field was initially expected to be non-empty, this
586   // validation is no longer accurate for some record types (such as OPT
587   // records). To ensure that future record types correctly parse into
588   // RawRecordData types and do not invalidate the received message, this check
589   // has been removed.
590   return ttl.count() <= std::numeric_limits<uint32_t>::max() &&
591          ((dns_type == DnsType::kSRV &&
592            absl::holds_alternative<SrvRecordRdata>(rdata)) ||
593           (dns_type == DnsType::kA &&
594            absl::holds_alternative<ARecordRdata>(rdata)) ||
595           (dns_type == DnsType::kAAAA &&
596            absl::holds_alternative<AAAARecordRdata>(rdata)) ||
597           (dns_type == DnsType::kPTR &&
598            absl::holds_alternative<PtrRecordRdata>(rdata)) ||
599           (dns_type == DnsType::kTXT &&
600            absl::holds_alternative<TxtRecordRdata>(rdata)) ||
601           (dns_type == DnsType::kNSEC &&
602            absl::holds_alternative<NsecRecordRdata>(rdata)) ||
603           (dns_type == DnsType::kOPT &&
604            absl::holds_alternative<OptRecordRdata>(rdata)) ||
605           absl::holds_alternative<RawRecordRdata>(rdata));
606 }
607 
operator ==(const MdnsRecord & rhs) const608 bool MdnsRecord::operator==(const MdnsRecord& rhs) const {
609   return IsReannouncementOf(rhs) && ttl_ == rhs.ttl_;
610 }
611 
operator !=(const MdnsRecord & rhs) const612 bool MdnsRecord::operator!=(const MdnsRecord& rhs) const {
613   return !(*this == rhs);
614 }
615 
operator >(const MdnsRecord & rhs) const616 bool MdnsRecord::operator>(const MdnsRecord& rhs) const {
617   // Returns the record which is lexicographically later. The determination of
618   // "lexicographically later" is performed by first comparing the record class,
619   // then the record type, then raw comparison of the binary content of the
620   // rdata without regard for meaning or structure.
621   // NOTE: Per RFC, the TTL is not included in this comparison.
622   if (name() != rhs.name()) {
623     return name() > rhs.name();
624   }
625 
626   if (record_type() != rhs.record_type()) {
627     return record_type() == RecordType::kUnique;
628   }
629 
630   if (dns_class() != rhs.dns_class()) {
631     return dns_class() > rhs.dns_class();
632   }
633 
634   uint16_t this_type = static_cast<uint16_t>(dns_type()) & kClassMask;
635   uint16_t other_type = static_cast<uint16_t>(rhs.dns_type()) & kClassMask;
636   if (this_type != other_type) {
637     return this_type > other_type;
638   }
639 
640   return IsGreaterThan(dns_type(), rdata(), rhs.rdata());
641 }
642 
operator <(const MdnsRecord & rhs) const643 bool MdnsRecord::operator<(const MdnsRecord& rhs) const {
644   return rhs > *this;
645 }
646 
operator <=(const MdnsRecord & rhs) const647 bool MdnsRecord::operator<=(const MdnsRecord& rhs) const {
648   return !(*this > rhs);
649 }
650 
operator >=(const MdnsRecord & rhs) const651 bool MdnsRecord::operator>=(const MdnsRecord& rhs) const {
652   return !(*this < rhs);
653 }
654 
IsReannouncementOf(const MdnsRecord & rhs) const655 bool MdnsRecord::IsReannouncementOf(const MdnsRecord& rhs) const {
656   return dns_type_ == rhs.dns_type_ && dns_class_ == rhs.dns_class_ &&
657          record_type_ == rhs.record_type_ && name_ == rhs.name_ &&
658          rdata_ == rhs.rdata_;
659 }
660 
MaxWireSize() const661 size_t MdnsRecord::MaxWireSize() const {
662   auto wire_size_visitor = [](auto&& arg) { return arg.MaxWireSize(); };
663   // NAME size, 2-byte TYPE, 2-byte CLASS, 4-byte TTL, RDATA size
664   return name_.MaxWireSize() + absl::visit(wire_size_visitor, rdata_) + 8;
665 }
666 
ToString() const667 std::string MdnsRecord::ToString() const {
668   std::stringstream ss;
669   ss << "name: '" << name_.ToString() << "'";
670   ss << ", type: " << dns_type_;
671 
672   if (dns_type_ == DnsType::kPTR) {
673     const DomainName& target = absl::get<PtrRecordRdata>(rdata_).ptr_domain();
674     ss << ", target: '" << target.ToString() << "'";
675   } else if (dns_type_ == DnsType::kSRV) {
676     const DomainName& target = absl::get<SrvRecordRdata>(rdata_).target();
677     ss << ", target: '" << target.ToString() << "'";
678   } else if (dns_type_ == DnsType::kNSEC) {
679     const auto& nsec_rdata = absl::get<NsecRecordRdata>(rdata_);
680     std::vector<DnsType> types = nsec_rdata.types();
681     ss << ", representing [";
682     if (!types.empty()) {
683       auto it = types.begin();
684       ss << *it++;
685       while (it != types.end()) {
686         ss << ", " << *it++;
687       }
688       ss << "]";
689     }
690   }
691 
692   return ss.str();
693 }
694 
CreateAddressRecord(DomainName name,const IPAddress & address)695 MdnsRecord CreateAddressRecord(DomainName name, const IPAddress& address) {
696   Rdata rdata;
697   DnsType type;
698   std::chrono::seconds ttl;
699   if (address.IsV4()) {
700     type = DnsType::kA;
701     rdata = ARecordRdata(address);
702     ttl = kARecordTtl;
703   } else {
704     type = DnsType::kAAAA;
705     rdata = AAAARecordRdata(address);
706     ttl = kAAAARecordTtl;
707   }
708 
709   return MdnsRecord(std::move(name), type, DnsClass::kIN, RecordType::kUnique,
710                     ttl, std::move(rdata));
711 }
712 
713 // static
TryCreate(DomainName name,DnsType dns_type,DnsClass dns_class,ResponseType response_type)714 ErrorOr<MdnsQuestion> MdnsQuestion::TryCreate(DomainName name,
715                                               DnsType dns_type,
716                                               DnsClass dns_class,
717                                               ResponseType response_type) {
718   if (name.empty()) {
719     return Error::Code::kParameterInvalid;
720   }
721 
722   return MdnsQuestion(std::move(name), dns_type, dns_class, response_type);
723 }
724 
MdnsQuestion(DomainName name,DnsType dns_type,DnsClass dns_class,ResponseType response_type)725 MdnsQuestion::MdnsQuestion(DomainName name,
726                            DnsType dns_type,
727                            DnsClass dns_class,
728                            ResponseType response_type)
729     : name_(std::move(name)),
730       dns_type_(dns_type),
731       dns_class_(dns_class),
732       response_type_(response_type) {
733   OSP_CHECK(!name_.empty());
734 }
735 
operator ==(const MdnsQuestion & rhs) const736 bool MdnsQuestion::operator==(const MdnsQuestion& rhs) const {
737   return dns_type_ == rhs.dns_type_ && dns_class_ == rhs.dns_class_ &&
738          response_type_ == rhs.response_type_ && name_ == rhs.name_;
739 }
740 
operator !=(const MdnsQuestion & rhs) const741 bool MdnsQuestion::operator!=(const MdnsQuestion& rhs) const {
742   return !(*this == rhs);
743 }
744 
MaxWireSize() const745 size_t MdnsQuestion::MaxWireSize() const {
746   // NAME size, 2-byte TYPE, 2-byte CLASS
747   return name_.MaxWireSize() + 4;
748 }
749 
750 // static
TryCreate(uint16_t id,MessageType type,std::vector<MdnsQuestion> questions,std::vector<MdnsRecord> answers,std::vector<MdnsRecord> authority_records,std::vector<MdnsRecord> additional_records)751 ErrorOr<MdnsMessage> MdnsMessage::TryCreate(
752     uint16_t id,
753     MessageType type,
754     std::vector<MdnsQuestion> questions,
755     std::vector<MdnsRecord> answers,
756     std::vector<MdnsRecord> authority_records,
757     std::vector<MdnsRecord> additional_records) {
758   if (questions.size() >= kMaxMessageFieldEntryCount ||
759       answers.size() >= kMaxMessageFieldEntryCount ||
760       authority_records.size() >= kMaxMessageFieldEntryCount ||
761       additional_records.size() >= kMaxMessageFieldEntryCount) {
762     return Error::Code::kParameterInvalid;
763   }
764 
765   return MdnsMessage(id, type, std::move(questions), std::move(answers),
766                      std::move(authority_records),
767                      std::move(additional_records));
768 }
769 
MdnsMessage(uint16_t id,MessageType type)770 MdnsMessage::MdnsMessage(uint16_t id, MessageType type)
771     : id_(id), type_(type) {}
772 
MdnsMessage(uint16_t id,MessageType type,std::vector<MdnsQuestion> questions,std::vector<MdnsRecord> answers,std::vector<MdnsRecord> authority_records,std::vector<MdnsRecord> additional_records)773 MdnsMessage::MdnsMessage(uint16_t id,
774                          MessageType type,
775                          std::vector<MdnsQuestion> questions,
776                          std::vector<MdnsRecord> answers,
777                          std::vector<MdnsRecord> authority_records,
778                          std::vector<MdnsRecord> additional_records)
779     : id_(id),
780       type_(type),
781       questions_(std::move(questions)),
782       answers_(std::move(answers)),
783       authority_records_(std::move(authority_records)),
784       additional_records_(std::move(additional_records)) {
785   OSP_DCHECK(questions_.size() < kMaxMessageFieldEntryCount);
786   OSP_DCHECK(answers_.size() < kMaxMessageFieldEntryCount);
787   OSP_DCHECK(authority_records_.size() < kMaxMessageFieldEntryCount);
788   OSP_DCHECK(additional_records_.size() < kMaxMessageFieldEntryCount);
789 
790   for (const MdnsQuestion& question : questions_) {
791     max_wire_size_ += question.MaxWireSize();
792   }
793   for (const MdnsRecord& record : answers_) {
794     max_wire_size_ += record.MaxWireSize();
795   }
796   for (const MdnsRecord& record : authority_records_) {
797     max_wire_size_ += record.MaxWireSize();
798   }
799   for (const MdnsRecord& record : additional_records_) {
800     max_wire_size_ += record.MaxWireSize();
801   }
802 }
803 
operator ==(const MdnsMessage & rhs) const804 bool MdnsMessage::operator==(const MdnsMessage& rhs) const {
805   return id_ == rhs.id_ && type_ == rhs.type_ && questions_ == rhs.questions_ &&
806          answers_ == rhs.answers_ &&
807          authority_records_ == rhs.authority_records_ &&
808          additional_records_ == rhs.additional_records_;
809 }
810 
operator !=(const MdnsMessage & rhs) const811 bool MdnsMessage::operator!=(const MdnsMessage& rhs) const {
812   return !(*this == rhs);
813 }
814 
IsProbeQuery() const815 bool MdnsMessage::IsProbeQuery() const {
816   // A message is a probe query if it contains records in the authority section
817   // which answer the question being asked.
818   if (questions().empty() || authority_records().empty()) {
819     return false;
820   }
821 
822   for (const MdnsQuestion& question : questions_) {
823     for (const MdnsRecord& record : authority_records_) {
824       if (question.name() == record.name() &&
825           ((question.dns_type() == record.dns_type()) ||
826            (question.dns_type() == DnsType::kANY)) &&
827           ((question.dns_class() == record.dns_class()) ||
828            (question.dns_class() == DnsClass::kANY))) {
829         return true;
830       }
831     }
832   }
833 
834   return false;
835 }
836 
MaxWireSize() const837 size_t MdnsMessage::MaxWireSize() const {
838   return max_wire_size_;
839 }
840 
AddQuestion(MdnsQuestion question)841 void MdnsMessage::AddQuestion(MdnsQuestion question) {
842   OSP_DCHECK(questions_.size() < kMaxMessageFieldEntryCount);
843   max_wire_size_ += question.MaxWireSize();
844   questions_.emplace_back(std::move(question));
845 }
846 
AddAnswer(MdnsRecord record)847 void MdnsMessage::AddAnswer(MdnsRecord record) {
848   OSP_DCHECK(answers_.size() < kMaxMessageFieldEntryCount);
849   max_wire_size_ += record.MaxWireSize();
850   answers_.emplace_back(std::move(record));
851 }
852 
AddAuthorityRecord(MdnsRecord record)853 void MdnsMessage::AddAuthorityRecord(MdnsRecord record) {
854   OSP_DCHECK(authority_records_.size() < kMaxMessageFieldEntryCount);
855   max_wire_size_ += record.MaxWireSize();
856   authority_records_.emplace_back(std::move(record));
857 }
858 
AddAdditionalRecord(MdnsRecord record)859 void MdnsMessage::AddAdditionalRecord(MdnsRecord record) {
860   OSP_DCHECK(additional_records_.size() < kMaxMessageFieldEntryCount);
861   max_wire_size_ += record.MaxWireSize();
862   additional_records_.emplace_back(std::move(record));
863 }
864 
CanAddRecord(const MdnsRecord & record)865 bool MdnsMessage::CanAddRecord(const MdnsRecord& record) {
866   return (max_wire_size_ + record.MaxWireSize()) < kMaxMulticastMessageSize;
867 }
868 
CreateMessageId()869 uint16_t CreateMessageId() {
870   static uint16_t id(0);
871   return id++;
872 }
873 
CanBePublished(DnsType type)874 bool CanBePublished(DnsType type) {
875   // NOTE: A 'default' switch statement has intentionally been avoided below to
876   // enforce that new DnsTypes added must be added below through a compile-time
877   // check.
878   switch (type) {
879     case DnsType::kA:
880     case DnsType::kAAAA:
881     case DnsType::kPTR:
882     case DnsType::kTXT:
883     case DnsType::kSRV:
884       return true;
885     case DnsType::kOPT:
886     case DnsType::kNSEC:
887     case DnsType::kANY:
888       break;
889   }
890 
891   return false;
892 }
893 
CanBePublished(const MdnsRecord & record)894 bool CanBePublished(const MdnsRecord& record) {
895   return CanBePublished(record.dns_type());
896 }
897 
CanBeQueried(DnsType type)898 bool CanBeQueried(DnsType type) {
899   // NOTE: A 'default' switch statement has intentionally been avoided below to
900   // enforce that new DnsTypes added must be added below through a compile-time
901   // check.
902   switch (type) {
903     case DnsType::kA:
904     case DnsType::kAAAA:
905     case DnsType::kPTR:
906     case DnsType::kTXT:
907     case DnsType::kSRV:
908     case DnsType::kANY:
909       return true;
910     case DnsType::kOPT:
911     case DnsType::kNSEC:
912       break;
913   }
914 
915   return false;
916 }
917 
CanBeProcessed(DnsType type)918 bool CanBeProcessed(DnsType type) {
919   // NOTE: A 'default' switch statement has intentionally been avoided below to
920   // enforce that new DnsTypes added must be added below through a compile-time
921   // check.
922   switch (type) {
923     case DnsType::kA:
924     case DnsType::kAAAA:
925     case DnsType::kPTR:
926     case DnsType::kTXT:
927     case DnsType::kSRV:
928     case DnsType::kNSEC:
929       return true;
930     case DnsType::kOPT:
931     case DnsType::kANY:
932       break;
933   }
934 
935   return false;
936 }
937 
938 }  // namespace discovery
939 }  // namespace openscreen
940