1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #ifndef DISCOVERY_PUBLIC_DNS_SD_SERVICE_WATCHER_H_
6 #define DISCOVERY_PUBLIC_DNS_SD_SERVICE_WATCHER_H_
7 
8 #include <memory>
9 #include <sstream>
10 #include <string>
11 #include <unordered_map>
12 #include <utility>
13 #include <vector>
14 
15 #include "discovery/dnssd/public/dns_sd_instance.h"
16 #include "discovery/dnssd/public/dns_sd_querier.h"
17 #include "discovery/dnssd/public/dns_sd_service.h"
18 #include "platform/base/error.h"
19 #include "util/hashing.h"
20 #include "util/osp_logging.h"
21 
22 namespace openscreen {
23 namespace discovery {
24 
25 // This class represents a top-level discovery API which sits on top of DNS-SD.
26 // T is the service-specific type which stores information regarding a specific
27 // service instance.
28 // TODO(rwkeane): Include reporting client as ctor parameter once parallel CLs
29 // are in.
30 // NOTE: This class is not thread-safe and calls will be made to DnsSdService in
31 // the same sequence and on the same threads from which these methods are
32 // called. This is to avoid forcing design decisions on embedders who write
33 // their own implementations of the DNS-SD layer.
34 template <typename T>
35 class DnsSdServiceWatcher : public DnsSdQuerier::Callback {
36  public:
37   using ConstRefT = std::reference_wrapper<const T>;
38 
39   enum class ServicesUpdatedState {
40       EndpointCreated,
41       EndpointUpdated,
42       EndpointDeleted,
43   };
44 
45   // The method which will be called when any new service instance is
46   // discovered, a service instance changes its data (such as TXT or A data), or
47   // a previously discovered service instance ceases to be available. The vector
48   // is the set of all currently active service instances which have been
49   // discovered so far.
50   // NOTE: This callback may not modify the DnsSdServiceWatcher instance from
51   // which it is called.
52   using ServicesUpdatedCallback =
53       std::function<void(std::vector<ConstRefT> services,
54                          ConstRefT service,
55                          ServicesUpdatedState state)>;
56 
57   // This function type is responsible for converting from a DNS service
58   // instance (received from another mDNS endpoint) to a T type to be returned
59   // to the caller.
60   using ServiceConverter =
61       std::function<ErrorOr<T>(const DnsSdInstanceEndpoint&)>;
62 
DnsSdServiceWatcher(DnsSdService * service,std::string service_name,ServiceConverter conversion,ServicesUpdatedCallback callback)63   DnsSdServiceWatcher(DnsSdService* service,
64                       std::string service_name,
65                       ServiceConverter conversion,
66                       ServicesUpdatedCallback callback)
67       : conversion_(conversion),
68         service_name_(std::move(service_name)),
69         callback_(std::move(callback)),
70         querier_(service ? service->GetQuerier() : nullptr) {
71     OSP_DCHECK(querier_);
72   }
73 
74   ~DnsSdServiceWatcher() = default;
75 
76   // Starts service discovery.
StartDiscovery()77   void StartDiscovery() {
78     OSP_DCHECK(!is_running_);
79     is_running_ = true;
80 
81     querier_->StartQuery(service_name_, this);
82   }
83 
84   // Stops service discovery.
StopDiscovery()85   void StopDiscovery() {
86     OSP_DCHECK(is_running_);
87     is_running_ = false;
88 
89     querier_->StopQuery(service_name_, this);
90   }
91 
92   // Returns whether or not discovery is currently ongoing.
is_running()93   bool is_running() const { return is_running_; }
94 
95   // Re-initializes the process of service discovery, even if the underlying
96   // implementation would not normally do so at this time. All previously
97   // received service data is discarded.
98   // NOTE: This call will return an error if StartDiscovery has not yet been
99   // called.
ForceRefresh()100   Error ForceRefresh() {
101     if (!is_running_) {
102       return Error::Code::kOperationInvalid;
103     }
104 
105     querier_->ReinitializeQueries(service_name_);
106     records_.clear();
107     return Error::None();
108   }
109 
110   // Re-initializes the process of service discovery, even if the underlying
111   // implementation would not normally do so at this time. All previously
112   // received service data is persisted.
113   // NOTE: This call will return an error if StartDiscovery has not yet been
114   // called.
DiscoverNow()115   Error DiscoverNow() {
116     if (!is_running_) {
117       return Error::Code::kOperationInvalid;
118     }
119 
120     querier_->ReinitializeQueries(service_name_);
121     return Error::None();
122   }
123 
124   // Returns the set of services which have been received so far.
GetServices()125   std::vector<ConstRefT> GetServices() const {
126     std::vector<ConstRefT> refs;
127     for (const auto& pair : records_) {
128       refs.push_back(*pair.second.get());
129     }
130 
131     OSP_DVLOG << "Currently " << records_.size()
132               << " known service instances: [" << GetInstanceNames() << "]";
133 
134     return refs;
135   }
136 
137  private:
138   friend class TestServiceWatcher;
139 
140   using EndpointKey = std::pair<std::string, NetworkInterfaceIndex>;
141 
142   // DnsSdQuerier::Callback overrides.
OnEndpointCreated(const DnsSdInstanceEndpoint & new_endpoint)143   void OnEndpointCreated(const DnsSdInstanceEndpoint& new_endpoint) override {
144     // NOTE: existence is not checked because records may be overwritten after
145     // querier_->ReinitializeQueries() is called.
146     ErrorOr<T> record = conversion_(new_endpoint);
147     if (record.is_error()) {
148       OSP_LOG_INFO << "Conversion of received record failed with error: "
149                    << record.error();
150       return;
151     }
152     records_[GetKey(new_endpoint)] =
153         std::make_unique<T>(std::move(record.value()));
154     callback_(GetServices(), *records_[GetKey(new_endpoint)].get(), ServicesUpdatedState::EndpointCreated);
155   }
156 
OnEndpointUpdated(const DnsSdInstanceEndpoint & modified_endpoint)157   void OnEndpointUpdated(
158       const DnsSdInstanceEndpoint& modified_endpoint) override {
159     auto it = records_.find(GetKey(modified_endpoint));
160     if (it != records_.end()) {
161       ErrorOr<T> record = conversion_(modified_endpoint);
162       if (record.is_error()) {
163         OSP_LOG_INFO << "Conversion of received record failed with error: "
164                      << record.error();
165         return;
166       }
167       auto ptr = std::make_unique<T>(std::move(record.value()));
168       it->second.swap(ptr);
169 
170       callback_(GetServices(), *it->second.get(), ServicesUpdatedState::EndpointUpdated);
171     } else {
172       OSP_LOG_INFO
173           << "Received modified record for non-existent DNS-SD Instance "
174           << modified_endpoint.instance_id();
175     }
176   }
177 
OnEndpointDeleted(const DnsSdInstanceEndpoint & old_endpoint)178   void OnEndpointDeleted(const DnsSdInstanceEndpoint& old_endpoint) override {
179     auto it = records_.find(GetKey(old_endpoint));
180     if (it != records_.end()) {
181       auto ptr = std::move(it->second);
182       records_.erase(it);
183       callback_(GetServices(), *ptr.get(), ServicesUpdatedState::EndpointDeleted);
184     } else {
185       OSP_LOG_INFO
186           << "Received deletion of record for non-existent DNS-SD Instance "
187           << old_endpoint.instance_id();
188     }
189   }
190 
GetKey(const DnsSdInstanceEndpoint & endpoint)191   EndpointKey GetKey(const DnsSdInstanceEndpoint& endpoint) const {
192     return std::make_pair(endpoint.instance_id(), endpoint.network_interface());
193   }
194 
GetInstanceNames()195   std::string GetInstanceNames() const {
196     auto it = records_.begin();
197     if (it == records_.end()) {
198       return "";
199     }
200 
201     std::stringstream ss;
202     ss << it->first.first << "/" << it->first.second;
203     while (++it != records_.end()) {
204       ss << ", " << it->first.first << "/" << it->first.second;
205     }
206     return ss.str();
207   }
208 
209   // Set of all instance ids found so far, mapped to the T type that it
210   // represents. unique_ptr<T> entities are used so that the const refs returned
211   // from GetServices() and the ServicesUpdatedCallback can persist even once
212   // this map is resized.
213   // NOTE: Unordered map is used because this set is in  many cases expected to
214   // be large.
215   std::unordered_map<EndpointKey, std::unique_ptr<T>, PairHash> records_;
216 
217   // Represents whether discovery is currently running or not.
218   bool is_running_ = false;
219 
220   // Converts from the DNS-SD representation of a service to the outside
221   // representation.
222   ServiceConverter conversion_;
223 
224   std::string service_name_;
225   ServicesUpdatedCallback callback_;
226   DnsSdQuerier* const querier_;
227 };
228 
229 }  // namespace discovery
230 }  // namespace openscreen
231 
232 #endif  // DISCOVERY_PUBLIC_DNS_SD_SERVICE_WATCHER_H_
233