1 /*
2 * Copyright (C) 2016 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "dns_responder.h"
18
19 #include <arpa/inet.h>
20 #include <fcntl.h>
21 #include <netdb.h>
22 #include <stdarg.h>
23 #include <stdlib.h>
24 #include <string.h>
25 #include <sys/epoll.h>
26 #include <sys/eventfd.h>
27 #include <sys/socket.h>
28 #include <sys/types.h>
29 #include <unistd.h>
30 #include <span>
31
32 #include <chrono>
33 #include <iostream>
34 #include <set>
35 #include <vector>
36
37 #define LOG_TAG "DNSResponder"
38 #include <android-base/logging.h>
39 #include <android-base/strings.h>
40 #include <netdutils/BackoffSequence.h>
41 #include <netdutils/InternetAddresses.h>
42 #include <netdutils/SocketOption.h>
43
44 using android::base::ErrnoError;
45 using android::base::Result;
46 using android::base::unique_fd;
47 using android::netdutils::BackoffSequence;
48 using android::netdutils::enableSockopt;
49 using android::netdutils::ScopedAddrinfo;
50 using std::chrono::milliseconds;
51
52 namespace test {
53
errno2str()54 std::string errno2str() {
55 char error_msg[512] = {0};
56 // It actually calls __gnu_strerror_r() which returns the type |char*| rather than |int|.
57 // PLOG is an option though it requires lots of changes from ALOGx() to LOG(x).
58 return strerror_r(errno, error_msg, sizeof(error_msg));
59 }
60
addr2str(const sockaddr * sa,socklen_t sa_len)61 std::string addr2str(const sockaddr* sa, socklen_t sa_len) {
62 char host_str[NI_MAXHOST] = {0};
63 int rv = getnameinfo(sa, sa_len, host_str, sizeof(host_str), nullptr, 0, NI_NUMERICHOST);
64 if (rv == 0) return std::string(host_str);
65 return std::string();
66 }
67
bytesToHexStr(std::span<const uint8_t> bytes)68 std::string bytesToHexStr(std::span<const uint8_t> bytes) {
69 static char const hex[16] = {'0', '1', '2', '3', '4', '5', '6', '7',
70 '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'};
71 std::string str;
72 str.reserve(bytes.size() * 2);
73 for (uint8_t ch : bytes) {
74 str.append({hex[(ch & 0xf0) >> 4], hex[ch & 0xf]});
75 }
76 return str;
77 }
78
79 // Because The address might still being set up (b/186181084), This is a wrapper function
80 // that retries bind() if errno is EADDRNOTAVAIL
bindSocket(int socket,const sockaddr * address,socklen_t address_len)81 Result<void> bindSocket(int socket, const sockaddr* address, socklen_t address_len) {
82 // Set the wrapper to try bind() at most 6 times with backoff time
83 // (100 ms, 200 ms, ..., 1600 ms).
84 auto backoff = BackoffSequence<milliseconds>::Builder()
85 .withInitialRetransmissionTime(milliseconds(100))
86 .withMaximumRetransmissionCount(5)
87 .build();
88
89 while (true) {
90 if (0 == bind(socket, address, address_len)) return {};
91 if (errno != EADDRNOTAVAIL) return ErrnoError();
92 if (!backoff.hasNextTimeout()) return ErrnoError();
93
94 LOG(WARNING) << "Retry to bind " << addr2str(address, address_len);
95 std::this_thread::sleep_for(backoff.getNextTimeout());
96 }
97 }
98
99 /* DNS struct helpers */
100
dnstype2str(unsigned dnstype)101 const char* dnstype2str(unsigned dnstype) {
102 static std::unordered_map<unsigned, const char*> kTypeStrs = {
103 {ns_type::ns_t_a, "A"},
104 {ns_type::ns_t_ns, "NS"},
105 {ns_type::ns_t_md, "MD"},
106 {ns_type::ns_t_mf, "MF"},
107 {ns_type::ns_t_cname, "CNAME"},
108 {ns_type::ns_t_soa, "SOA"},
109 {ns_type::ns_t_mb, "MB"},
110 {ns_type::ns_t_mb, "MG"},
111 {ns_type::ns_t_mr, "MR"},
112 {ns_type::ns_t_null, "NULL"},
113 {ns_type::ns_t_wks, "WKS"},
114 {ns_type::ns_t_ptr, "PTR"},
115 {ns_type::ns_t_hinfo, "HINFO"},
116 {ns_type::ns_t_minfo, "MINFO"},
117 {ns_type::ns_t_mx, "MX"},
118 {ns_type::ns_t_txt, "TXT"},
119 {ns_type::ns_t_rp, "RP"},
120 {ns_type::ns_t_afsdb, "AFSDB"},
121 {ns_type::ns_t_x25, "X25"},
122 {ns_type::ns_t_isdn, "ISDN"},
123 {ns_type::ns_t_rt, "RT"},
124 {ns_type::ns_t_nsap, "NSAP"},
125 {ns_type::ns_t_nsap_ptr, "NSAP-PTR"},
126 {ns_type::ns_t_sig, "SIG"},
127 {ns_type::ns_t_key, "KEY"},
128 {ns_type::ns_t_px, "PX"},
129 {ns_type::ns_t_gpos, "GPOS"},
130 {ns_type::ns_t_aaaa, "AAAA"},
131 {ns_type::ns_t_loc, "LOC"},
132 {ns_type::ns_t_nxt, "NXT"},
133 {ns_type::ns_t_eid, "EID"},
134 {ns_type::ns_t_nimloc, "NIMLOC"},
135 {ns_type::ns_t_srv, "SRV"},
136 {ns_type::ns_t_naptr, "NAPTR"},
137 {ns_type::ns_t_kx, "KX"},
138 {ns_type::ns_t_cert, "CERT"},
139 {ns_type::ns_t_a6, "A6"},
140 {ns_type::ns_t_dname, "DNAME"},
141 {ns_type::ns_t_sink, "SINK"},
142 {ns_type::ns_t_opt, "OPT"},
143 {ns_type::ns_t_apl, "APL"},
144 {ns_type::ns_t_tkey, "TKEY"},
145 {ns_type::ns_t_tsig, "TSIG"},
146 {ns_type::ns_t_ixfr, "IXFR"},
147 {ns_type::ns_t_axfr, "AXFR"},
148 {ns_type::ns_t_mailb, "MAILB"},
149 {ns_type::ns_t_maila, "MAILA"},
150 {ns_type::ns_t_any, "ANY"},
151 {ns_type::ns_t_zxfr, "ZXFR"},
152 };
153 auto it = kTypeStrs.find(dnstype);
154 static const char* kUnknownStr{"UNKNOWN"};
155 if (it == kTypeStrs.end()) return kUnknownStr;
156 return it->second;
157 }
158
dnsclass2str(unsigned dnsclass)159 const char* dnsclass2str(unsigned dnsclass) {
160 static std::unordered_map<unsigned, const char*> kClassStrs = {
161 {ns_class::ns_c_in, "Internet"}, {2, "CSNet"},
162 {ns_class::ns_c_chaos, "ChaosNet"}, {ns_class::ns_c_hs, "Hesiod"},
163 {ns_class::ns_c_none, "none"}, {ns_class::ns_c_any, "any"}};
164 auto it = kClassStrs.find(dnsclass);
165 static const char* kUnknownStr{"UNKNOWN"};
166 if (it == kClassStrs.end()) return kUnknownStr;
167 return it->second;
168 }
169
dnsproto2str(int protocol)170 const char* dnsproto2str(int protocol) {
171 switch (protocol) {
172 case IPPROTO_TCP:
173 return "TCP";
174 case IPPROTO_UDP:
175 return "UDP";
176 default:
177 return "UNKNOWN";
178 }
179 }
180
read(const char * buffer,const char * buffer_end)181 const char* DNSName::read(const char* buffer, const char* buffer_end) {
182 const char* cur = buffer;
183 bool last = false;
184 do {
185 cur = parseField(cur, buffer_end, &last);
186 if (cur == nullptr) {
187 LOG(ERROR) << "parsing failed at line " << __LINE__;
188 return nullptr;
189 }
190 } while (!last);
191 return cur;
192 }
193
write(char * buffer,const char * buffer_end) const194 char* DNSName::write(char* buffer, const char* buffer_end) const {
195 char* buffer_cur = buffer;
196 for (size_t pos = 0; pos < name.size();) {
197 size_t dot_pos = name.find('.', pos);
198 if (dot_pos == std::string::npos) {
199 // Soundness check, should never happen unless parseField is broken.
200 LOG(ERROR) << "logic error: all names are expected to end with a '.'";
201 return nullptr;
202 }
203 const size_t len = dot_pos - pos;
204 if (len >= 256) {
205 LOG(ERROR) << "name component '" << name.substr(pos, dot_pos - pos) << "' is " << len
206 << " long, but max is 255";
207 return nullptr;
208 }
209 if (buffer_cur + sizeof(uint8_t) + len > buffer_end) {
210 LOG(ERROR) << "buffer overflow at line " << __LINE__;
211 return nullptr;
212 }
213 *buffer_cur++ = len;
214 buffer_cur = std::copy(std::next(name.begin(), pos), std::next(name.begin(), dot_pos),
215 buffer_cur);
216 pos = dot_pos + 1;
217 }
218 // Write final zero.
219 *buffer_cur++ = 0;
220 return buffer_cur;
221 }
222
parseField(const char * buffer,const char * buffer_end,bool * last)223 const char* DNSName::parseField(const char* buffer, const char* buffer_end, bool* last) {
224 if (buffer + sizeof(uint8_t) > buffer_end) {
225 LOG(ERROR) << "parsing failed at line " << __LINE__;
226 return nullptr;
227 }
228 unsigned field_type = *buffer >> 6;
229 unsigned ofs = *buffer & 0x3F;
230 const char* cur = buffer + sizeof(uint8_t);
231 if (field_type == 0) {
232 // length + name component
233 if (ofs == 0) {
234 *last = true;
235 return cur;
236 }
237 if (cur + ofs > buffer_end) {
238 LOG(ERROR) << "parsing failed at line " << __LINE__;
239 return nullptr;
240 }
241 name.append(cur, ofs);
242 name.push_back('.');
243 return cur + ofs;
244 } else if (field_type == 3) {
245 LOG(ERROR) << "name compression not implemented";
246 return nullptr;
247 }
248 LOG(ERROR) << "invalid name field type";
249 return nullptr;
250 }
251
read(const char * buffer,const char * buffer_end)252 const char* DNSQuestion::read(const char* buffer, const char* buffer_end) {
253 const char* cur = qname.read(buffer, buffer_end);
254 if (cur == nullptr) {
255 LOG(ERROR) << "parsing failed at line " << __LINE__;
256 return nullptr;
257 }
258 if (cur + 2 * sizeof(uint16_t) > buffer_end) {
259 LOG(ERROR) << "parsing failed at line " << __LINE__;
260 return nullptr;
261 }
262 qtype = ntohs(*reinterpret_cast<const uint16_t*>(cur));
263 qclass = ntohs(*reinterpret_cast<const uint16_t*>(cur + sizeof(uint16_t)));
264 return cur + 2 * sizeof(uint16_t);
265 }
266
write(char * buffer,const char * buffer_end) const267 char* DNSQuestion::write(char* buffer, const char* buffer_end) const {
268 char* buffer_cur = qname.write(buffer, buffer_end);
269 if (buffer_cur == nullptr) return nullptr;
270 if (buffer_cur + 2 * sizeof(uint16_t) > buffer_end) {
271 LOG(ERROR) << "buffer overflow on line " << __LINE__;
272 return nullptr;
273 }
274 *reinterpret_cast<uint16_t*>(buffer_cur) = htons(qtype);
275 *reinterpret_cast<uint16_t*>(buffer_cur + sizeof(uint16_t)) = htons(qclass);
276 return buffer_cur + 2 * sizeof(uint16_t);
277 }
278
toString() const279 std::string DNSQuestion::toString() const {
280 char buffer[16384];
281 int len = snprintf(buffer, sizeof(buffer), "Q<%s,%s,%s>", qname.name.c_str(),
282 dnstype2str(qtype), dnsclass2str(qclass));
283 return std::string(buffer, len);
284 }
285
read(const char * buffer,const char * buffer_end)286 const char* DNSRecord::read(const char* buffer, const char* buffer_end) {
287 const char* cur = name.read(buffer, buffer_end);
288 if (cur == nullptr) {
289 LOG(ERROR) << "parsing failed at line " << __LINE__;
290 return nullptr;
291 }
292 unsigned rdlen = 0;
293 cur = readIntFields(cur, buffer_end, &rdlen);
294 if (cur == nullptr) {
295 LOG(ERROR) << "parsing failed at line " << __LINE__;
296 return nullptr;
297 }
298 if (cur + rdlen > buffer_end) {
299 LOG(ERROR) << "parsing failed at line " << __LINE__;
300 return nullptr;
301 }
302 rdata.assign(cur, cur + rdlen);
303 return cur + rdlen;
304 }
305
write(char * buffer,const char * buffer_end) const306 char* DNSRecord::write(char* buffer, const char* buffer_end) const {
307 char* buffer_cur = name.write(buffer, buffer_end);
308 if (buffer_cur == nullptr) return nullptr;
309 buffer_cur = writeIntFields(rdata.size(), buffer_cur, buffer_end);
310 if (buffer_cur == nullptr) return nullptr;
311 if (buffer_cur + rdata.size() > buffer_end) {
312 LOG(ERROR) << "buffer overflow on line " << __LINE__;
313 return nullptr;
314 }
315 return std::copy(rdata.begin(), rdata.end(), buffer_cur);
316 }
317
toString() const318 std::string DNSRecord::toString() const {
319 char buffer[16384];
320 int len = snprintf(buffer, sizeof(buffer), "R<%s,%s,%s>", name.name.c_str(), dnstype2str(rtype),
321 dnsclass2str(rclass));
322 return std::string(buffer, len);
323 }
324
readIntFields(const char * buffer,const char * buffer_end,unsigned * rdlen)325 const char* DNSRecord::readIntFields(const char* buffer, const char* buffer_end, unsigned* rdlen) {
326 if (buffer + sizeof(IntFields) > buffer_end) {
327 LOG(ERROR) << "parsing failed at line " << __LINE__;
328 return nullptr;
329 }
330 const auto& intfields = *reinterpret_cast<const IntFields*>(buffer);
331 rtype = ntohs(intfields.rtype);
332 rclass = ntohs(intfields.rclass);
333 ttl = ntohl(intfields.ttl);
334 *rdlen = ntohs(intfields.rdlen);
335 return buffer + sizeof(IntFields);
336 }
337
writeIntFields(unsigned rdlen,char * buffer,const char * buffer_end) const338 char* DNSRecord::writeIntFields(unsigned rdlen, char* buffer, const char* buffer_end) const {
339 if (buffer + sizeof(IntFields) > buffer_end) {
340 LOG(ERROR) << "buffer overflow on line " << __LINE__;
341 return nullptr;
342 }
343 auto& intfields = *reinterpret_cast<IntFields*>(buffer);
344 intfields.rtype = htons(rtype);
345 intfields.rclass = htons(rclass);
346 intfields.ttl = htonl(ttl);
347 intfields.rdlen = htons(rdlen);
348 return buffer + sizeof(IntFields);
349 }
350
read(const char * buffer,const char * buffer_end)351 const char* DNSHeader::read(const char* buffer, const char* buffer_end) {
352 unsigned qdcount;
353 unsigned ancount;
354 unsigned nscount;
355 unsigned arcount;
356 const char* cur = readHeader(buffer, buffer_end, &qdcount, &ancount, &nscount, &arcount);
357 if (cur == nullptr) {
358 LOG(ERROR) << "parsing failed at line " << __LINE__;
359 return nullptr;
360 }
361 if (qdcount) {
362 questions.resize(qdcount);
363 for (unsigned i = 0; i < qdcount; ++i) {
364 cur = questions[i].read(cur, buffer_end);
365 if (cur == nullptr) {
366 LOG(ERROR) << "parsing failed at line " << __LINE__;
367 return nullptr;
368 }
369 }
370 }
371 if (ancount) {
372 answers.resize(ancount);
373 for (unsigned i = 0; i < ancount; ++i) {
374 cur = answers[i].read(cur, buffer_end);
375 if (cur == nullptr) {
376 LOG(ERROR) << "parsing failed at line " << __LINE__;
377 return nullptr;
378 }
379 }
380 }
381 if (nscount) {
382 authorities.resize(nscount);
383 for (unsigned i = 0; i < nscount; ++i) {
384 cur = authorities[i].read(cur, buffer_end);
385 if (cur == nullptr) {
386 LOG(ERROR) << "parsing failed at line " << __LINE__;
387 return nullptr;
388 }
389 }
390 }
391 if (arcount) {
392 additionals.resize(arcount);
393 for (unsigned i = 0; i < arcount; ++i) {
394 cur = additionals[i].read(cur, buffer_end);
395 if (cur == nullptr) {
396 LOG(ERROR) << "parsing failed at line " << __LINE__;
397 return nullptr;
398 }
399 }
400 }
401 return cur;
402 }
403
write(char * buffer,const char * buffer_end) const404 char* DNSHeader::write(char* buffer, const char* buffer_end) const {
405 if (buffer + sizeof(Header) > buffer_end) {
406 LOG(ERROR) << "buffer overflow on line " << __LINE__;
407 return nullptr;
408 }
409 Header& header = *reinterpret_cast<Header*>(buffer);
410 // bytes 0-1
411 header.id = htons(id);
412 // byte 2: 7:qr, 3-6:opcode, 2:aa, 1:tr, 0:rd
413 header.flags0 = (qr << 7) | (opcode << 3) | (aa << 2) | (tr << 1) | rd;
414 // byte 3: 7:ra, 6:zero, 5:ad, 4:cd, 0-3:rcode
415 // Fake behavior: if the query set the "ad" bit, set it in the response too.
416 // In a real server, this should be set only if the data is authentic and the
417 // query contained an "ad" bit or DNSSEC extensions.
418 header.flags1 = (ad << 5) | rcode;
419 // rest of header
420 header.qdcount = htons(questions.size());
421 header.ancount = htons(answers.size());
422 header.nscount = htons(authorities.size());
423 header.arcount = htons(additionals.size());
424 char* buffer_cur = buffer + sizeof(Header);
425 for (const DNSQuestion& question : questions) {
426 buffer_cur = question.write(buffer_cur, buffer_end);
427 if (buffer_cur == nullptr) return nullptr;
428 }
429 for (const DNSRecord& answer : answers) {
430 buffer_cur = answer.write(buffer_cur, buffer_end);
431 if (buffer_cur == nullptr) return nullptr;
432 }
433 for (const DNSRecord& authority : authorities) {
434 buffer_cur = authority.write(buffer_cur, buffer_end);
435 if (buffer_cur == nullptr) return nullptr;
436 }
437 for (const DNSRecord& additional : additionals) {
438 buffer_cur = additional.write(buffer_cur, buffer_end);
439 if (buffer_cur == nullptr) return nullptr;
440 }
441 return buffer_cur;
442 }
443
444 // TODO: convert all callers to this interface, then delete the old one.
write(std::vector<uint8_t> * out) const445 bool DNSHeader::write(std::vector<uint8_t>* out) const {
446 char buffer[16384];
447 char* end = this->write(buffer, buffer + sizeof buffer);
448 if (end == nullptr) return false;
449 out->insert(out->end(), buffer, end);
450 return true;
451 }
452
toString() const453 std::string DNSHeader::toString() const {
454 // TODO
455 return std::string();
456 }
457
readHeader(const char * buffer,const char * buffer_end,unsigned * qdcount,unsigned * ancount,unsigned * nscount,unsigned * arcount)458 const char* DNSHeader::readHeader(const char* buffer, const char* buffer_end, unsigned* qdcount,
459 unsigned* ancount, unsigned* nscount, unsigned* arcount) {
460 if (buffer + sizeof(Header) > buffer_end) return nullptr;
461 const auto& header = *reinterpret_cast<const Header*>(buffer);
462 // bytes 0-1
463 id = ntohs(header.id);
464 // byte 2: 7:qr, 3-6:opcode, 2:aa, 1:tr, 0:rd
465 qr = header.flags0 >> 7;
466 opcode = (header.flags0 >> 3) & 0x0F;
467 aa = (header.flags0 >> 2) & 1;
468 tr = (header.flags0 >> 1) & 1;
469 rd = header.flags0 & 1;
470 // byte 3: 7:ra, 6:zero, 5:ad, 4:cd, 0-3:rcode
471 ra = header.flags1 >> 7;
472 ad = (header.flags1 >> 5) & 1;
473 rcode = header.flags1 & 0xF;
474 // rest of header
475 *qdcount = ntohs(header.qdcount);
476 *ancount = ntohs(header.ancount);
477 *nscount = ntohs(header.nscount);
478 *arcount = ntohs(header.arcount);
479 return buffer + sizeof(Header);
480 }
481
482 /* DNS responder */
483
DNSResponder(std::string listen_address,std::string listen_service,ns_rcode error_rcode,MappingType mapping_type)484 DNSResponder::DNSResponder(std::string listen_address, std::string listen_service,
485 ns_rcode error_rcode, MappingType mapping_type)
486 : listen_address_(std::move(listen_address)),
487 listen_service_(std::move(listen_service)),
488 error_rcode_(error_rcode),
489 mapping_type_(mapping_type) {}
490
~DNSResponder()491 DNSResponder::~DNSResponder() {
492 stopServer();
493 }
494
addMapping(const std::string & name,ns_type type,const std::string & addr)495 void DNSResponder::addMapping(const std::string& name, ns_type type, const std::string& addr) {
496 std::lock_guard lock(mappings_mutex_);
497 mappings_[{name, type}] = addr;
498 }
499
addMappingDnsHeader(const std::string & name,ns_type type,const DNSHeader & header)500 void DNSResponder::addMappingDnsHeader(const std::string& name, ns_type type,
501 const DNSHeader& header) {
502 std::lock_guard lock(mappings_mutex_);
503 dnsheader_mappings_[{name, type}] = header;
504 }
505
addMappingBinaryPacket(const std::vector<uint8_t> & query,const std::vector<uint8_t> & response)506 void DNSResponder::addMappingBinaryPacket(const std::vector<uint8_t>& query,
507 const std::vector<uint8_t>& response) {
508 std::lock_guard lock(mappings_mutex_);
509 packet_mappings_[query] = response;
510 }
511
removeMapping(const std::string & name,ns_type type)512 void DNSResponder::removeMapping(const std::string& name, ns_type type) {
513 std::lock_guard lock(mappings_mutex_);
514 if (!mappings_.erase({name, type})) {
515 LOG(ERROR) << "Cannot remove mapping from (" << name << ", " << dnstype2str(type)
516 << "), not present in registered mappings";
517 }
518 }
519
removeMappingDnsHeader(const std::string & name,ns_type type)520 void DNSResponder::removeMappingDnsHeader(const std::string& name, ns_type type) {
521 std::lock_guard lock(mappings_mutex_);
522 if (!dnsheader_mappings_.erase({name, type})) {
523 LOG(ERROR) << "Cannot remove mapping from (" << name << ", " << dnstype2str(type)
524 << "), not present in registered DnsHeader mappings";
525 }
526 }
527
removeMappingBinaryPacket(const std::vector<uint8_t> & query)528 void DNSResponder::removeMappingBinaryPacket(const std::vector<uint8_t>& query) {
529 std::lock_guard lock(mappings_mutex_);
530 if (!packet_mappings_.erase(query)) {
531 LOG(ERROR) << "Cannot remove mapping, not present in registered BinaryPacket mappings";
532 LOG(INFO) << "Hex dump:";
533 LOG(INFO) << bytesToHexStr(query);
534 }
535 }
536
537 // Set response probability on all supported protocols.
setResponseProbability(double response_probability)538 void DNSResponder::setResponseProbability(double response_probability) {
539 setResponseProbability(response_probability, IPPROTO_TCP);
540 setResponseProbability(response_probability, IPPROTO_UDP);
541 }
542
setResponseDelayMs(unsigned timeMs)543 void DNSResponder::setResponseDelayMs(unsigned timeMs) {
544 response_delayed_ms_ = timeMs;
545 }
546
547 // Set response probability on specific protocol. It's caller's duty to ensure that the |protocol|
548 // can be supported by DNSResponder.
setResponseProbability(double response_probability,int protocol)549 void DNSResponder::setResponseProbability(double response_probability, int protocol) {
550 switch (protocol) {
551 case IPPROTO_TCP:
552 response_probability_tcp_ = response_probability;
553 break;
554 case IPPROTO_UDP:
555 response_probability_udp_ = response_probability;
556 break;
557 default:
558 LOG(FATAL) << "Unsupported protocol " << protocol; // abort() by log level FATAL
559 }
560 }
561
getResponseProbability(int protocol) const562 double DNSResponder::getResponseProbability(int protocol) const {
563 switch (protocol) {
564 case IPPROTO_TCP:
565 return response_probability_tcp_;
566 case IPPROTO_UDP:
567 return response_probability_udp_;
568 default:
569 LOG(FATAL) << "Unsupported protocol " << protocol; // abort() by log level FATAL
570 // unreachable
571 return -1;
572 }
573 }
574
setEdns(Edns edns)575 void DNSResponder::setEdns(Edns edns) {
576 edns_ = edns;
577 }
578
setTtl(unsigned ttl)579 void DNSResponder::setTtl(unsigned ttl) {
580 answer_record_ttl_sec_ = ttl;
581 }
582
running() const583 bool DNSResponder::running() const {
584 if (listen_service_ == kDefaultMdnsListenService)
585 return udp_socket_.ok();
586 else
587 return (udp_socket_.ok()) && (tcp_socket_.ok());
588 }
589
startServer()590 bool DNSResponder::startServer() {
591 if (running()) {
592 LOG(ERROR) << "server already running";
593 return false;
594 }
595
596 // Create UDP, TCP socket
597 if (udp_socket_ = createListeningSocket(SOCK_DGRAM); udp_socket_.get() < 0) {
598 PLOG(ERROR) << "failed to create UDP socket";
599 return false;
600 }
601
602 if (listen_service_ != kDefaultMdnsListenService) {
603 if (tcp_socket_ = createListeningSocket(SOCK_STREAM); tcp_socket_.get() < 0) {
604 PLOG(ERROR) << "failed to create TCP socket";
605 return false;
606 }
607
608 if (listen(tcp_socket_.get(), 1) < 0) {
609 PLOG(ERROR) << "failed to listen TCP socket";
610 return false;
611 }
612 }
613
614 // Set up eventfd socket.
615 event_fd_.reset(eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC));
616 if (event_fd_.get() == -1) {
617 PLOG(ERROR) << "failed to create eventfd";
618 return false;
619 }
620
621 // Set up epoll socket.
622 epoll_fd_.reset(epoll_create1(EPOLL_CLOEXEC));
623 if (epoll_fd_.get() < 0) {
624 PLOG(ERROR) << "epoll_create1() failed on fd";
625 return false;
626 }
627
628 LOG(INFO) << "adding UDP socket to epoll";
629 if (!addFd(udp_socket_.get(), EPOLLIN)) {
630 LOG(ERROR) << "failed to add the UDP socket to epoll";
631 return false;
632 }
633
634 if (listen_service_ != kDefaultMdnsListenService) {
635 LOG(INFO) << "adding TCP socket to epoll";
636 if (!addFd(tcp_socket_.get(), EPOLLIN)) {
637 LOG(ERROR) << "failed to add the TCP socket to epoll";
638 return false;
639 }
640 }
641
642 LOG(INFO) << "adding eventfd to epoll";
643 if (!addFd(event_fd_.get(), EPOLLIN)) {
644 LOG(ERROR) << "failed to add the eventfd to epoll";
645 return false;
646 }
647
648 {
649 std::lock_guard lock(update_mutex_);
650 handler_thread_ = std::thread(&DNSResponder::requestHandler, this);
651 }
652 LOG(INFO) << "server started successfully";
653 return true;
654 }
655
stopServer()656 bool DNSResponder::stopServer() {
657 std::lock_guard lock(update_mutex_);
658 if (!running()) {
659 LOG(ERROR) << "server not running";
660 return false;
661 }
662 LOG(INFO) << "stopping server";
663 if (!sendToEventFd()) {
664 return false;
665 }
666 handler_thread_.join();
667 epoll_fd_.reset();
668 event_fd_.reset();
669 udp_socket_.reset();
670 tcp_socket_.reset();
671 LOG(INFO) << "server stopped successfully";
672 return true;
673 }
674
queries() const675 std::vector<DNSResponder::QueryInfo> DNSResponder::queries() const {
676 std::lock_guard lock(queries_mutex_);
677 return queries_;
678 }
679
dumpQueries() const680 std::string DNSResponder::dumpQueries() const {
681 std::lock_guard lock(queries_mutex_);
682 std::string out;
683
684 for (const auto& q : queries_) {
685 out += "{\"" + q.name + "\", " + std::to_string(q.type) + "\", " +
686 dnsproto2str(q.protocol) + "} ";
687 }
688 return out;
689 }
690
clearQueries()691 void DNSResponder::clearQueries() {
692 std::lock_guard lock(queries_mutex_);
693 queries_.clear();
694 }
695
hasOptPseudoRR(DNSHeader * header) const696 bool DNSResponder::hasOptPseudoRR(DNSHeader* header) const {
697 if (header->additionals.empty()) return false;
698
699 // OPT RR may be placed anywhere within the additional section. See RFC 6891 section 6.1.1.
700 auto found = std::find_if(header->additionals.begin(), header->additionals.end(),
701 [](const auto& a) { return a.rtype == ns_type::ns_t_opt; });
702 return found != header->additionals.end();
703 }
704
requestHandler()705 void DNSResponder::requestHandler() {
706 epoll_event evs[EPOLL_MAX_EVENTS];
707 while (true) {
708 int n = epoll_wait(epoll_fd_.get(), evs, EPOLL_MAX_EVENTS, -1);
709 if (n <= 0) {
710 PLOG(ERROR) << "epoll_wait() failed, n=" << n;
711 return;
712 }
713
714 for (int i = 0; i < n; i++) {
715 const int fd = evs[i].data.fd;
716 const uint32_t events = evs[i].events;
717 if (fd == event_fd_.get() && (events & (EPOLLIN | EPOLLERR))) {
718 handleEventFd();
719 return;
720 } else if (fd == udp_socket_.get() && (events & (EPOLLIN | EPOLLERR))) {
721 handleQuery(IPPROTO_UDP);
722 } else if (fd == tcp_socket_.get() && (events & (EPOLLIN | EPOLLERR))) {
723 handleQuery(IPPROTO_TCP);
724 } else {
725 LOG(WARNING) << "unexpected epoll events " << events << " on fd " << fd;
726 }
727 }
728 }
729 }
730
handleDNSRequest(const char * buffer,ssize_t len,int protocol,char * response,size_t * response_len) const731 bool DNSResponder::handleDNSRequest(const char* buffer, ssize_t len, int protocol, char* response,
732 size_t* response_len) const {
733 LOG(DEBUG) << "request: '"
734 << bytesToHexStr(std::span(reinterpret_cast<const uint8_t*>(buffer), len))
735 << "', on " << dnsproto2str(protocol);
736 const char* buffer_end = buffer + len;
737 DNSHeader header;
738 const char* cur = header.read(buffer, buffer_end);
739 // TODO(imaipi): for now, unparsable messages are silently dropped, fix.
740 if (cur == nullptr) {
741 LOG(ERROR) << "failed to parse query";
742 return false;
743 }
744 if (header.qr) {
745 LOG(ERROR) << "response received instead of a query";
746 return false;
747 }
748 if (header.opcode != ns_opcode::ns_o_query) {
749 LOG(INFO) << "unsupported request opcode received";
750 return makeErrorResponse(&header, ns_rcode::ns_r_notimpl, response, response_len);
751 }
752 if (header.questions.empty()) {
753 LOG(INFO) << "no questions present";
754 return makeErrorResponse(&header, ns_rcode::ns_r_formerr, response, response_len);
755 }
756 if (!header.answers.empty()) {
757 LOG(INFO) << "already " << header.answers.size() << " answers present in query";
758 return makeErrorResponse(&header, ns_rcode::ns_r_formerr, response, response_len);
759 }
760
761 if (edns_ == Edns::FORMERR_UNCOND) {
762 LOG(INFO) << "force to return RCODE FORMERR";
763 return makeErrorResponse(&header, ns_rcode::ns_r_formerr, response, response_len);
764 }
765
766 if (!header.additionals.empty() && edns_ != Edns::ON) {
767 LOG(INFO) << "DNS request has an additional section (assumed EDNS). Simulating an ancient "
768 "(pre-EDNS) server, and returning "
769 << (edns_ == Edns::FORMERR_ON_EDNS ? "RCODE FORMERR." : "no response.");
770 if (edns_ == Edns::FORMERR_ON_EDNS) {
771 return makeErrorResponse(&header, ns_rcode::ns_r_formerr, response, response_len);
772 }
773 // No response.
774 return false;
775 }
776 {
777 std::lock_guard lock(queries_mutex_);
778 for (const DNSQuestion& question : header.questions) {
779 queries_.push_back({question.qname.name, ns_type(question.qtype), protocol});
780 }
781 }
782 // Ignore requests with the preset probability.
783 auto constexpr bound = std::numeric_limits<unsigned>::max();
784 if (arc4random_uniform(bound) > bound * getResponseProbability(protocol)) {
785 if (error_rcode_ < 0) {
786 LOG(ERROR) << "Returning no response";
787 return false;
788 } else {
789 LOG(INFO) << "returning RCODE " << static_cast<int>(error_rcode_)
790 << " in accordance with probability distribution";
791 return makeErrorResponse(&header, error_rcode_, response, response_len);
792 }
793 }
794
795 // Make the response. The query has been read into |header| which is used to build and return
796 // the response as well.
797 return makeResponse(&header, protocol, response, response_len);
798 }
799
addAnswerRecords(const DNSQuestion & question,std::vector<DNSRecord> * answers) const800 bool DNSResponder::addAnswerRecords(const DNSQuestion& question,
801 std::vector<DNSRecord>* answers) const {
802 std::lock_guard guard(mappings_mutex_);
803 std::string rname = question.qname.name;
804 std::vector<int> rtypes;
805
806 if (question.qtype == ns_type::ns_t_a || question.qtype == ns_type::ns_t_aaaa ||
807 question.qtype == ns_type::ns_t_ptr)
808 rtypes.push_back(ns_type::ns_t_cname);
809 rtypes.push_back(question.qtype);
810 for (int rtype : rtypes) {
811 std::set<std::string> cnames_Loop;
812 std::unordered_map<QueryKey, std::string, QueryKeyHash>::const_iterator it;
813 while ((it = mappings_.find(QueryKey(rname, rtype))) != mappings_.end()) {
814 if (rtype == ns_type::ns_t_cname) {
815 // When detect CNAME infinite loops by cnames_Loop, it won't save the duplicate one.
816 // As following, the query will stop on loop3 by detecting the same cname.
817 // loop1.{"a.xxx.com", ns_type::ns_t_cname, "b.xxx.com"}(insert in answer record)
818 // loop2.{"b.xxx.com", ns_type::ns_t_cname, "a.xxx.com"}(insert in answer record)
819 // loop3.{"a.xxx.com", ns_type::ns_t_cname, "b.xxx.com"}(When the same cname record
820 // is found in cnames_Loop already, break the query loop.)
821 if (cnames_Loop.find(it->first.name) != cnames_Loop.end()) break;
822 cnames_Loop.insert(it->first.name);
823 }
824 DNSRecord record{
825 .name = {.name = it->first.name},
826 .rtype = it->first.type,
827 .rclass = ns_class::ns_c_in,
828 .ttl = answer_record_ttl_sec_, // seconds
829 };
830 if (!fillRdata(it->second, record)) return false;
831 answers->push_back(std::move(record));
832 if (rtype != ns_type::ns_t_cname) break;
833 rname = it->second;
834 }
835 }
836
837 if (answers->size() == 0) {
838 // TODO(imaipi): handle correctly
839 LOG(INFO) << "no mapping found for " << question.qname.name << " "
840 << dnstype2str(question.qtype) << ", lazily refusing to add an answer";
841 }
842
843 return true;
844 }
845
fillRdata(const std::string & rdatastr,DNSRecord & record)846 bool DNSResponder::fillRdata(const std::string& rdatastr, DNSRecord& record) {
847 if (record.rtype == ns_type::ns_t_a) {
848 record.rdata.resize(4);
849 if (inet_pton(AF_INET, rdatastr.c_str(), record.rdata.data()) != 1) {
850 LOG(ERROR) << "inet_pton(AF_INET, " << rdatastr << ") failed";
851 return false;
852 }
853 } else if (record.rtype == ns_type::ns_t_aaaa) {
854 record.rdata.resize(16);
855 if (inet_pton(AF_INET6, rdatastr.c_str(), record.rdata.data()) != 1) {
856 LOG(ERROR) << "inet_pton(AF_INET6, " << rdatastr << ") failed";
857 return false;
858 }
859 } else if ((record.rtype == ns_type::ns_t_ptr) || (record.rtype == ns_type::ns_t_cname) ||
860 (record.rtype == ns_type::ns_t_ns)) {
861 constexpr char delimiter = '.';
862 std::string name = rdatastr;
863 std::vector<char> rdata;
864
865 // Generating PTRDNAME field(section 3.3.12) or CNAME field(section 3.3.1) in rfc1035.
866 // The "name" should be an absolute domain name which ends in a dot.
867 if (name.back() != delimiter) {
868 LOG(ERROR) << "invalid absolute domain name";
869 return false;
870 }
871 name.pop_back(); // remove the dot in tail
872 for (const std::string& label : android::base::Split(name, {delimiter})) {
873 // The length of label is limited to 63 octets or less. See RFC 1035 section 3.1.
874 if (label.length() == 0 || label.length() > 63) {
875 LOG(ERROR) << "invalid label length";
876 return false;
877 }
878
879 rdata.push_back(label.length());
880 rdata.insert(rdata.end(), label.begin(), label.end());
881 }
882 rdata.push_back(0); // Length byte of zero terminates the label list
883
884 // The length of domain name is limited to 255 octets or less. See RFC 1035 section 3.1.
885 if (rdata.size() > 255) {
886 LOG(ERROR) << "invalid name length";
887 return false;
888 }
889 record.rdata = std::move(rdata);
890 } else {
891 LOG(ERROR) << "unhandled qtype " << dnstype2str(record.rtype);
892 return false;
893 }
894 return true;
895 }
896
writePacket(const DNSHeader * header,char * response,size_t * response_len) const897 bool DNSResponder::writePacket(const DNSHeader* header, char* response,
898 size_t* response_len) const {
899 char* response_cur = header->write(response, response + *response_len);
900 if (response_cur == nullptr) {
901 return false;
902 }
903 *response_len = response_cur - response;
904 return true;
905 }
906
makeErrorResponse(DNSHeader * header,ns_rcode rcode,char * response,size_t * response_len) const907 bool DNSResponder::makeErrorResponse(DNSHeader* header, ns_rcode rcode, char* response,
908 size_t* response_len) const {
909 header->answers.clear();
910 header->authorities.clear();
911 header->additionals.clear();
912 header->rcode = rcode;
913 header->qr = true;
914 return writePacket(header, response, response_len);
915 }
916
makeTruncatedResponse(DNSHeader * header,char * response,size_t * response_len) const917 bool DNSResponder::makeTruncatedResponse(DNSHeader* header, char* response,
918 size_t* response_len) const {
919 // Build a minimal response for non-EDNS response over UDP. Truncate all stub RRs in answer,
920 // authority and additional section. EDNS response truncation has not supported here yet
921 // because the EDNS response must have an OPT record. See RFC 6891 section 7.
922 header->answers.clear();
923 header->authorities.clear();
924 header->additionals.clear();
925 header->qr = true;
926 header->tr = true;
927 return writePacket(header, response, response_len);
928 }
929
makeResponse(DNSHeader * header,int protocol,char * response,size_t * response_len) const930 bool DNSResponder::makeResponse(DNSHeader* header, int protocol, char* response,
931 size_t* response_len) const {
932 char buffer[16384];
933 size_t buffer_len = sizeof(buffer);
934 bool ret;
935
936 switch (mapping_type_) {
937 case MappingType::DNS_HEADER:
938 ret = makeResponseFromDnsHeader(header, buffer, &buffer_len);
939 break;
940 case MappingType::BINARY_PACKET:
941 ret = makeResponseFromBinaryPacket(header, buffer, &buffer_len);
942 break;
943 case MappingType::ADDRESS_OR_HOSTNAME:
944 default:
945 ret = makeResponseFromAddressOrHostname(header, buffer, &buffer_len);
946 }
947
948 if (!ret) return false;
949
950 // Return truncated response if the built non-EDNS response size which is larger than 512 bytes
951 // will be responded over UDP. The truncated response implementation here just simply set up
952 // the TC bit and truncate all stub RRs in answer, authority and additional section. It is
953 // because the resolver will retry DNS query over TCP and use the full TCP response. See also
954 // RFC 1035 section 4.2.1 for UDP response truncation and RFC 6891 section 4.3 for EDNS larger
955 // response size capability.
956 // TODO: Perhaps keep the stub RRs as possible.
957 // TODO: Perhaps truncate the EDNS based response over UDP. See also RFC 6891 section 4.3,
958 // section 6.2.5 and section 7.
959 if (protocol == IPPROTO_UDP && buffer_len > kMaximumUdpSize &&
960 !hasOptPseudoRR(header) /* non-EDNS */) {
961 LOG(INFO) << "Return truncated response because original response length " << buffer_len
962 << " is larger than " << kMaximumUdpSize << " bytes.";
963 return makeTruncatedResponse(header, response, response_len);
964 }
965
966 if (buffer_len > *response_len) {
967 LOG(ERROR) << "buffer overflow on line " << __LINE__;
968 return false;
969 }
970 memcpy(response, buffer, buffer_len);
971 *response_len = buffer_len;
972 return true;
973 }
974
makeResponseFromAddressOrHostname(DNSHeader * header,char * response,size_t * response_len) const975 bool DNSResponder::makeResponseFromAddressOrHostname(DNSHeader* header, char* response,
976 size_t* response_len) const {
977 for (const DNSQuestion& question : header->questions) {
978 if (question.qclass != ns_class::ns_c_in && question.qclass != ns_class::ns_c_any) {
979 LOG(INFO) << "unsupported question class " << question.qclass;
980 return makeErrorResponse(header, ns_rcode::ns_r_notimpl, response, response_len);
981 }
982
983 if (!addAnswerRecords(question, &header->answers)) {
984 return makeErrorResponse(header, ns_rcode::ns_r_servfail, response, response_len);
985 }
986 }
987 header->qr = true;
988 return writePacket(header, response, response_len);
989 }
990
makeResponseFromDnsHeader(DNSHeader * header,char * response,size_t * response_len) const991 bool DNSResponder::makeResponseFromDnsHeader(DNSHeader* header, char* response,
992 size_t* response_len) const {
993 std::lock_guard guard(mappings_mutex_);
994
995 // Support single question record only. It should be okay because res_mkquery() sets "qdcount"
996 // as one for the operation QUERY and handleDNSRequest() checks ns_opcode::ns_o_query before
997 // making a response. In other words, only need to handle the query which has single question
998 // section. See also res_mkquery() in system/netd/resolv/res_mkquery.cpp.
999 // TODO: Perhaps add support for multi-question records.
1000 const std::vector<DNSQuestion>& questions = header->questions;
1001 if (questions.size() != 1) {
1002 LOG(INFO) << "unsupported question count " << questions.size();
1003 return makeErrorResponse(header, ns_rcode::ns_r_notimpl, response, response_len);
1004 }
1005
1006 if (questions[0].qclass != ns_class::ns_c_in && questions[0].qclass != ns_class::ns_c_any) {
1007 LOG(INFO) << "unsupported question class " << questions[0].qclass;
1008 return makeErrorResponse(header, ns_rcode::ns_r_notimpl, response, response_len);
1009 }
1010
1011 const std::string name = questions[0].qname.name;
1012 const int qtype = questions[0].qtype;
1013 const auto it = dnsheader_mappings_.find(QueryKey(name, qtype));
1014 if (it != dnsheader_mappings_.end()) {
1015 // Store both "id" and "rd" which comes from query.
1016 const unsigned id = header->id;
1017 const bool rd = header->rd;
1018
1019 // Build a response from the registered DNSHeader mapping.
1020 *header = it->second;
1021 // Assign both "ID" and "RD" fields from query to response. See RFC 1035 section 4.1.1.
1022 header->id = id;
1023 header->rd = rd;
1024 } else {
1025 // TODO: handle correctly. See also TODO in addAnswerRecords().
1026 LOG(INFO) << "no mapping found for " << name << " " << dnstype2str(qtype)
1027 << ", couldn't build a response from DNSHeader mapping";
1028
1029 // Note that do nothing as makeResponseFromAddressOrHostname() if no mapping is found. It
1030 // just changes the QR flag from query (0) to response (1) in the query. Then, send the
1031 // modified query back as a response.
1032 header->qr = true;
1033 }
1034 return writePacket(header, response, response_len);
1035 }
1036
makeResponseFromBinaryPacket(DNSHeader * header,char * response,size_t * response_len) const1037 bool DNSResponder::makeResponseFromBinaryPacket(DNSHeader* header, char* response,
1038 size_t* response_len) const {
1039 std::lock_guard guard(mappings_mutex_);
1040
1041 // Build a search key of mapping from the query.
1042 // TODO: Perhaps pass the query packet buffer directly from the caller.
1043 std::vector<uint8_t> queryKey;
1044 if (!header->write(&queryKey)) return false;
1045 // Clear ID field (byte 0-1) because it is not required by the mapping key.
1046 queryKey[0] = 0;
1047 queryKey[1] = 0;
1048
1049 const auto it = packet_mappings_.find(queryKey);
1050 if (it != packet_mappings_.end()) {
1051 if (it->second.size() > *response_len) {
1052 LOG(ERROR) << "buffer overflow on line " << __LINE__;
1053 return false;
1054 } else {
1055 std::copy(it->second.begin(), it->second.end(), response);
1056 // Leave the "RD" flag assignment for testing. The "RD" flag of the response keep
1057 // using the one from the raw packet mapping but the received query.
1058 // Assign "ID" field from query to response. See RFC 1035 section 4.1.1.
1059 reinterpret_cast<uint16_t*>(response)[0] = htons(header->id); // bytes 0-1: id
1060 *response_len = it->second.size();
1061 return true;
1062 }
1063 } else {
1064 // TODO: handle correctly. See also TODO in addAnswerRecords().
1065 // TODO: Perhaps dump packet content to indicate which query failed.
1066 LOG(INFO) << "no mapping found, couldn't build a response from BinaryPacket mapping";
1067 // Note that do nothing as makeResponseFromAddressOrHostname() if no mapping is found. It
1068 // just changes the QR flag from query (0) to response (1) in the query. Then, send the
1069 // modified query back as a response.
1070 header->qr = true;
1071 return writePacket(header, response, response_len);
1072 }
1073 }
1074
setDeferredResp(bool deferred_resp)1075 void DNSResponder::setDeferredResp(bool deferred_resp) {
1076 std::lock_guard<std::mutex> guard(cv_mutex_for_deferred_resp_);
1077 deferred_resp_ = deferred_resp;
1078 if (!deferred_resp_) {
1079 cv_for_deferred_resp_.notify_one();
1080 }
1081 }
1082
addFd(int fd,uint32_t events)1083 bool DNSResponder::addFd(int fd, uint32_t events) {
1084 epoll_event ev;
1085 ev.events = events;
1086 ev.data.fd = fd;
1087 if (epoll_ctl(epoll_fd_.get(), EPOLL_CTL_ADD, fd, &ev) < 0) {
1088 PLOG(ERROR) << "epoll_ctl() for socket " << fd << " failed";
1089 return false;
1090 }
1091 return true;
1092 }
1093
handleQuery(int protocol)1094 void DNSResponder::handleQuery(int protocol) {
1095 char buffer[16384];
1096 sockaddr_storage sa;
1097 socklen_t sa_len = sizeof(sa);
1098 ssize_t len = 0;
1099 unique_fd tcpFd;
1100 switch (protocol) {
1101 case IPPROTO_UDP:
1102 do {
1103 len = recvfrom(udp_socket_.get(), buffer, sizeof(buffer), 0, (sockaddr*)&sa,
1104 &sa_len);
1105 } while (len < 0 && (errno == EAGAIN || errno == EINTR));
1106 if (len <= 0) {
1107 PLOG(ERROR) << "recvfrom() failed, len=" << len;
1108 return;
1109 }
1110 break;
1111 case IPPROTO_TCP:
1112 tcpFd.reset(accept4(tcp_socket_.get(), reinterpret_cast<sockaddr*>(&sa), &sa_len,
1113 SOCK_CLOEXEC));
1114 if (tcpFd.get() < 0) {
1115 PLOG(ERROR) << "failed to accept client socket";
1116 return;
1117 }
1118 // Get the message length from two byte length field.
1119 // See also RFC 1035, section 4.2.2 and RFC 7766, section 8
1120 uint8_t queryMessageLengthField[2];
1121 if (read(tcpFd.get(), &queryMessageLengthField, 2) != 2) {
1122 PLOG(ERROR) << "Not enough length field bytes";
1123 return;
1124 }
1125
1126 const uint16_t qlen = (queryMessageLengthField[0] << 8) | queryMessageLengthField[1];
1127 while (len < qlen) {
1128 ssize_t ret = read(tcpFd.get(), buffer + len, qlen - len);
1129 if (ret <= 0) {
1130 PLOG(ERROR) << "Error while reading query";
1131 return;
1132 }
1133 len += ret;
1134 }
1135 break;
1136 }
1137 LOG(DEBUG) << "read " << len << " bytes on " << dnsproto2str(protocol);
1138 std::lock_guard lock(cv_mutex_);
1139 char response[16384];
1140 size_t response_len = sizeof(response);
1141 // TODO: check whether sending malformed packets to DnsResponder
1142 if (handleDNSRequest(buffer, len, protocol, response, &response_len) && response_len > 0) {
1143 std::this_thread::sleep_for(std::chrono::milliseconds(response_delayed_ms_));
1144 // place wait_for after handleDNSRequest() so we can check the number of queries in
1145 // test case before it got responded.
1146 std::unique_lock guard(cv_mutex_for_deferred_resp_);
1147 cv_for_deferred_resp_.wait(
1148 guard, [this]() REQUIRES(cv_mutex_for_deferred_resp_) { return !deferred_resp_; });
1149 len = 0;
1150
1151 switch (protocol) {
1152 case IPPROTO_UDP:
1153 len = sendto(udp_socket_.get(), response, response_len, 0,
1154 reinterpret_cast<const sockaddr*>(&sa), sa_len);
1155 if (len < 0) {
1156 PLOG(ERROR) << "Failed to send response";
1157 }
1158 break;
1159 case IPPROTO_TCP:
1160 // Get the message length from two byte length field.
1161 // See also RFC 1035, section 4.2.2 and RFC 7766, section 8
1162 uint8_t responseMessageLengthField[2];
1163 responseMessageLengthField[0] = response_len >> 8;
1164 responseMessageLengthField[1] = response_len;
1165 if (write(tcpFd.get(), responseMessageLengthField, 2) != 2) {
1166 PLOG(ERROR) << "Failed to write response length field";
1167 break;
1168 }
1169 if (write(tcpFd.get(), response, response_len) !=
1170 static_cast<ssize_t>(response_len)) {
1171 PLOG(ERROR) << "Failed to write response";
1172 break;
1173 }
1174 len = response_len;
1175 break;
1176 }
1177 const std::string host_str = addr2str(reinterpret_cast<const sockaddr*>(&sa), sa_len);
1178 if (len > 0) {
1179 LOG(DEBUG) << "sent " << len << " bytes to " << host_str;
1180 } else {
1181 const char* method_str = (protocol == IPPROTO_TCP) ? "write()" : "sendto()";
1182 LOG(ERROR) << method_str << " failed for " << host_str;
1183 }
1184 // Test that the response is actually a correct DNS message.
1185 // TODO: Perhaps make DNS message validation to support name compression. Or it throws
1186 // a warning for a valid DNS message with name compression while the binary packet mapping
1187 // is used.
1188 const char* response_end = response + len;
1189 DNSHeader header;
1190 const char* cur = header.read(response, response_end);
1191 if (cur == nullptr) LOG(WARNING) << "response is flawed";
1192 } else {
1193 LOG(WARNING) << "not responding";
1194 }
1195 cv.notify_one();
1196 return;
1197 }
1198
sendToEventFd()1199 bool DNSResponder::sendToEventFd() {
1200 const uint64_t data = 1;
1201 if (const ssize_t rt = write(event_fd_.get(), &data, sizeof(data)); rt != sizeof(data)) {
1202 PLOG(ERROR) << "failed to write eventfd, rt=" << rt;
1203 return false;
1204 }
1205 return true;
1206 }
1207
handleEventFd()1208 void DNSResponder::handleEventFd() {
1209 int64_t data;
1210 if (const ssize_t rt = read(event_fd_.get(), &data, sizeof(data)); rt != sizeof(data)) {
1211 PLOG(INFO) << "ignore reading eventfd failed, rt=" << rt;
1212 }
1213 }
1214
createListeningSocket(int socket_type)1215 unique_fd DNSResponder::createListeningSocket(int socket_type) {
1216 addrinfo ai_hints{
1217 .ai_flags = AI_PASSIVE,
1218 .ai_family = AF_UNSPEC,
1219 .ai_socktype = socket_type,
1220 };
1221 addrinfo* ai_res = nullptr;
1222 const int rv =
1223 getaddrinfo(listen_address_.c_str(), listen_service_.c_str(), &ai_hints, &ai_res);
1224 ScopedAddrinfo ai_res_cleanup(ai_res);
1225 if (rv) {
1226 LOG(ERROR) << "getaddrinfo(" << listen_address_ << ", " << listen_service_
1227 << ") failed: " << gai_strerror(rv);
1228 return {};
1229 }
1230 for (const addrinfo* ai = ai_res; ai; ai = ai->ai_next) {
1231 unique_fd fd(socket(ai->ai_family, ai->ai_socktype | SOCK_NONBLOCK, ai->ai_protocol));
1232 if (fd.get() < 0) {
1233 PLOG(ERROR) << "ignore creating socket failed";
1234 continue;
1235 }
1236
1237 enableSockopt(fd.get(), SOL_SOCKET, SO_REUSEADDR).ignoreError();
1238 const std::string host_str = addr2str(ai->ai_addr, ai->ai_addrlen);
1239 if ((listen_service_ == kDefaultMdnsListenService) && (socket_type == SOCK_DGRAM)) {
1240 const int mdns_port = 5353;
1241 const char mdns_multiaddrv4[] = "224.0.0.251";
1242 const char mdns_multiaddrv6[] = "ff02::fb";
1243 if (ai_res->ai_family == AF_INET) {
1244 // Join the MDNS IPV4 multicast group
1245 struct ip_mreq mreq;
1246 mreq.imr_multiaddr.s_addr = inet_addr(mdns_multiaddrv4);
1247 mreq.imr_interface.s_addr = inet_addr(host_str.c_str());
1248 if (setsockopt(fd.get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &mreq,
1249 sizeof(struct ip_mreq)) == -1) {
1250 LOG(ERROR) << "Error set setsockopt for IP_ADD_MEMBERSHIP ";
1251 return {};
1252 }
1253 struct sockaddr_in addr = {.sin_family = AF_INET,
1254 .sin_port = htons(mdns_port),
1255 .sin_addr = {INADDR_ANY}};
1256 if (auto result = bindSocket(fd.get(), (struct sockaddr*)&addr, sizeof(addr));
1257 !result.ok()) {
1258 LOG(ERROR) << "failed to bind. MDNS IPv4: " << result.error().message();
1259 return {};
1260 }
1261 } else if (ai_res->ai_family == AF_INET6) {
1262 // Join the MDNS IPV6 multicast group
1263 struct ipv6_mreq mreqv6;
1264 inet_pton(AF_INET6, mdns_multiaddrv6, &mreqv6.ipv6mr_multiaddr.s6_addr);
1265 mreqv6.ipv6mr_interface = 0;
1266 if (setsockopt(fd.get(), IPPROTO_IPV6, IPV6_JOIN_GROUP, &mreqv6, sizeof(mreqv6)) ==
1267 -1) {
1268 LOG(ERROR) << "Error set setsockopt for IPV6_JOIN_GROUP ";
1269 return {};
1270 }
1271 struct sockaddr_in6 addr = {
1272 .sin6_family = AF_INET6,
1273 .sin6_port = htons(mdns_port),
1274 .sin6_addr = IN6ADDR_ANY_INIT,
1275 };
1276 if (auto result = bindSocket(fd.get(), (struct sockaddr*)&addr, sizeof(addr));
1277 !result.ok()) {
1278 LOG(ERROR) << "failed to bind. MDNS IPV6: " << result.error().message();
1279 return {};
1280 }
1281 }
1282 return fd;
1283 } else {
1284 const char* socket_str = (socket_type == SOCK_STREAM) ? "TCP" : "UDP";
1285 if (auto result = bindSocket(fd.get(), ai->ai_addr, ai->ai_addrlen); !result.ok()) {
1286 LOG(ERROR) << "failed to bind " << socket_str << " " << host_str << ":"
1287 << listen_service_ << " " << result.error().message();
1288 continue;
1289 }
1290 LOG(INFO) << "bound to " << socket_str << " " << host_str << ":" << listen_service_;
1291 return fd;
1292 }
1293 }
1294 return {};
1295 }
1296
1297 } // namespace test
1298