1 //
2 // Copyright (C) 2012 The Android Open Source Project
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 
17 #include "shill/dns_client.h"
18 
19 #include <arpa/inet.h>
20 #include <netdb.h>
21 #include <netinet/in.h>
22 #include <sys/socket.h>
23 
24 #include <map>
25 #include <memory>
26 #include <set>
27 #include <string>
28 #include <vector>
29 
30 #include <base/bind.h>
31 #include <base/bind_helpers.h>
32 #include <base/stl_util.h>
33 #include <base/strings/string_number_conversions.h>
34 
35 #include "shill/logging.h"
36 #include "shill/net/shill_time.h"
37 #include "shill/shill_ares.h"
38 
39 using base::Bind;
40 using base::Unretained;
41 using std::map;
42 using std::set;
43 using std::string;
44 using std::vector;
45 
46 namespace shill {
47 
48 namespace Logging {
49 static auto kModuleLogScope = ScopeLogger::kDNS;
ObjectID(DNSClient * d)50 static string ObjectID(DNSClient* d) { return d->interface_name(); }
51 }
52 
53 const char DNSClient::kErrorNoData[] = "The query response contains no answers";
54 const char DNSClient::kErrorFormErr[] = "The server says the query is bad";
55 const char DNSClient::kErrorServerFail[] = "The server says it had a failure";
56 const char DNSClient::kErrorNotFound[] = "The queried-for domain was not found";
57 const char DNSClient::kErrorNotImp[] = "The server doesn't implement operation";
58 const char DNSClient::kErrorRefused[] = "The server replied, refused the query";
59 const char DNSClient::kErrorBadQuery[] = "Locally we could not format a query";
60 const char DNSClient::kErrorNetRefused[] = "The network connection was refused";
61 const char DNSClient::kErrorTimedOut[] = "The network connection was timed out";
62 const char DNSClient::kErrorUnknown[] = "DNS Resolver unknown internal error";
63 
64 const int DNSClient::kDefaultDNSPort = 53;
65 
66 // Private to the implementation of resolver so callers don't include ares.h
67 struct DNSClientState {
DNSClientStateshill::DNSClientState68   DNSClientState() : channel(nullptr), start_time{} {}
69 
70   ares_channel channel;
71   map<ares_socket_t, std::shared_ptr<IOHandler>> read_handlers;
72   map<ares_socket_t, std::shared_ptr<IOHandler>> write_handlers;
73   struct timeval start_time;
74 };
75 
DNSClient(IPAddress::Family family,const string & interface_name,const vector<string> & dns_servers,int timeout_ms,EventDispatcher * dispatcher,const ClientCallback & callback)76 DNSClient::DNSClient(IPAddress::Family family,
77                      const string& interface_name,
78                      const vector<string>& dns_servers,
79                      int timeout_ms,
80                      EventDispatcher* dispatcher,
81                      const ClientCallback& callback)
82     : address_(IPAddress(family)),
83       interface_name_(interface_name),
84       dns_servers_(dns_servers),
85       dispatcher_(dispatcher),
86       callback_(callback),
87       timeout_ms_(timeout_ms),
88       running_(false),
89       weak_ptr_factory_(this),
90       ares_(Ares::GetInstance()),
91       time_(Time::GetInstance()) {}
92 
~DNSClient()93 DNSClient::~DNSClient() {
94   Stop();
95 }
96 
Start(const string & hostname,Error * error)97 bool DNSClient::Start(const string& hostname, Error* error) {
98   if (running_) {
99     Error::PopulateAndLog(FROM_HERE, error, Error::kInProgress,
100                           "Only one DNS request is allowed at a time");
101     return false;
102   }
103 
104   if (!resolver_state_.get()) {
105     struct ares_options options;
106     memset(&options, 0, sizeof(options));
107     options.timeout = timeout_ms_;
108 
109     if (dns_servers_.empty()) {
110       Error::PopulateAndLog(FROM_HERE, error, Error::kInvalidArguments,
111                             "No valid DNS server addresses");
112       return false;
113     }
114 
115     resolver_state_.reset(new DNSClientState);
116     int status = ares_->InitOptions(&resolver_state_->channel,
117                                    &options,
118                                    ARES_OPT_TIMEOUTMS);
119     if (status != ARES_SUCCESS) {
120       Error::PopulateAndLog(FROM_HERE, error, Error::kOperationFailed,
121                             "ARES initialization returns error code: " +
122                             base::IntToString(status));
123       resolver_state_.reset();
124       return false;
125     }
126 
127     // Format DNS server addresses string as "host:port[,host:port...]" to be
128     // used in call to ares_set_servers_csv for setting DNS server addresses.
129     // There is a bug in ares library when parsing IPv6 addresses, where it
130     // always assumes the port number are specified when address contains ":".
131     // So when IPv6 address are given without port number as "xx:xx:xx::yy",the
132     // parser would parse the address as "xx:xx:xx:" and port number as "yy".
133     // To work around this bug, port number are added to each address.
134     //
135     // Alternatively, we can use ares_set_servers instead, where we would
136     // explicitly construct a link list of ares_addr_node.
137     string server_addresses;
138     bool first = true;
139     for (const auto& ip : dns_servers_) {
140       if (!first) {
141         server_addresses += ",";
142       } else {
143         first = false;
144       }
145       server_addresses += (ip + ":" + base::IntToString(kDefaultDNSPort));
146     }
147     status = ares_->SetServersCsv(resolver_state_->channel,
148                                   server_addresses.c_str());
149     if (status != ARES_SUCCESS) {
150       Error::PopulateAndLog(FROM_HERE, error, Error::kOperationFailed,
151                             "ARES set DNS servers error code: " +
152                             base::IntToString(status));
153       resolver_state_.reset();
154       return false;
155     }
156 
157     ares_->SetLocalDev(resolver_state_->channel, interface_name_.c_str());
158   }
159 
160   running_ = true;
161   time_->GetTimeMonotonic(&resolver_state_->start_time);
162   ares_->GetHostByName(resolver_state_->channel, hostname.c_str(),
163                        address_.family(), ReceiveDNSReplyCB, this);
164 
165   if (!RefreshHandles()) {
166     LOG(ERROR) << "Impossibly short timeout.";
167     error->CopyFrom(error_);
168     Stop();
169     return false;
170   }
171 
172   return true;
173 }
174 
Stop()175 void DNSClient::Stop() {
176   SLOG(this, 3) << "In " << __func__;
177   if (!resolver_state_.get()) {
178     return;
179   }
180 
181   running_ = false;
182   weak_ptr_factory_.InvalidateWeakPtrs();
183   error_.Reset();
184   address_.SetAddressToDefault();
185   ares_->Destroy(resolver_state_->channel);
186   resolver_state_.reset();
187 }
188 
IsActive() const189 bool DNSClient::IsActive() const {
190   return running_;
191 }
192 
193 // We delay our call to completion so that we exit all IOHandlers, and
194 // can clean up all of our local state before calling the callback, or
195 // during the process of the execution of the callee (which is free to
196 // call our destructor safely).
HandleCompletion()197 void DNSClient::HandleCompletion() {
198   SLOG(this, 3) << "In " << __func__;
199   Error error;
200   error.CopyFrom(error_);
201   IPAddress address(address_);
202   if (!error.IsSuccess()) {
203     // If the DNS request did not succeed, do not trust it for future
204     // attempts.
205     Stop();
206   } else {
207     // Prepare our state for the next request without destroying the
208     // current ARES state.
209     error_.Reset();
210     address_.SetAddressToDefault();
211   }
212   callback_.Run(error, address);
213 }
214 
HandleDNSRead(int fd)215 void DNSClient::HandleDNSRead(int fd) {
216   ares_->ProcessFd(resolver_state_->channel, fd, ARES_SOCKET_BAD);
217   RefreshHandles();
218 }
219 
HandleDNSWrite(int fd)220 void DNSClient::HandleDNSWrite(int fd) {
221   ares_->ProcessFd(resolver_state_->channel, ARES_SOCKET_BAD, fd);
222   RefreshHandles();
223 }
224 
HandleTimeout()225 void DNSClient::HandleTimeout() {
226   ares_->ProcessFd(resolver_state_->channel, ARES_SOCKET_BAD, ARES_SOCKET_BAD);
227   RefreshHandles();
228 }
229 
ReceiveDNSReply(int status,struct hostent * hostent)230 void DNSClient::ReceiveDNSReply(int status, struct hostent* hostent) {
231   if (!running_) {
232     // We can be called during ARES shutdown -- ignore these events.
233     return;
234   }
235   SLOG(this, 3) << "In " << __func__;
236   running_ = false;
237   timeout_closure_.Cancel();
238   dispatcher_->PostTask(Bind(&DNSClient::HandleCompletion,
239                              weak_ptr_factory_.GetWeakPtr()));
240 
241   if (status == ARES_SUCCESS &&
242       hostent != nullptr &&
243       hostent->h_addrtype == address_.family() &&
244       static_cast<size_t>(hostent->h_length) ==
245       IPAddress::GetAddressLength(address_.family()) &&
246       hostent->h_addr_list != nullptr &&
247       hostent->h_addr_list[0] != nullptr) {
248     address_ = IPAddress(address_.family(),
249                          ByteString(reinterpret_cast<unsigned char*>(
250                              hostent->h_addr_list[0]), hostent->h_length));
251   } else {
252     switch (status) {
253       case ARES_ENODATA:
254         error_.Populate(Error::kOperationFailed, kErrorNoData);
255         break;
256       case ARES_EFORMERR:
257         error_.Populate(Error::kOperationFailed, kErrorFormErr);
258         break;
259       case ARES_ESERVFAIL:
260         error_.Populate(Error::kOperationFailed, kErrorServerFail);
261         break;
262       case ARES_ENOTFOUND:
263         error_.Populate(Error::kOperationFailed, kErrorNotFound);
264         break;
265       case ARES_ENOTIMP:
266         error_.Populate(Error::kOperationFailed, kErrorNotImp);
267         break;
268       case ARES_EREFUSED:
269         error_.Populate(Error::kOperationFailed, kErrorRefused);
270         break;
271       case ARES_EBADQUERY:
272       case ARES_EBADNAME:
273       case ARES_EBADFAMILY:
274       case ARES_EBADRESP:
275         error_.Populate(Error::kOperationFailed, kErrorBadQuery);
276         break;
277       case ARES_ECONNREFUSED:
278         error_.Populate(Error::kOperationFailed, kErrorNetRefused);
279         break;
280       case ARES_ETIMEOUT:
281         error_.Populate(Error::kOperationTimeout, kErrorTimedOut);
282         break;
283       default:
284         error_.Populate(Error::kOperationFailed, kErrorUnknown);
285         if (status == ARES_SUCCESS) {
286           LOG(ERROR) << "ARES returned success but hostent was invalid!";
287         } else {
288           LOG(ERROR) << "ARES returned unhandled error status " << status;
289         }
290         break;
291     }
292   }
293 }
294 
ReceiveDNSReplyCB(void * arg,int status,int,struct hostent * hostent)295 void DNSClient::ReceiveDNSReplyCB(void* arg, int status,
296                                   int /*timeouts*/,
297                                   struct hostent* hostent) {
298   DNSClient* res = static_cast<DNSClient*>(arg);
299   res->ReceiveDNSReply(status, hostent);
300 }
301 
RefreshHandles()302 bool DNSClient::RefreshHandles() {
303   map<ares_socket_t, std::shared_ptr<IOHandler>> old_read =
304       resolver_state_->read_handlers;
305   map<ares_socket_t, std::shared_ptr<IOHandler>> old_write =
306       resolver_state_->write_handlers;
307 
308   resolver_state_->read_handlers.clear();
309   resolver_state_->write_handlers.clear();
310 
311   ares_socket_t sockets[ARES_GETSOCK_MAXNUM];
312   int action_bits = ares_->GetSock(resolver_state_->channel, sockets,
313                                    ARES_GETSOCK_MAXNUM);
314 
315   base::Callback<void(int)> read_callback(
316       Bind(&DNSClient::HandleDNSRead, weak_ptr_factory_.GetWeakPtr()));
317   base::Callback<void(int)> write_callback(
318       Bind(&DNSClient::HandleDNSWrite, weak_ptr_factory_.GetWeakPtr()));
319   for (int i = 0; i < ARES_GETSOCK_MAXNUM; i++) {
320     if (ARES_GETSOCK_READABLE(action_bits, i)) {
321       if (ContainsKey(old_read, sockets[i])) {
322         resolver_state_->read_handlers[sockets[i]] = old_read[sockets[i]];
323       } else {
324         resolver_state_->read_handlers[sockets[i]] =
325             std::shared_ptr<IOHandler> (
326                 dispatcher_->CreateReadyHandler(sockets[i],
327                                                 IOHandler::kModeInput,
328                                                 read_callback));
329       }
330     }
331     if (ARES_GETSOCK_WRITABLE(action_bits, i)) {
332       if (ContainsKey(old_write, sockets[i])) {
333         resolver_state_->write_handlers[sockets[i]] = old_write[sockets[i]];
334       } else {
335         resolver_state_->write_handlers[sockets[i]] =
336             std::shared_ptr<IOHandler> (
337                 dispatcher_->CreateReadyHandler(sockets[i],
338                                                 IOHandler::kModeOutput,
339                                                 write_callback));
340       }
341     }
342   }
343 
344   if (!running_) {
345     // We are here just to clean up socket handles, and the ARES state was
346     // cleaned up during the last call to ares_->ProcessFd().
347     return false;
348   }
349 
350   // Schedule timer event for the earlier of our timeout or one requested by
351   // the resolver library.
352   struct timeval now, elapsed_time, timeout_tv;
353   time_->GetTimeMonotonic(&now);
354   timersub(&now, &resolver_state_->start_time, &elapsed_time);
355   timeout_tv.tv_sec = timeout_ms_ / 1000;
356   timeout_tv.tv_usec = (timeout_ms_ % 1000) * 1000;
357   timeout_closure_.Cancel();
358 
359   if (timercmp(&elapsed_time, &timeout_tv, >=)) {
360     // There are 3 cases of interest:
361     //  - If we got here from Start(), when we return, Stop() will be
362     //    called, so our cleanup task will not run, so we will not have the
363     //    side-effect of both invoking the callback and returning False
364     //    in Start().
365     //  - If we got here from the tail of an IO event, we can't call
366     //    Stop() since that will blow away the IOHandler we are running
367     //    in.  We will perform the cleanup in the posted task below.
368     //  - If we got here from a timeout handler, we will perform cleanup
369     //    in the posted task.
370     running_ = false;
371     error_.Populate(Error::kOperationTimeout, kErrorTimedOut);
372     dispatcher_->PostTask(Bind(&DNSClient::HandleCompletion,
373                                weak_ptr_factory_.GetWeakPtr()));
374     return false;
375   } else {
376     struct timeval max, ret_tv;
377     timersub(&timeout_tv, &elapsed_time, &max);
378     struct timeval* tv = ares_->Timeout(resolver_state_->channel,
379                                         &max, &ret_tv);
380     timeout_closure_.Reset(
381         Bind(&DNSClient::HandleTimeout, weak_ptr_factory_.GetWeakPtr()));
382     dispatcher_->PostDelayedTask(timeout_closure_.callback(),
383                                  tv->tv_sec * 1000 + tv->tv_usec / 1000);
384   }
385 
386   return true;
387 }
388 
389 }  // namespace shill
390