1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "discovery/dnssd/impl/querier_impl.h"
6 
7 #include <algorithm>
8 #include <string>
9 #include <utility>
10 #include <vector>
11 
12 #include "discovery/common/reporting_client.h"
13 #include "discovery/dnssd/impl/conversion_layer.h"
14 #include "discovery/dnssd/impl/network_interface_config.h"
15 #include "platform/api/task_runner.h"
16 #include "util/osp_logging.h"
17 
18 namespace openscreen {
19 namespace discovery {
20 namespace {
21 
22 static constexpr char kLocalDomain[] = "local";
23 
24 // Removes all error instances from the below records, and calls the log
25 // function on all errors present in |new_endpoints|. Input vectors are expected
26 // to be sorted in ascending order.
ProcessErrors(std::vector<ErrorOr<DnsSdInstanceEndpoint>> * old_endpoints,std::vector<ErrorOr<DnsSdInstanceEndpoint>> * new_endpoints,std::function<void (Error)> log)27 void ProcessErrors(std::vector<ErrorOr<DnsSdInstanceEndpoint>>* old_endpoints,
28                    std::vector<ErrorOr<DnsSdInstanceEndpoint>>* new_endpoints,
29                    std::function<void(Error)> log) {
30   OSP_DCHECK(old_endpoints);
31   OSP_DCHECK(new_endpoints);
32 
33   auto old_it = old_endpoints->begin();
34   auto new_it = new_endpoints->begin();
35 
36   // Iterate across both vectors and log new errors in the process.
37   // NOTE: In sorted order, all errors will appear before all non-errors.
38   while (old_it != old_endpoints->end() && new_it != new_endpoints->end()) {
39     ErrorOr<DnsSdInstanceEndpoint>& old_ep = *old_it;
40     ErrorOr<DnsSdInstanceEndpoint>& new_ep = *new_it;
41 
42     if (new_ep.is_value()) {
43       break;
44     }
45 
46     // If they are equal, the element is in both |old_endpoints| and
47     // |new_endpoints|, so skip it in both vectors.
48     if (old_ep == new_ep) {
49       old_it++;
50       new_it++;
51       continue;
52     }
53 
54     // There's an error in |old_endpoints| not in |new_endpoints|, so skip it.
55     if (old_ep < new_ep) {
56       old_it++;
57       continue;
58     }
59 
60     // There's an error in |new_endpoints| not in |old_endpoints|, so it's a new
61     // error from the applied changes. Log it.
62     log(std::move(new_ep.error()));
63     new_it++;
64   }
65 
66   // Skip all remaining errors in the old vector.
67   for (; old_it != old_endpoints->end() && old_it->is_error(); old_it++) {
68   }
69 
70   // Log all errors remaining in the new vector.
71   for (; new_it != new_endpoints->end() && new_it->is_error(); new_it++) {
72     log(std::move(new_it->error()));
73   }
74 
75   // Erase errors.
76   old_endpoints->erase(old_endpoints->begin(), old_it);
77   new_endpoints->erase(new_endpoints->begin(), new_it);
78 }
79 
80 // Returns a vector containing the value of each ErrorOr<> instance provided.
81 // All ErrorOr<> values are expected to be non-errors.
GetValues(std::vector<ErrorOr<DnsSdInstanceEndpoint>> endpoints)82 std::vector<DnsSdInstanceEndpoint> GetValues(
83     std::vector<ErrorOr<DnsSdInstanceEndpoint>> endpoints) {
84   std::vector<DnsSdInstanceEndpoint> results;
85   results.reserve(endpoints.size());
86   for (ErrorOr<DnsSdInstanceEndpoint>& endpoint : endpoints) {
87     results.push_back(std::move(endpoint.value()));
88   }
89   return results;
90 }
91 
IsEqualOrUpdate(const absl::optional<DnsSdInstanceEndpoint> & first,const absl::optional<DnsSdInstanceEndpoint> & second)92 bool IsEqualOrUpdate(const absl::optional<DnsSdInstanceEndpoint>& first,
93                      const absl::optional<DnsSdInstanceEndpoint>& second) {
94   if (!first.has_value() || !second.has_value()) {
95     return !first.has_value() && !second.has_value();
96   }
97 
98   // In the remaining case, both |first| and |second| must be values.
99   const DnsSdInstanceEndpoint& a = first.value();
100   const DnsSdInstanceEndpoint& b = second.value();
101 
102   // All endpoints from this querier should have the same network interface
103   // because the querier is only associated with a single network interface.
104   OSP_DCHECK_EQ(a.network_interface(), b.network_interface());
105 
106   // Function returns true if first < second.
107   return a.instance_id() == b.instance_id() &&
108          a.service_id() == b.service_id() && a.domain_id() == b.domain_id();
109 }
110 
IsNotEqualOrUpdate(const absl::optional<DnsSdInstanceEndpoint> & first,const absl::optional<DnsSdInstanceEndpoint> & second)111 bool IsNotEqualOrUpdate(const absl::optional<DnsSdInstanceEndpoint>& first,
112                         const absl::optional<DnsSdInstanceEndpoint>& second) {
113   return !IsEqualOrUpdate(first, second);
114 }
115 
116 // Calculates the created, updated, and deleted elements using the provided
117 // sets, appending these values to the provided vectors. Each of the input
118 // vectors is expected to contain only elements such that
119 // |element|.is_error() == false. Additionally, input vectors are expected to
120 // be sorted in ascending order.
121 //
122 // NOTE: A lot of operations are used to do this, but each is only O(n) so the
123 // resulting algorithm is still fast.
CalculateChangeSets(std::vector<DnsSdInstanceEndpoint> old_endpoints,std::vector<DnsSdInstanceEndpoint> new_endpoints,std::vector<DnsSdInstanceEndpoint> * created_out,std::vector<DnsSdInstanceEndpoint> * updated_out,std::vector<DnsSdInstanceEndpoint> * deleted_out)124 void CalculateChangeSets(std::vector<DnsSdInstanceEndpoint> old_endpoints,
125                          std::vector<DnsSdInstanceEndpoint> new_endpoints,
126                          std::vector<DnsSdInstanceEndpoint>* created_out,
127                          std::vector<DnsSdInstanceEndpoint>* updated_out,
128                          std::vector<DnsSdInstanceEndpoint>* deleted_out) {
129   OSP_DCHECK(created_out);
130   OSP_DCHECK(updated_out);
131   OSP_DCHECK(deleted_out);
132 
133   // Use set difference with default operators to find the elements present in
134   // one list but not the others.
135   //
136   // NOTE: Because absl::optional<...> types are used here and below, calls to
137   // the ctor and dtor for empty elements are no-ops.
138   const int total_count = old_endpoints.size() + new_endpoints.size();
139 
140   // This is the set of elements that aren't in the old endpoints, meaning the
141   // old endpoint either didn't exist or had different TXT / Address / etc..
142   std::vector<absl::optional<DnsSdInstanceEndpoint>> created_or_updated(
143       total_count);
144   auto new_end = std::set_difference(new_endpoints.begin(), new_endpoints.end(),
145                                      old_endpoints.begin(), old_endpoints.end(),
146                                      created_or_updated.begin());
147   created_or_updated.erase(new_end, created_or_updated.end());
148 
149   // This is the set of elements that are only in the old endpoints, similar to
150   // the above.
151   std::vector<absl::optional<DnsSdInstanceEndpoint>> deleted_or_updated(
152       total_count);
153   new_end = std::set_difference(old_endpoints.begin(), old_endpoints.end(),
154                                 new_endpoints.begin(), new_endpoints.end(),
155                                 deleted_or_updated.begin());
156   deleted_or_updated.erase(new_end, deleted_or_updated.end());
157 
158   // Next, find the elements which were updated.
159   const size_t max_count =
160       std::max(created_or_updated.size(), deleted_or_updated.size());
161   std::vector<absl::optional<DnsSdInstanceEndpoint>> updated(max_count);
162   new_end = std::set_intersection(
163       created_or_updated.begin(), created_or_updated.end(),
164       deleted_or_updated.begin(), deleted_or_updated.end(), updated.begin(),
165       IsNotEqualOrUpdate);
166   updated.erase(new_end, updated.end());
167 
168   // Use the updated elements to find all created and deleted elements.
169   std::vector<absl::optional<DnsSdInstanceEndpoint>> created(
170       created_or_updated.size());
171   new_end = std::set_difference(
172       created_or_updated.begin(), created_or_updated.end(), updated.begin(),
173       updated.end(), created.begin(), IsNotEqualOrUpdate);
174   created.erase(new_end, created.end());
175 
176   std::vector<absl::optional<DnsSdInstanceEndpoint>> deleted(
177       deleted_or_updated.size());
178   new_end = std::set_difference(
179       deleted_or_updated.begin(), deleted_or_updated.end(), updated.begin(),
180       updated.end(), deleted.begin(), IsNotEqualOrUpdate);
181   deleted.erase(new_end, deleted.end());
182 
183   // Return the calculated elements back to the caller in the output variables.
184   created_out->reserve(created.size());
185   for (absl::optional<DnsSdInstanceEndpoint>& endpoint : created) {
186     OSP_DCHECK(endpoint.has_value());
187     created_out->push_back(std::move(endpoint.value()));
188   }
189 
190   updated_out->reserve(updated.size());
191   for (absl::optional<DnsSdInstanceEndpoint>& endpoint : updated) {
192     OSP_DCHECK(endpoint.has_value());
193     updated_out->push_back(std::move(endpoint.value()));
194   }
195 
196   deleted_out->reserve(deleted.size());
197   for (absl::optional<DnsSdInstanceEndpoint>& endpoint : deleted) {
198     OSP_DCHECK(endpoint.has_value());
199     deleted_out->push_back(std::move(endpoint.value()));
200   }
201 }
202 
203 }  // namespace
204 
QuerierImpl(MdnsService * mdns_querier,TaskRunner * task_runner,ReportingClient * reporting_client,const NetworkInterfaceConfig * network_config)205 QuerierImpl::QuerierImpl(MdnsService* mdns_querier,
206                          TaskRunner* task_runner,
207                          ReportingClient* reporting_client,
208                          const NetworkInterfaceConfig* network_config)
209     : mdns_querier_(mdns_querier),
210       task_runner_(task_runner),
211       reporting_client_(reporting_client) {
212   OSP_DCHECK(mdns_querier_);
213   OSP_DCHECK(task_runner_);
214 
215   OSP_DCHECK(network_config);
216   graph_ = DnsDataGraph::Create(network_config->network_interface());
217 }
218 
219 QuerierImpl::~QuerierImpl() = default;
220 
StartQuery(const std::string & service,Callback * callback)221 void QuerierImpl::StartQuery(const std::string& service, Callback* callback) {
222   OSP_DCHECK(callback);
223   OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
224 
225   OSP_DVLOG << "Starting DNS-SD query for service '" << service << "'";
226 
227   // Start tracking the new callback
228   const ServiceKey key(service, kLocalDomain);
229   auto it = callback_map_.emplace(key, std::vector<Callback*>{}).first;
230   it->second.push_back(callback);
231 
232   const DomainName domain = key.GetName();
233 
234   // If the associated service isn't tracked yet, start tracking it and start
235   // queries for the relevant PTR records.
236   if (!graph_->IsTracked(domain)) {
237     std::function<void(const DomainName&)> mdns_query(
238         [this, &domain](const DomainName& changed_domain) {
239           OSP_DVLOG << "Starting mDNS query for '" << domain.ToString() << "'";
240           mdns_querier_->StartQuery(changed_domain, DnsType::kANY,
241                                     DnsClass::kANY, this);
242         });
243     graph_->StartTracking(domain, std::move(mdns_query));
244     return;
245   }
246 
247   // Else, it's already being tracked so fire creation callbacks for any already
248   // found service instances.
249   const std::vector<ErrorOr<DnsSdInstanceEndpoint>> endpoints =
250       graph_->CreateEndpoints(DnsDataGraph::DomainGroup::kPtr, domain);
251   for (const auto& endpoint : endpoints) {
252     if (endpoint.is_value()) {
253       callback->OnEndpointCreated(endpoint.value());
254     }
255   }
256 }
257 
StopQuery(const std::string & service,Callback * callback)258 void QuerierImpl::StopQuery(const std::string& service, Callback* callback) {
259   OSP_DCHECK(callback);
260   OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
261 
262   OSP_DVLOG << "Stopping DNS-SD query for service '" << service << "'";
263 
264   ServiceKey key(service, kLocalDomain);
265   const auto callbacks_it = callback_map_.find(key);
266   if (callbacks_it == callback_map_.end()) {
267     return;
268   }
269 
270   std::vector<Callback*>& callbacks = callbacks_it->second;
271   const auto it = std::find(callbacks.begin(), callbacks.end(), callback);
272   if (it == callbacks.end()) {
273     return;
274   }
275 
276   callbacks.erase(it);
277   if (callbacks.empty()) {
278     callback_map_.erase(callbacks_it);
279 
280     ServiceKey key(service, kLocalDomain);
281     DomainName domain = key.GetName();
282 
283     std::function<void(const DomainName&)> stop_mdns_query(
284         [this](const DomainName& changed_domain) {
285           OSP_DVLOG << "Stopping mDNS query for '" << changed_domain.ToString()
286                     << "'";
287           mdns_querier_->StopQuery(changed_domain, DnsType::kANY,
288                                    DnsClass::kANY, this);
289         });
290     graph_->StopTracking(domain, std::move(stop_mdns_query));
291   }
292 }
293 
IsQueryRunning(const std::string & service) const294 bool QuerierImpl::IsQueryRunning(const std::string& service) const {
295   OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
296   const ServiceKey key(service, kLocalDomain);
297   return graph_->IsTracked(key.GetName());
298 }
299 
ReinitializeQueries(const std::string & service)300 void QuerierImpl::ReinitializeQueries(const std::string& service) {
301   OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
302 
303   OSP_DVLOG << "Re-initializing query for service '" << service << "'";
304 
305   const ServiceKey key(service, kLocalDomain);
306   const DomainName domain = key.GetName();
307 
308   std::function<void(const DomainName&)> start_callback(
309       [this](const DomainName& domain) {
310         mdns_querier_->StartQuery(domain, DnsType::kANY, DnsClass::kANY, this);
311       });
312   std::function<void(const DomainName&)> stop_callback(
313       [this](const DomainName& domain) {
314         mdns_querier_->StopQuery(domain, DnsType::kANY, DnsClass::kANY, this);
315       });
316   graph_->StopTracking(domain, std::move(stop_callback));
317 
318   // Restart top-level queries.
319   mdns_querier_->ReinitializeQueries(GetPtrQueryInfo(key).name);
320 
321   graph_->StartTracking(domain, std::move(start_callback));
322 }
323 
OnRecordChanged(const MdnsRecord & record,RecordChangedEvent event)324 std::vector<PendingQueryChange> QuerierImpl::OnRecordChanged(
325     const MdnsRecord& record,
326     RecordChangedEvent event) {
327   OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
328 
329   OSP_DVLOG << "Record " << record.ToString()
330             << " has received change of type '" << event << "'";
331 
332   std::function<void(Error)> log = [this](Error error) mutable {
333     reporting_client_->OnRecoverableError(
334         Error(Error::Code::kProcessReceivedRecordFailure));
335   };
336 
337   // Get the details to use for calling CreateEndpoints(). Special case PTR
338   // records to optimize performance.
339   const DomainName& create_endpoints_domain =
340       record.dns_type() != DnsType::kPTR
341           ? record.name()
342           : absl::get<PtrRecordRdata>(record.rdata()).ptr_domain();
343   const DnsDataGraph::DomainGroup create_endpoints_group =
344       record.dns_type() != DnsType::kPTR
345           ? DnsDataGraph::GetDomainGroup(record)
346           : DnsDataGraph::DomainGroup::kSrvAndTxt;
347 
348   // Get the current set of DnsSdInstanceEndpoints prior to this change. Special
349   // case PTR records to avoid iterating over unrelated child domains.
350   std::vector<ErrorOr<DnsSdInstanceEndpoint>> old_endpoints_or_errors =
351       graph_->CreateEndpoints(create_endpoints_group, create_endpoints_domain);
352 
353   // Apply the changes, creating a list of all pending changes that should be
354   // applied afterwards.
355   ErrorOr<std::vector<PendingQueryChange>> pending_changes_or_error =
356       ApplyRecordChanges(record, event);
357   if (pending_changes_or_error.is_error()) {
358     OSP_DVLOG << "Failed to apply changes for " << record.dns_type()
359               << " record change of type " << event << " with error "
360               << pending_changes_or_error.error();
361     log(std::move(pending_changes_or_error.error()));
362     return {};
363   }
364   std::vector<PendingQueryChange>& pending_changes =
365       pending_changes_or_error.value();
366 
367   // Get the new set of DnsSdInstanceEndpoints following this change.
368   std::vector<ErrorOr<DnsSdInstanceEndpoint>> new_endpoints_or_errors =
369       graph_->CreateEndpoints(create_endpoints_group, create_endpoints_domain);
370 
371   // Return early if the resulting sets are equal. This will frequently be the
372   // case, especially when both sets are empty.
373   std::sort(old_endpoints_or_errors.begin(), old_endpoints_or_errors.end());
374   std::sort(new_endpoints_or_errors.begin(), new_endpoints_or_errors.end());
375   if (old_endpoints_or_errors.size() == new_endpoints_or_errors.size() &&
376       std::equal(old_endpoints_or_errors.begin(), old_endpoints_or_errors.end(),
377                  new_endpoints_or_errors.begin())) {
378     return pending_changes;
379   }
380 
381   // Log all errors and erase them.
382   ProcessErrors(&old_endpoints_or_errors, &new_endpoints_or_errors,
383                 std::move(log));
384   const size_t old_endpoints_or_errors_count = old_endpoints_or_errors.size();
385   const size_t new_endpoints_or_errors_count = new_endpoints_or_errors.size();
386   std::vector<DnsSdInstanceEndpoint> old_endpoints =
387       GetValues(std::move(old_endpoints_or_errors));
388   std::vector<DnsSdInstanceEndpoint> new_endpoints =
389       GetValues(std::move(new_endpoints_or_errors));
390   OSP_DCHECK_EQ(old_endpoints.size(), old_endpoints_or_errors_count);
391   OSP_DCHECK_EQ(new_endpoints.size(), new_endpoints_or_errors_count);
392 
393   // Calculate the changes and call callbacks.
394   //
395   // NOTE: As the input sets are expected to be small, the generated sets will
396   // also be small.
397   std::vector<DnsSdInstanceEndpoint> created;
398   std::vector<DnsSdInstanceEndpoint> updated;
399   std::vector<DnsSdInstanceEndpoint> deleted;
400   CalculateChangeSets(std::move(old_endpoints), std::move(new_endpoints),
401                       &created, &updated, &deleted);
402 
403   InvokeChangeCallbacks(std::move(created), std::move(updated),
404                         std::move(deleted));
405   return pending_changes;
406 }
407 
InvokeChangeCallbacks(std::vector<DnsSdInstanceEndpoint> created,std::vector<DnsSdInstanceEndpoint> updated,std::vector<DnsSdInstanceEndpoint> deleted)408 void QuerierImpl::InvokeChangeCallbacks(
409     std::vector<DnsSdInstanceEndpoint> created,
410     std::vector<DnsSdInstanceEndpoint> updated,
411     std::vector<DnsSdInstanceEndpoint> deleted) {
412   // Find an endpoint and use it to create the key, or return if there is none.
413   DnsSdInstanceEndpoint* some_endpoint;
414   if (!created.empty()) {
415     some_endpoint = &created.front();
416   } else if (!updated.empty()) {
417     some_endpoint = &updated.front();
418   } else if (!deleted.empty()) {
419     some_endpoint = &deleted.front();
420   } else {
421     return;
422   }
423   ServiceKey key(some_endpoint->service_id(), some_endpoint->domain_id());
424 
425   // Find all callbacks.
426   auto it = callback_map_.find(key);
427   if (it == callback_map_.end()) {
428     return;
429   }
430 
431   // Call relevant callbacks.
432   std::vector<Callback*>& callbacks = it->second;
433   for (Callback* callback : callbacks) {
434     for (const DnsSdInstanceEndpoint& endpoint : created) {
435       callback->OnEndpointCreated(endpoint);
436     }
437     for (const DnsSdInstanceEndpoint& endpoint : updated) {
438       callback->OnEndpointUpdated(endpoint);
439     }
440     for (const DnsSdInstanceEndpoint& endpoint : deleted) {
441       callback->OnEndpointDeleted(endpoint);
442     }
443   }
444 }
445 
ApplyRecordChanges(const MdnsRecord & record,RecordChangedEvent event)446 ErrorOr<std::vector<PendingQueryChange>> QuerierImpl::ApplyRecordChanges(
447     const MdnsRecord& record,
448     RecordChangedEvent event) {
449   std::vector<PendingQueryChange> pending_changes;
450   std::function<void(DomainName)> creation_callback(
451       [this, &pending_changes](DomainName domain) mutable {
452         pending_changes.push_back({std::move(domain), DnsType::kANY,
453                                    DnsClass::kANY, this,
454                                    PendingQueryChange::kStartQuery});
455       });
456   std::function<void(DomainName)> deletion_callback(
457       [this, &pending_changes](DomainName domain) mutable {
458         pending_changes.push_back({std::move(domain), DnsType::kANY,
459                                    DnsClass::kANY, this,
460                                    PendingQueryChange::kStopQuery});
461       });
462   Error result =
463       graph_->ApplyDataRecordChange(record, event, std::move(creation_callback),
464                                     std::move(deletion_callback));
465   if (!result.ok()) {
466     return result;
467   }
468 
469   return pending_changes;
470 }
471 
472 }  // namespace discovery
473 }  // namespace openscreen
474