1 //
2 // Copyright (C) 2012 The Android Open Source Project
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16
17 #include "shill/dns_client.h"
18
19 #include <arpa/inet.h>
20 #include <netdb.h>
21 #include <netinet/in.h>
22 #include <sys/socket.h>
23
24 #include <map>
25 #include <memory>
26 #include <set>
27 #include <string>
28 #include <vector>
29
30 #include <base/bind.h>
31 #include <base/bind_helpers.h>
32 #include <base/stl_util.h>
33 #include <base/strings/string_number_conversions.h>
34
35 #include "shill/logging.h"
36 #include "shill/net/shill_time.h"
37 #include "shill/shill_ares.h"
38
39 using base::Bind;
40 using base::Unretained;
41 using std::map;
42 using std::set;
43 using std::string;
44 using std::vector;
45
46 namespace shill {
47
48 namespace Logging {
49 static auto kModuleLogScope = ScopeLogger::kDNS;
ObjectID(DNSClient * d)50 static string ObjectID(DNSClient* d) { return d->interface_name(); }
51 }
52
53 const char DNSClient::kErrorNoData[] = "The query response contains no answers";
54 const char DNSClient::kErrorFormErr[] = "The server says the query is bad";
55 const char DNSClient::kErrorServerFail[] = "The server says it had a failure";
56 const char DNSClient::kErrorNotFound[] = "The queried-for domain was not found";
57 const char DNSClient::kErrorNotImp[] = "The server doesn't implement operation";
58 const char DNSClient::kErrorRefused[] = "The server replied, refused the query";
59 const char DNSClient::kErrorBadQuery[] = "Locally we could not format a query";
60 const char DNSClient::kErrorNetRefused[] = "The network connection was refused";
61 const char DNSClient::kErrorTimedOut[] = "The network connection was timed out";
62 const char DNSClient::kErrorUnknown[] = "DNS Resolver unknown internal error";
63
64 const int DNSClient::kDefaultDNSPort = 53;
65
66 // Private to the implementation of resolver so callers don't include ares.h
67 struct DNSClientState {
DNSClientStateshill::DNSClientState68 DNSClientState() : channel(nullptr), start_time{} {}
69
70 ares_channel channel;
71 map<ares_socket_t, std::shared_ptr<IOHandler>> read_handlers;
72 map<ares_socket_t, std::shared_ptr<IOHandler>> write_handlers;
73 struct timeval start_time;
74 };
75
DNSClient(IPAddress::Family family,const string & interface_name,const vector<string> & dns_servers,int timeout_ms,EventDispatcher * dispatcher,const ClientCallback & callback)76 DNSClient::DNSClient(IPAddress::Family family,
77 const string& interface_name,
78 const vector<string>& dns_servers,
79 int timeout_ms,
80 EventDispatcher* dispatcher,
81 const ClientCallback& callback)
82 : address_(IPAddress(family)),
83 interface_name_(interface_name),
84 dns_servers_(dns_servers),
85 dispatcher_(dispatcher),
86 callback_(callback),
87 timeout_ms_(timeout_ms),
88 running_(false),
89 weak_ptr_factory_(this),
90 ares_(Ares::GetInstance()),
91 time_(Time::GetInstance()) {}
92
~DNSClient()93 DNSClient::~DNSClient() {
94 Stop();
95 }
96
Start(const string & hostname,Error * error)97 bool DNSClient::Start(const string& hostname, Error* error) {
98 if (running_) {
99 Error::PopulateAndLog(FROM_HERE, error, Error::kInProgress,
100 "Only one DNS request is allowed at a time");
101 return false;
102 }
103
104 if (!resolver_state_.get()) {
105 struct ares_options options;
106 memset(&options, 0, sizeof(options));
107 options.timeout = timeout_ms_;
108
109 if (dns_servers_.empty()) {
110 Error::PopulateAndLog(FROM_HERE, error, Error::kInvalidArguments,
111 "No valid DNS server addresses");
112 return false;
113 }
114
115 resolver_state_.reset(new DNSClientState);
116 int status = ares_->InitOptions(&resolver_state_->channel,
117 &options,
118 ARES_OPT_TIMEOUTMS);
119 if (status != ARES_SUCCESS) {
120 Error::PopulateAndLog(FROM_HERE, error, Error::kOperationFailed,
121 "ARES initialization returns error code: " +
122 base::IntToString(status));
123 resolver_state_.reset();
124 return false;
125 }
126
127 // Format DNS server addresses string as "host:port[,host:port...]" to be
128 // used in call to ares_set_servers_csv for setting DNS server addresses.
129 // There is a bug in ares library when parsing IPv6 addresses, where it
130 // always assumes the port number are specified when address contains ":".
131 // So when IPv6 address are given without port number as "xx:xx:xx::yy",the
132 // parser would parse the address as "xx:xx:xx:" and port number as "yy".
133 // To work around this bug, port number are added to each address.
134 //
135 // Alternatively, we can use ares_set_servers instead, where we would
136 // explicitly construct a link list of ares_addr_node.
137 string server_addresses;
138 bool first = true;
139 for (const auto& ip : dns_servers_) {
140 if (!first) {
141 server_addresses += ",";
142 } else {
143 first = false;
144 }
145 server_addresses += (ip + ":" + base::IntToString(kDefaultDNSPort));
146 }
147 status = ares_->SetServersCsv(resolver_state_->channel,
148 server_addresses.c_str());
149 if (status != ARES_SUCCESS) {
150 Error::PopulateAndLog(FROM_HERE, error, Error::kOperationFailed,
151 "ARES set DNS servers error code: " +
152 base::IntToString(status));
153 resolver_state_.reset();
154 return false;
155 }
156
157 ares_->SetLocalDev(resolver_state_->channel, interface_name_.c_str());
158 }
159
160 running_ = true;
161 time_->GetTimeMonotonic(&resolver_state_->start_time);
162 ares_->GetHostByName(resolver_state_->channel, hostname.c_str(),
163 address_.family(), ReceiveDNSReplyCB, this);
164
165 if (!RefreshHandles()) {
166 LOG(ERROR) << "Impossibly short timeout.";
167 error->CopyFrom(error_);
168 Stop();
169 return false;
170 }
171
172 return true;
173 }
174
Stop()175 void DNSClient::Stop() {
176 SLOG(this, 3) << "In " << __func__;
177 if (!resolver_state_.get()) {
178 return;
179 }
180
181 running_ = false;
182 weak_ptr_factory_.InvalidateWeakPtrs();
183 error_.Reset();
184 address_.SetAddressToDefault();
185 ares_->Destroy(resolver_state_->channel);
186 resolver_state_.reset();
187 }
188
IsActive() const189 bool DNSClient::IsActive() const {
190 return running_;
191 }
192
193 // We delay our call to completion so that we exit all IOHandlers, and
194 // can clean up all of our local state before calling the callback, or
195 // during the process of the execution of the callee (which is free to
196 // call our destructor safely).
HandleCompletion()197 void DNSClient::HandleCompletion() {
198 SLOG(this, 3) << "In " << __func__;
199 Error error;
200 error.CopyFrom(error_);
201 IPAddress address(address_);
202 if (!error.IsSuccess()) {
203 // If the DNS request did not succeed, do not trust it for future
204 // attempts.
205 Stop();
206 } else {
207 // Prepare our state for the next request without destroying the
208 // current ARES state.
209 error_.Reset();
210 address_.SetAddressToDefault();
211 }
212 callback_.Run(error, address);
213 }
214
HandleDNSRead(int fd)215 void DNSClient::HandleDNSRead(int fd) {
216 ares_->ProcessFd(resolver_state_->channel, fd, ARES_SOCKET_BAD);
217 RefreshHandles();
218 }
219
HandleDNSWrite(int fd)220 void DNSClient::HandleDNSWrite(int fd) {
221 ares_->ProcessFd(resolver_state_->channel, ARES_SOCKET_BAD, fd);
222 RefreshHandles();
223 }
224
HandleTimeout()225 void DNSClient::HandleTimeout() {
226 ares_->ProcessFd(resolver_state_->channel, ARES_SOCKET_BAD, ARES_SOCKET_BAD);
227 RefreshHandles();
228 }
229
ReceiveDNSReply(int status,struct hostent * hostent)230 void DNSClient::ReceiveDNSReply(int status, struct hostent* hostent) {
231 if (!running_) {
232 // We can be called during ARES shutdown -- ignore these events.
233 return;
234 }
235 SLOG(this, 3) << "In " << __func__;
236 running_ = false;
237 timeout_closure_.Cancel();
238 dispatcher_->PostTask(Bind(&DNSClient::HandleCompletion,
239 weak_ptr_factory_.GetWeakPtr()));
240
241 if (status == ARES_SUCCESS &&
242 hostent != nullptr &&
243 hostent->h_addrtype == address_.family() &&
244 static_cast<size_t>(hostent->h_length) ==
245 IPAddress::GetAddressLength(address_.family()) &&
246 hostent->h_addr_list != nullptr &&
247 hostent->h_addr_list[0] != nullptr) {
248 address_ = IPAddress(address_.family(),
249 ByteString(reinterpret_cast<unsigned char*>(
250 hostent->h_addr_list[0]), hostent->h_length));
251 } else {
252 switch (status) {
253 case ARES_ENODATA:
254 error_.Populate(Error::kOperationFailed, kErrorNoData);
255 break;
256 case ARES_EFORMERR:
257 error_.Populate(Error::kOperationFailed, kErrorFormErr);
258 break;
259 case ARES_ESERVFAIL:
260 error_.Populate(Error::kOperationFailed, kErrorServerFail);
261 break;
262 case ARES_ENOTFOUND:
263 error_.Populate(Error::kOperationFailed, kErrorNotFound);
264 break;
265 case ARES_ENOTIMP:
266 error_.Populate(Error::kOperationFailed, kErrorNotImp);
267 break;
268 case ARES_EREFUSED:
269 error_.Populate(Error::kOperationFailed, kErrorRefused);
270 break;
271 case ARES_EBADQUERY:
272 case ARES_EBADNAME:
273 case ARES_EBADFAMILY:
274 case ARES_EBADRESP:
275 error_.Populate(Error::kOperationFailed, kErrorBadQuery);
276 break;
277 case ARES_ECONNREFUSED:
278 error_.Populate(Error::kOperationFailed, kErrorNetRefused);
279 break;
280 case ARES_ETIMEOUT:
281 error_.Populate(Error::kOperationTimeout, kErrorTimedOut);
282 break;
283 default:
284 error_.Populate(Error::kOperationFailed, kErrorUnknown);
285 if (status == ARES_SUCCESS) {
286 LOG(ERROR) << "ARES returned success but hostent was invalid!";
287 } else {
288 LOG(ERROR) << "ARES returned unhandled error status " << status;
289 }
290 break;
291 }
292 }
293 }
294
ReceiveDNSReplyCB(void * arg,int status,int,struct hostent * hostent)295 void DNSClient::ReceiveDNSReplyCB(void* arg, int status,
296 int /*timeouts*/,
297 struct hostent* hostent) {
298 DNSClient* res = static_cast<DNSClient*>(arg);
299 res->ReceiveDNSReply(status, hostent);
300 }
301
RefreshHandles()302 bool DNSClient::RefreshHandles() {
303 map<ares_socket_t, std::shared_ptr<IOHandler>> old_read =
304 resolver_state_->read_handlers;
305 map<ares_socket_t, std::shared_ptr<IOHandler>> old_write =
306 resolver_state_->write_handlers;
307
308 resolver_state_->read_handlers.clear();
309 resolver_state_->write_handlers.clear();
310
311 ares_socket_t sockets[ARES_GETSOCK_MAXNUM];
312 int action_bits = ares_->GetSock(resolver_state_->channel, sockets,
313 ARES_GETSOCK_MAXNUM);
314
315 base::Callback<void(int)> read_callback(
316 Bind(&DNSClient::HandleDNSRead, weak_ptr_factory_.GetWeakPtr()));
317 base::Callback<void(int)> write_callback(
318 Bind(&DNSClient::HandleDNSWrite, weak_ptr_factory_.GetWeakPtr()));
319 for (int i = 0; i < ARES_GETSOCK_MAXNUM; i++) {
320 if (ARES_GETSOCK_READABLE(action_bits, i)) {
321 if (ContainsKey(old_read, sockets[i])) {
322 resolver_state_->read_handlers[sockets[i]] = old_read[sockets[i]];
323 } else {
324 resolver_state_->read_handlers[sockets[i]] =
325 std::shared_ptr<IOHandler> (
326 dispatcher_->CreateReadyHandler(sockets[i],
327 IOHandler::kModeInput,
328 read_callback));
329 }
330 }
331 if (ARES_GETSOCK_WRITABLE(action_bits, i)) {
332 if (ContainsKey(old_write, sockets[i])) {
333 resolver_state_->write_handlers[sockets[i]] = old_write[sockets[i]];
334 } else {
335 resolver_state_->write_handlers[sockets[i]] =
336 std::shared_ptr<IOHandler> (
337 dispatcher_->CreateReadyHandler(sockets[i],
338 IOHandler::kModeOutput,
339 write_callback));
340 }
341 }
342 }
343
344 if (!running_) {
345 // We are here just to clean up socket handles, and the ARES state was
346 // cleaned up during the last call to ares_->ProcessFd().
347 return false;
348 }
349
350 // Schedule timer event for the earlier of our timeout or one requested by
351 // the resolver library.
352 struct timeval now, elapsed_time, timeout_tv;
353 time_->GetTimeMonotonic(&now);
354 timersub(&now, &resolver_state_->start_time, &elapsed_time);
355 timeout_tv.tv_sec = timeout_ms_ / 1000;
356 timeout_tv.tv_usec = (timeout_ms_ % 1000) * 1000;
357 timeout_closure_.Cancel();
358
359 if (timercmp(&elapsed_time, &timeout_tv, >=)) {
360 // There are 3 cases of interest:
361 // - If we got here from Start(), when we return, Stop() will be
362 // called, so our cleanup task will not run, so we will not have the
363 // side-effect of both invoking the callback and returning False
364 // in Start().
365 // - If we got here from the tail of an IO event, we can't call
366 // Stop() since that will blow away the IOHandler we are running
367 // in. We will perform the cleanup in the posted task below.
368 // - If we got here from a timeout handler, we will perform cleanup
369 // in the posted task.
370 running_ = false;
371 error_.Populate(Error::kOperationTimeout, kErrorTimedOut);
372 dispatcher_->PostTask(Bind(&DNSClient::HandleCompletion,
373 weak_ptr_factory_.GetWeakPtr()));
374 return false;
375 } else {
376 struct timeval max, ret_tv;
377 timersub(&timeout_tv, &elapsed_time, &max);
378 struct timeval* tv = ares_->Timeout(resolver_state_->channel,
379 &max, &ret_tv);
380 timeout_closure_.Reset(
381 Bind(&DNSClient::HandleTimeout, weak_ptr_factory_.GetWeakPtr()));
382 dispatcher_->PostDelayedTask(timeout_closure_.callback(),
383 tv->tv_sec * 1000 + tv->tv_usec / 1000);
384 }
385
386 return true;
387 }
388
389 } // namespace shill
390