1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "discovery/dnssd/impl/conversion_layer.h"
6 
7 #include <utility>
8 
9 #include "absl/strings/str_join.h"
10 #include "absl/strings/str_split.h"
11 #include "absl/types/optional.h"
12 #include "absl/types/span.h"
13 #include "discovery/dnssd/impl/constants.h"
14 #include "discovery/dnssd/impl/instance_key.h"
15 #include "discovery/dnssd/impl/service_key.h"
16 #include "discovery/dnssd/public/dns_sd_instance.h"
17 #include "discovery/mdns/mdns_records.h"
18 #include "discovery/mdns/public/mdns_constants.h"
19 
20 namespace openscreen {
21 namespace discovery {
22 namespace {
23 
AddServiceInfoToLabels(const std::string & service,const std::string & domain,std::vector<std::string> * labels)24 void AddServiceInfoToLabels(const std::string& service,
25                             const std::string& domain,
26                             std::vector<std::string>* labels) {
27   std::vector<std::string> service_labels = absl::StrSplit(service, '.');
28   labels->insert(labels->end(), service_labels.begin(), service_labels.end());
29 
30   std::vector<std::string> domain_labels = absl::StrSplit(domain, '.');
31   labels->insert(labels->end(), domain_labels.begin(), domain_labels.end());
32 }
33 
GetPtrDomainName(const std::string & service,const std::string & domain)34 DomainName GetPtrDomainName(const std::string& service,
35                             const std::string& domain) {
36   std::vector<std::string> labels;
37   AddServiceInfoToLabels(service, domain, &labels);
38   return DomainName{std::move(labels)};
39 }
40 
GetInstanceDomainName(const std::string & instance,const std::string & service,const std::string & domain)41 DomainName GetInstanceDomainName(const std::string& instance,
42                                  const std::string& service,
43                                  const std::string& domain) {
44   std::vector<std::string> labels;
45   labels.emplace_back(instance);
46   AddServiceInfoToLabels(service, domain, &labels);
47   return DomainName{std::move(labels)};
48 }
49 
GetInstanceDomainName(const InstanceKey & key)50 inline DomainName GetInstanceDomainName(const InstanceKey& key) {
51   return GetInstanceDomainName(key.instance_id(), key.service_id(),
52                                key.domain_id());
53 }
54 
CreatePtrRecord(const DnsSdInstance & instance,const DomainName & domain)55 MdnsRecord CreatePtrRecord(const DnsSdInstance& instance,
56                            const DomainName& domain) {
57   PtrRecordRdata data(domain);
58   auto outer_domain =
59       GetPtrDomainName(instance.service_id(), instance.domain_id());
60   return MdnsRecord(std::move(outer_domain), DnsType::kPTR, DnsClass::kIN,
61                     RecordType::kShared, kPtrRecordTtl, std::move(data));
62 }
63 
CreateSrvRecord(const DnsSdInstance & instance,const DomainName & domain)64 MdnsRecord CreateSrvRecord(const DnsSdInstance& instance,
65                            const DomainName& domain) {
66   uint16_t port = instance.port();
67   SrvRecordRdata data(0, 0, port, domain);
68   return MdnsRecord(domain, DnsType::kSRV, DnsClass::kIN, RecordType::kUnique,
69                     kSrvRecordTtl, std::move(data));
70 }
71 
CreateARecords(const DnsSdInstanceEndpoint & endpoint,const DomainName & domain)72 std::vector<MdnsRecord> CreateARecords(const DnsSdInstanceEndpoint& endpoint,
73                                        const DomainName& domain) {
74   std::vector<MdnsRecord> records;
75   for (const IPAddress& address : endpoint.addresses()) {
76     if (address.IsV4()) {
77       ARecordRdata data(address);
78       records.emplace_back(domain, DnsType::kA, DnsClass::kIN,
79                            RecordType::kUnique, kARecordTtl, std::move(data));
80     }
81   }
82 
83   return records;
84 }
85 
CreateAAAARecords(const DnsSdInstanceEndpoint & endpoint,const DomainName & domain)86 std::vector<MdnsRecord> CreateAAAARecords(const DnsSdInstanceEndpoint& endpoint,
87                                           const DomainName& domain) {
88   std::vector<MdnsRecord> records;
89   for (const IPAddress& address : endpoint.addresses()) {
90     if (address.IsV6()) {
91       AAAARecordRdata data(address);
92       records.emplace_back(domain, DnsType::kAAAA, DnsClass::kIN,
93                            RecordType::kUnique, kAAAARecordTtl,
94                            std::move(data));
95     }
96   }
97 
98   return records;
99 }
100 
CreateTxtRecord(const DnsSdInstance & endpoint,const DomainName & domain)101 MdnsRecord CreateTxtRecord(const DnsSdInstance& endpoint,
102                            const DomainName& domain) {
103   TxtRecordRdata data(endpoint.txt().GetData());
104   return MdnsRecord(domain, DnsType::kTXT, DnsClass::kIN, RecordType::kUnique,
105                     kTXTRecordTtl, std::move(data));
106 }
107 
108 }  // namespace
109 
CreateFromDnsTxt(const TxtRecordRdata & txt_data)110 ErrorOr<DnsSdTxtRecord> CreateFromDnsTxt(const TxtRecordRdata& txt_data) {
111   DnsSdTxtRecord txt;
112   if (txt_data.texts().size() == 1 && txt_data.texts()[0] == "") {
113     return txt;
114   }
115 
116   // Iterate backwards so that the first key of each type is the one that is
117   // present at the end, as pet spec.
118   for (auto it = txt_data.texts().rbegin(); it != txt_data.texts().rend();
119        it++) {
120     const std::string& text = *it;
121     size_t index_of_eq = text.find_first_of('=');
122     if (index_of_eq != std::string::npos) {
123       if (index_of_eq == 0) {
124         return Error::Code::kParameterInvalid;
125       }
126       std::string key = text.substr(0, index_of_eq);
127       std::string value = text.substr(index_of_eq + 1);
128       absl::Span<const uint8_t> data(
129           reinterpret_cast<const uint8_t*>(value.data()), value.size());
130       const auto set_result =
131           txt.SetValue(key, std::vector<uint8_t>(data.begin(), data.end()));
132       if (!set_result.ok()) {
133         return set_result;
134       }
135     } else {
136       const auto set_result = txt.SetFlag(text, true);
137       if (!set_result.ok()) {
138         return set_result;
139       }
140     }
141   }
142 
143   return txt;
144 }
145 
GetDomainName(const InstanceKey & key)146 DomainName GetDomainName(const InstanceKey& key) {
147   return GetInstanceDomainName(key.instance_id(), key.service_id(),
148                                key.domain_id());
149 }
150 
GetDomainName(const MdnsRecord & record)151 DomainName GetDomainName(const MdnsRecord& record) {
152   return IsPtrRecord(record)
153              ? absl::get<PtrRecordRdata>(record.rdata()).ptr_domain()
154              : record.name();
155 }
156 
GetInstanceQueryInfo(const InstanceKey & key)157 DnsQueryInfo GetInstanceQueryInfo(const InstanceKey& key) {
158   return {GetDomainName(key), DnsType::kANY, DnsClass::kANY};
159 }
160 
GetPtrQueryInfo(const ServiceKey & key)161 DnsQueryInfo GetPtrQueryInfo(const ServiceKey& key) {
162   auto domain = GetPtrDomainName(key.service_id(), key.domain_id());
163   return {std::move(domain), DnsType::kPTR, DnsClass::kANY};
164 }
165 
HasValidDnsRecordAddress(const MdnsRecord & record)166 bool HasValidDnsRecordAddress(const MdnsRecord& record) {
167   return HasValidDnsRecordAddress(GetDomainName(record));
168 }
169 
HasValidDnsRecordAddress(const DomainName & domain)170 bool HasValidDnsRecordAddress(const DomainName& domain) {
171   return InstanceKey::TryCreate(domain).is_value() &&
172          IsInstanceValid(domain.labels()[0]);
173 }
174 
IsPtrRecord(const MdnsRecord & record)175 bool IsPtrRecord(const MdnsRecord& record) {
176   return record.dns_type() == DnsType::kPTR;
177 }
178 
GetDnsRecords(const DnsSdInstance & instance)179 std::vector<MdnsRecord> GetDnsRecords(const DnsSdInstance& instance) {
180   auto domain = GetInstanceDomainName(InstanceKey(instance));
181 
182   return {CreatePtrRecord(instance, domain), CreateSrvRecord(instance, domain),
183           CreateTxtRecord(instance, domain)};
184 }
185 
GetDnsRecords(const DnsSdInstanceEndpoint & endpoint)186 std::vector<MdnsRecord> GetDnsRecords(const DnsSdInstanceEndpoint& endpoint) {
187   auto domain = GetInstanceDomainName(InstanceKey(endpoint));
188 
189   std::vector<MdnsRecord> records =
190       GetDnsRecords(static_cast<DnsSdInstance>(endpoint));
191 
192   std::vector<MdnsRecord> v4 = CreateARecords(endpoint, domain);
193   std::vector<MdnsRecord> v6 = CreateAAAARecords(endpoint, domain);
194 
195   records.insert(records.end(), v4.begin(), v4.end());
196   records.insert(records.end(), v6.begin(), v6.end());
197 
198   return records;
199 }
200 
201 }  // namespace discovery
202 }  // namespace openscreen
203