1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "discovery/dnssd/impl/publisher_impl.h"
6 
7 #include <map>
8 #include <string>
9 #include <utility>
10 #include <vector>
11 
12 #include "absl/types/optional.h"
13 #include "discovery/common/reporting_client.h"
14 #include "discovery/dnssd/impl/conversion_layer.h"
15 #include "discovery/dnssd/impl/instance_key.h"
16 #include "discovery/dnssd/impl/network_interface_config.h"
17 #include "discovery/mdns/public/mdns_constants.h"
18 #include "platform/api/task_runner.h"
19 #include "platform/base/error.h"
20 #include "util/trace_logging.h"
21 
22 namespace openscreen {
23 namespace discovery {
24 namespace {
25 
CreateEndpoint(DnsSdInstance instance,InstanceKey key,const NetworkInterfaceConfig & network_config)26 DnsSdInstanceEndpoint CreateEndpoint(
27     DnsSdInstance instance,
28     InstanceKey key,
29     const NetworkInterfaceConfig& network_config) {
30   std::vector<IPEndpoint> endpoints;
31   if (network_config.HasAddressV4()) {
32     endpoints.push_back({network_config.address_v4(), instance.port()});
33   }
34   if (network_config.HasAddressV6()) {
35     endpoints.push_back({network_config.address_v6(), instance.port()});
36   }
37   return DnsSdInstanceEndpoint(
38       key.instance_id(), key.service_id(), key.domain_id(), instance.txt(),
39       network_config.network_interface(), std::move(endpoints));
40 }
41 
UpdateDomain(const DomainName & name,DnsSdInstance instance,const NetworkInterfaceConfig & network_config)42 DnsSdInstanceEndpoint UpdateDomain(
43     const DomainName& name,
44     DnsSdInstance instance,
45     const NetworkInterfaceConfig& network_config) {
46   return CreateEndpoint(std::move(instance), InstanceKey(name), network_config);
47 }
48 
CreateEndpoint(DnsSdInstance instance,const NetworkInterfaceConfig & network_config)49 DnsSdInstanceEndpoint CreateEndpoint(
50     DnsSdInstance instance,
51     const NetworkInterfaceConfig& network_config) {
52   InstanceKey key(instance);
53   return CreateEndpoint(std::move(instance), std::move(key), network_config);
54 }
55 
56 template <typename T>
FindKey(std::map<DnsSdInstance,T> * instances,const InstanceKey & key)57 inline typename std::map<DnsSdInstance, T>::iterator FindKey(
58     std::map<DnsSdInstance, T>* instances,
59     const InstanceKey& key) {
60   return std::find_if(instances->begin(), instances->end(),
61                       [&key](const std::pair<DnsSdInstance, T>& pair) {
62                         return key == InstanceKey(pair.first);
63                       });
64 }
65 
66 template <typename T>
EraseInstancesWithServiceId(std::map<DnsSdInstance,T> * instances,const std::string & service_id)67 int EraseInstancesWithServiceId(std::map<DnsSdInstance, T>* instances,
68                                 const std::string& service_id) {
69   int removed_count = 0;
70   for (auto it = instances->begin(); it != instances->end();) {
71     if (it->first.service_id() == service_id) {
72       removed_count++;
73       it = instances->erase(it);
74     } else {
75       it++;
76     }
77   }
78 
79   return removed_count;
80 }
81 
82 }  // namespace
83 
PublisherImpl(MdnsService * publisher,ReportingClient * reporting_client,TaskRunner * task_runner,const NetworkInterfaceConfig * network_config)84 PublisherImpl::PublisherImpl(MdnsService* publisher,
85                              ReportingClient* reporting_client,
86                              TaskRunner* task_runner,
87                              const NetworkInterfaceConfig* network_config)
88     : mdns_publisher_(publisher),
89       reporting_client_(reporting_client),
90       task_runner_(task_runner),
91       network_config_(network_config) {
92   OSP_DCHECK(mdns_publisher_);
93   OSP_DCHECK(reporting_client_);
94   OSP_DCHECK(task_runner_);
95 }
96 
97 PublisherImpl::~PublisherImpl() = default;
98 
Register(const DnsSdInstance & instance,Client * client)99 Error PublisherImpl::Register(const DnsSdInstance& instance, Client* client) {
100   OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
101   OSP_DCHECK(client != nullptr);
102 
103   if (published_instances_.find(instance) != published_instances_.end()) {
104     UpdateRegistration(instance);
105   } else if (pending_instances_.find(instance) != pending_instances_.end()) {
106     return Error::Code::kOperationInProgress;
107   }
108 
109   InstanceKey key(instance);
110   const IPAddress& address = network_config_->GetAddress();
111   OSP_DCHECK(address);
112   pending_instances_.emplace(CreateEndpoint(instance, *network_config_),
113                              client);
114 
115   OSP_DVLOG << "Registering instance '" << instance.instance_id() << "'";
116 
117   return mdns_publisher_->StartProbe(this, GetDomainName(key), address);
118 }
119 
UpdateRegistration(const DnsSdInstance & instance)120 Error PublisherImpl::UpdateRegistration(const DnsSdInstance& instance) {
121   OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
122 
123   // Check if the instance is still pending publication.
124   auto it = FindKey(&pending_instances_, InstanceKey(instance));
125 
126   OSP_DVLOG << "Updating instance '" << instance.instance_id() << "'";
127 
128   // If it is a pending instance, update it. Else, try to update a published
129   // instance.
130   if (it != pending_instances_.end()) {
131     // The instance, service, and domain ids have not changed, so only the
132     // remaining data needs to change. The ongoing probe does not need to be
133     // modified.
134     Client* const client = it->second;
135     pending_instances_.erase(it);
136     pending_instances_.emplace(CreateEndpoint(instance, *network_config_),
137                                client);
138     return Error::None();
139   } else {
140     return UpdatePublishedRegistration(instance);
141   }
142 }
143 
UpdatePublishedRegistration(const DnsSdInstance & instance)144 Error PublisherImpl::UpdatePublishedRegistration(
145     const DnsSdInstance& instance) {
146   OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
147 
148   auto published_instance_it =
149       FindKey(&published_instances_, InstanceKey(instance));
150 
151   // Check preconditions called out in header. Specifically, the updated
152   // instance must be making changes to an already published instance.
153   if (published_instance_it == published_instances_.end()) {
154     return Error::Code::kParameterInvalid;
155   }
156 
157   const DnsSdInstanceEndpoint updated_endpoint =
158       UpdateDomain(GetDomainName(InstanceKey(published_instance_it->second)),
159                    instance, *network_config_);
160   if (published_instance_it->second == updated_endpoint) {
161     return Error::Code::kParameterInvalid;
162   }
163 
164   // Get all instances which have changed. By design, there an only be one
165   // instance of each DnsType, so use that here to simplify this step. First in
166   // each pair is the old instances, second is the new instance.
167   std::map<DnsType,
168            std::pair<absl::optional<MdnsRecord>, absl::optional<MdnsRecord>>>
169       changed_records;
170   const std::vector<MdnsRecord> old_records =
171       GetDnsRecords(published_instance_it->second);
172   const std::vector<MdnsRecord> new_records = GetDnsRecords(updated_endpoint);
173 
174   // Populate the first part of each pair in |changed_instances|.
175   for (size_t i = 0; i < old_records.size(); i++) {
176     const auto key = old_records[i].dns_type();
177     OSP_DCHECK(changed_records.find(key) == changed_records.end());
178     auto value = std::make_pair(std::move(old_records[i]), absl::nullopt);
179     changed_records.emplace(key, std::move(value));
180   }
181 
182   // Populate the second part of each pair in |changed_records|.
183   for (size_t i = 0; i < new_records.size(); i++) {
184     const auto key = new_records[i].dns_type();
185     auto find_it = changed_records.find(key);
186     if (find_it == changed_records.end()) {
187       std::pair<absl::optional<MdnsRecord>, absl::optional<MdnsRecord>> value(
188           absl::nullopt, std::move(new_records[i]));
189       changed_records.emplace(key, std::move(value));
190     } else {
191       find_it->second.second = std::move(new_records[i]);
192     }
193   }
194 
195   // Apply changes called out in |changed_records|.
196   Error total_result = Error::None();
197   for (const auto& pair : changed_records) {
198     OSP_DCHECK(pair.second.first != absl::nullopt ||
199                pair.second.second != absl::nullopt);
200     if (pair.second.first == absl::nullopt) {
201       TRACE_SCOPED(TraceCategory::kDiscovery, "mdns.RegisterRecord");
202       auto error = mdns_publisher_->RegisterRecord(pair.second.second.value());
203       TRACE_SET_RESULT(error);
204       if (!error.ok()) {
205         total_result = error;
206       }
207     } else if (pair.second.second == absl::nullopt) {
208       TRACE_SCOPED(TraceCategory::kDiscovery, "mdns.UnregisterRecord");
209       auto error = mdns_publisher_->UnregisterRecord(pair.second.first.value());
210       TRACE_SET_RESULT(error);
211       if (!error.ok()) {
212         total_result = error;
213       }
214     } else if (pair.second.first.value() != pair.second.second.value()) {
215       TRACE_SCOPED(TraceCategory::kDiscovery, "mdns.UpdateRegisteredRecord");
216       auto error = mdns_publisher_->UpdateRegisteredRecord(
217           pair.second.first.value(), pair.second.second.value());
218       TRACE_SET_RESULT(error);
219       if (!error.ok()) {
220         total_result = error;
221       }
222     }
223   }
224 
225   // Replace the old instances with the new ones.
226   published_instances_.erase(published_instance_it);
227   published_instances_.emplace(instance, std::move(updated_endpoint));
228 
229   return total_result;
230 }
231 
DeregisterAll(const std::string & service)232 ErrorOr<int> PublisherImpl::DeregisterAll(const std::string& service) {
233   OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
234 
235   OSP_DVLOG << "Deregistering all instances";
236 
237   int removed_count = 0;
238   Error error = Error::None();
239   for (auto it = published_instances_.begin();
240        it != published_instances_.end();) {
241     if (it->second.service_id() == service) {
242       for (const auto& mdns_record : GetDnsRecords(it->second)) {
243         TRACE_SCOPED(TraceCategory::kDiscovery, "mdns.UnregisterRecord");
244         auto publisher_error = mdns_publisher_->UnregisterRecord(mdns_record);
245         TRACE_SET_RESULT(error);
246         if (!publisher_error.ok()) {
247           error = publisher_error;
248         }
249       }
250       removed_count++;
251       it = published_instances_.erase(it);
252     } else {
253       it++;
254     }
255   }
256 
257   removed_count += EraseInstancesWithServiceId(&pending_instances_, service);
258 
259   if (!error.ok()) {
260     return error;
261   } else {
262     return removed_count;
263   }
264 }
265 
OnDomainFound(const DomainName & requested_name,const DomainName & confirmed_name)266 void PublisherImpl::OnDomainFound(const DomainName& requested_name,
267                                   const DomainName& confirmed_name) {
268   TRACE_DEFAULT_SCOPED(TraceCategory::kDiscovery);
269   OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
270 
271   OSP_DVLOG << "Domain successfully claimed: '" << confirmed_name.ToString()
272             << "' based on requested name: '" << requested_name.ToString()
273             << "'";
274 
275   auto it = FindKey(&pending_instances_, InstanceKey(requested_name));
276 
277   if (it == pending_instances_.end()) {
278     // This will be hit if the instance was deregister'd before the probe phase
279     // was completed.
280     return;
281   }
282 
283   DnsSdInstance requested_instance = std::move(it->first);
284   DnsSdInstanceEndpoint endpoint =
285       CreateEndpoint(requested_instance, *network_config_);
286   Client* const client = it->second;
287   pending_instances_.erase(it);
288 
289   InstanceKey requested_key(requested_instance);
290 
291   if (requested_name != confirmed_name) {
292     OSP_DCHECK(HasValidDnsRecordAddress(confirmed_name));
293     endpoint =
294         UpdateDomain(confirmed_name, requested_instance, *network_config_);
295   }
296 
297   for (const auto& mdns_record : GetDnsRecords(endpoint)) {
298     TRACE_SCOPED(TraceCategory::kDiscovery, "mdns.RegisterRecord");
299     Error result = mdns_publisher_->RegisterRecord(mdns_record);
300     if (!result.ok()) {
301       reporting_client_->OnRecoverableError(
302           Error(Error::Code::kRecordPublicationError, result.ToString()));
303     }
304   }
305 
306   auto pair = published_instances_.emplace(std::move(requested_instance),
307                                            std::move(endpoint));
308   client->OnEndpointClaimed(pair.first->first, pair.first->second);
309 }
310 
311 }  // namespace discovery
312 }  // namespace openscreen
313