1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #ifndef DISCOVERY_MDNS_MDNS_RESPONDER_H_
6 #define DISCOVERY_MDNS_MDNS_RESPONDER_H_
7 
8 #include <map>
9 #include <memory>
10 #include <vector>
11 
12 #include "discovery/mdns/mdns_records.h"
13 #include "platform/api/time.h"
14 #include "platform/base/macros.h"
15 #include "util/alarm.h"
16 
17 namespace openscreen {
18 
19 struct IPEndpoint;
20 class TaskRunner;
21 
22 namespace discovery {
23 
24 struct Config;
25 class MdnsMessage;
26 class MdnsProbeManager;
27 class MdnsRandom;
28 class MdnsReceiver;
29 class MdnsRecordChangedCallback;
30 class MdnsSender;
31 class MdnsQuerier;
32 
33 // This class is responsible for responding to any incoming mDNS Queries
34 // received via the OnMessageReceived() method. When responding, the generated
35 // MdnsMessage will contain the requested record(s) in the answers section, or
36 // an NSEC record to specify that the requested record was not found in the case
37 // of a query with DnsType aside from ANY. In the case where records are found,
38 // the additional records field may be populated with additional records, as
39 // specified in RFCs 6762 and 6763.
40 class MdnsResponder {
41  public:
42   // Class to handle querying for existing records.
43   class RecordHandler {
44    public:
45     virtual ~RecordHandler();
46 
47     // Returns whether this service has one or more records matching the
48     // provided name, type, and class.
49     virtual bool HasRecords(const DomainName& name,
50                             DnsType type,
51                             DnsClass clazz) = 0;
52 
53     // Returns all records owned by this service with name, type, and class
54     // matching the provided values.
55     virtual std::vector<MdnsRecord::ConstRef> GetRecords(const DomainName& name,
56                                                          DnsType type,
57                                                          DnsClass clazz) = 0;
58 
59     // Enumerates all PTR records owned by this service.
60     virtual std::vector<MdnsRecord::ConstRef> GetPtrRecords(DnsClass clazz) = 0;
61   };
62 
63   // |record_handler|, |sender|, |receiver|, |task_runner|, |random_delay|, and
64   // |config| are expected to persist for the duration of this instance's
65   // lifetime.
66   MdnsResponder(RecordHandler* record_handler,
67                 MdnsProbeManager* ownership_handler,
68                 MdnsSender* sender,
69                 MdnsReceiver* receiver,
70                 TaskRunner* task_runner,
71                 ClockNowFunctionPtr now_function,
72                 MdnsRandom* random_delay,
73                 const Config& config);
74   ~MdnsResponder();
75 
76   OSP_DISALLOW_COPY_AND_ASSIGN(MdnsResponder);
77 
78  private:
79   // Class which handles processing and responding to queries segmented into
80   // multiple messages.
81   class TruncatedQuery {
82    public:
83     // |responder| and |task_runner| are expected to persist for the duration of
84     // this instance's lifetime.
85     TruncatedQuery(MdnsResponder* responder,
86                    TaskRunner* task_runner,
87                    ClockNowFunctionPtr now_function,
88                    IPEndpoint src,
89                    const MdnsMessage& message,
90                    const Config& config);
91     TruncatedQuery(const TruncatedQuery& other) = delete;
92     TruncatedQuery(TruncatedQuery&& other) = delete;
93 
94     TruncatedQuery& operator=(const TruncatedQuery& other) = delete;
95     TruncatedQuery& operator=(TruncatedQuery&& other) = delete;
96 
97     // Sets the query associated with this instance. Must only be called if no
98     // query has already been set, here or through the ctor.
99     void SetQuery(const MdnsMessage& message);
100 
101     // Adds additional known answers.
102     void AddKnownAnswers(const std::vector<MdnsRecord>& records);
103 
104     // Responds to the stored queries.
105     void SendResponse();
106 
src()107     const IPEndpoint& src() const { return src_; }
questions()108     const std::vector<MdnsQuestion>& questions() const { return questions_; }
known_answers()109     const std::vector<MdnsRecord>& known_answers() const {
110       return known_answers_;
111     }
112 
113    private:
114     void RescheduleSend();
115 
116     // The number of messages received so far associated with this known answer
117     // query.
118     int messages_received_so_far = 0;
119 
120     const int max_allowed_messages_;
121     const int max_allowed_records_;
122     const IPEndpoint src_;
123     MdnsResponder* const responder_;
124 
125     std::vector<MdnsQuestion> questions_;
126     std::vector<MdnsRecord> known_answers_;
127     Alarm alarm_;
128   };
129 
130   // Called when a new MdnsMessage is received.
131   void OnMessageReceived(const MdnsMessage& message, const IPEndpoint& src);
132 
133   // Responds a truncated query for which all known answers have been received.
134   void RespondToTruncatedQuery(TruncatedQuery* query);
135 
136   // Processes a message associated with a multi-packet truncated query.
137   void ProcessMultiPacketTruncatedMessage(const MdnsMessage& message,
138                                           const IPEndpoint& src);
139 
140   // Processes queries provided.
141   void ProcessQueries(const IPEndpoint& src,
142                       const std::vector<MdnsQuestion>& questions,
143                       const std::vector<MdnsRecord>& known_answers);
144 
145   // Sends the response to the provided query.
146   void SendResponse(const MdnsQuestion& question,
147                     const std::vector<MdnsRecord>& known_answers,
148                     std::function<void(const MdnsMessage&)> send_response,
149                     bool is_exclusive_owner);
150 
151   // Set of all truncated queries received so far. Per RFC 6762 section 7.1,
152   // matching of a query with additional known answers should be done based on
153   // the source address.
154   // NOTE: unique_ptrs used because TruncatedQuery is not movable.
155   std::map<IPEndpoint, std::unique_ptr<TruncatedQuery>> truncated_queries_;
156 
157   RecordHandler* const record_handler_;
158   MdnsProbeManager* const ownership_handler_;
159   MdnsSender* const sender_;
160   MdnsReceiver* const receiver_;
161   TaskRunner* const task_runner_;
162   const ClockNowFunctionPtr now_function_;
163   MdnsRandom* const random_delay_;
164   const Config& config_;
165 
166   friend class MdnsResponderTest;
167 };
168 
169 }  // namespace discovery
170 }  // namespace openscreen
171 
172 #endif  // DISCOVERY_MDNS_MDNS_RESPONDER_H_
173