1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #ifndef DISCOVERY_MDNS_MDNS_QUERIER_H_
6 #define DISCOVERY_MDNS_MDNS_QUERIER_H_
7 
8 #include <list>
9 #include <map>
10 #include <memory>
11 #include <vector>
12 
13 #include "discovery/common/config.h"
14 #include "discovery/mdns/mdns_receiver.h"
15 #include "discovery/mdns/mdns_record_changed_callback.h"
16 #include "discovery/mdns/mdns_records.h"
17 #include "discovery/mdns/mdns_trackers.h"
18 #include "platform/api/task_runner.h"
19 
20 namespace openscreen {
21 namespace discovery {
22 
23 class MdnsRandom;
24 class MdnsSender;
25 class MdnsQuestionTracker;
26 class MdnsRecordTracker;
27 class ReportingClient;
28 
29 class MdnsQuerier : public MdnsReceiver::ResponseClient {
30  public:
31   MdnsQuerier(MdnsSender* sender,
32               MdnsReceiver* receiver,
33               TaskRunner* task_runner,
34               ClockNowFunctionPtr now_function,
35               MdnsRandom* random_delay,
36               ReportingClient* reporting_client,
37               Config config);
38   MdnsQuerier(const MdnsQuerier& other) = delete;
39   MdnsQuerier(MdnsQuerier&& other) noexcept = delete;
40   MdnsQuerier& operator=(const MdnsQuerier& other) = delete;
41   MdnsQuerier& operator=(MdnsQuerier&& other) noexcept = delete;
42   ~MdnsQuerier() override;
43 
44   // Starts an mDNS query with the given name, DNS type, and DNS class.  Updated
45   // records are passed to |callback|.  The caller must ensure |callback|
46   // remains alive while it is registered with a query.
47   // NOTE: This call is only valid for |dns_type| values:
48   // - DnsType::kA
49   // - DnsType::kPTR
50   // - DnsType::kTXT
51   // - DnsType::kAAAA
52   // - DnsType::kSRV
53   // - DnsType::kANY
54   void StartQuery(const DomainName& name,
55                   DnsType dns_type,
56                   DnsClass dns_class,
57                   MdnsRecordChangedCallback* callback);
58 
59   // Stops an mDNS query with the given name, DNS type, and DNS class.
60   // |callback| must be the same callback pointer that was previously passed to
61   // StartQuery.
62   void StopQuery(const DomainName& name,
63                  DnsType dns_type,
64                  DnsClass dns_class,
65                  MdnsRecordChangedCallback* callback);
66 
67   // Re-initializes the process of service discovery for the provided domain
68   // name. All ongoing queries for this domain are restarted and any previously
69   // received query results are discarded.
70   void ReinitializeQueries(const DomainName& name);
71 
72  private:
73   struct CallbackInfo {
74     MdnsRecordChangedCallback* const callback;
75     const DnsType dns_type;
76     const DnsClass dns_class;
77   };
78 
79   // Represents a Least Recently Used cache of MdnsRecordTrackers.
80   class RecordTrackerLruCache {
81    public:
82     using RecordTrackerConstRef =
83         std::reference_wrapper<const MdnsRecordTracker>;
84     using TrackerApplicableCheck =
85         std::function<bool(const MdnsRecordTracker&)>;
86     using TrackerChangeCallback = std::function<void(const MdnsRecordTracker&)>;
87 
88     RecordTrackerLruCache(MdnsQuerier* querier,
89                           MdnsSender* sender,
90                           MdnsRandom* random_delay,
91                           TaskRunner* task_runner,
92                           ClockNowFunctionPtr now_function,
93                           ReportingClient* reporting_client,
94                           const Config& config);
95 
96     // Returns all trackers with the associated |name| such that its type
97     // represents a type corresponding to |dns_type| and class corresponding to
98     // |dns_class|.
99     std::vector<RecordTrackerConstRef> Find(const DomainName& name);
100     std::vector<RecordTrackerConstRef> Find(const DomainName& name,
101                                             DnsType dns_type,
102                                             DnsClass dns_class);
103 
104     // Calls ExpireSoon on all record trackers in the provided domain which
105     // match the provided applicability check. Returns the number of trackers
106     // marked for expiry.
107     int ExpireSoon(const DomainName& name, TrackerApplicableCheck check);
108 
109     // Erases all record trackers in the provided domain which match the
110     // provided applicability check. Returns the number of trackers erased.
111     int Erase(const DomainName& name, TrackerApplicableCheck check);
112 
113     // Updates all record trackers in the domain |record.name()| which match the
114     // provided applicability check using the provided record. Returns the
115     // number of records successfully updated.
116     int Update(const MdnsRecord& record, TrackerApplicableCheck check);
117     int Update(const MdnsRecord& record,
118                TrackerApplicableCheck check,
119                TrackerChangeCallback on_rdata_update);
120 
121     // Creates a record tracker of the given type associated with the provided
122     // record.
123     const MdnsRecordTracker& StartTracking(MdnsRecord record, DnsType type);
124 
size()125     size_t size() { return records_.size(); }
126 
127    private:
128     using LruList = std::list<MdnsRecordTracker>;
129     using RecordMap = std::multimap<DomainName, LruList::iterator>;
130 
131     void MoveToBeginning(RecordMap::iterator iterator);
132     void MoveToEnd(RecordMap::iterator iterator);
133 
134     MdnsQuerier* const querier_;
135     MdnsSender* const sender_;
136     MdnsRandom* const random_delay_;
137     TaskRunner* const task_runner_;
138     ClockNowFunctionPtr now_function_;
139     ReportingClient* reporting_client_;
140     const Config& config_;
141 
142     // List of RecordTracker instances used by this instance where the least
143     // recently updated element (or next to be deleted element) appears at the
144     // end of the list.
145     LruList lru_order_;
146 
147     // A collection of active known record trackers, each is identified by
148     // domain name, DNS record type, and DNS record class. Multimap key is
149     // domain name only to allow easy support for wildcard processing for DNS
150     // record type and class and allow storing shared records that differ only
151     // in RDATA.
152     //
153     // MdnsRecordTracker instances are stored as unique_ptr so they are not
154     // moved around in memory when the collection is modified. This allows
155     // passing a pointer to MdnsQuestionTracker to a task running on the
156     // TaskRunner.
157     RecordMap records_;
158   };
159 
160   friend class MdnsQuerierTest;
161 
162   // MdnsReceiver::ResponseClient overrides.
163   void OnMessageReceived(const MdnsMessage& message) override;
164 
165   // Expires the record tracker provided. This callback is passed to owned
166   // MdnsRecordTracker instances in |records_|.
167   void OnRecordExpired(const MdnsRecordTracker* tracker,
168                        const MdnsRecord& record);
169 
170   // Determines whether a record received by this querier should be processed
171   // or dropped.
172   bool ShouldAnswerRecordBeProcessed(const MdnsRecord& answer);
173 
174   // Processes any record update, calling into the below methods as needed.
175   // NOTE: All records of type OPT are dropped, as they should not be cached per
176   // RFC6891.
177   void ProcessRecord(const MdnsRecord& records);
178 
179   // Processes a shared record update as a record of type |type|.
180   void ProcessSharedRecord(const MdnsRecord& record, DnsType type);
181 
182   // Processes a unique record update as a record of type |type|.
183   void ProcessUniqueRecord(const MdnsRecord& record, DnsType type);
184 
185   // Called when exactly one tracker is associated with a provided key.
186   // Determines the type of update being executed by this update call, then
187   // fires the appropriate callback.
188   void ProcessSinglyTrackedUniqueRecord(const MdnsRecord& record,
189                                         const MdnsRecordTracker& tracker);
190 
191   // Called when multiple records are associated with the same key. Expire all
192   // record with non-matching RDATA. Update the record with the matching RDATA
193   // if it exists, otherwise insert a new record.
194   void ProcessMultiTrackedUniqueRecord(const MdnsRecord& record,
195                                        DnsType dns_type);
196 
197   // Calls all callbacks associated with the provided record.
198   void ProcessCallbacks(const MdnsRecord& record, RecordChangedEvent event);
199 
200   // Begins tracking the provided question.
201   void AddQuestion(const MdnsQuestion& question);
202 
203   // Begins tracking the provided record.
204   void AddRecord(const MdnsRecord& record, DnsType type);
205 
206   // Applies the supplied pending changes.
207   void ApplyPendingChanges(std::vector<PendingQueryChange> pending_changes);
208 
209   MdnsSender* const sender_;
210   MdnsReceiver* const receiver_;
211   TaskRunner* const task_runner_;
212   const ClockNowFunctionPtr now_function_;
213   MdnsRandom* const random_delay_;
214   ReportingClient* reporting_client_;
215   Config config_;
216 
217   // A collection of active question trackers, each is uniquely identified by
218   // domain name, DNS record type, and DNS record class. Multimap key is domain
219   // name only to allow easy support for wildcard processing for DNS record type
220   // and class. MdnsQuestionTracker instances are stored as unique_ptr so they
221   // are not moved around in memory when the collection is modified. This allows
222   // passing a pointer to MdnsQuestionTracker to a task running on the
223   // TaskRunner.
224   std::multimap<DomainName, std::unique_ptr<MdnsQuestionTracker>> questions_;
225 
226   // Set of records tracked by this querier.
227   RecordTrackerLruCache records_;
228 
229   // A collection of callbacks passed to StartQuery method. Each is identified
230   // by domain name, DNS record type, and DNS record class, but there can be
231   // more than one callback for a particular query. Multimap key is domain name
232   // only to allow easy matching of records against callbacks that have wildcard
233   // DNS class and/or DNS type.
234   std::multimap<DomainName, CallbackInfo> callbacks_;
235 };
236 
237 }  // namespace discovery
238 }  // namespace openscreen
239 
240 #endif  // DISCOVERY_MDNS_MDNS_QUERIER_H_
241