1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "DnsTlsSocket"
18 //#define LOG_NDEBUG 0
19 
20 #include "dns/DnsTlsSocket.h"
21 
22 #include <algorithm>
23 #include <arpa/inet.h>
24 #include <arpa/nameser.h>
25 #include <errno.h>
26 #include <linux/tcp.h>
27 #include <openssl/err.h>
28 #include <sys/poll.h>
29 
30 #include "dns/DnsTlsSessionCache.h"
31 #include "dns/IDnsTlsSocketObserver.h"
32 
33 #include "log/log.h"
34 #include "netdutils/SocketOption.h"
35 #include "Fwmark.h"
36 #undef ADD  // already defined in nameser.h
37 #include "NetdConstants.h"
38 #include "Permission.h"
39 
40 
41 namespace android {
42 
43 using netdutils::enableSockopt;
44 using netdutils::enableTcpKeepAlives;
45 using netdutils::isOk;
46 using netdutils::Status;
47 
48 namespace net {
49 namespace {
50 
51 constexpr const char kCaCertDir[] = "/system/etc/security/cacerts";
52 
waitForReading(int fd)53 int waitForReading(int fd) {
54     struct pollfd fds = { .fd = fd, .events = POLLIN };
55     const int ret = TEMP_FAILURE_RETRY(poll(&fds, 1, -1));
56     return ret;
57 }
58 
waitForWriting(int fd)59 int waitForWriting(int fd) {
60     struct pollfd fds = { .fd = fd, .events = POLLOUT };
61     const int ret = TEMP_FAILURE_RETRY(poll(&fds, 1, -1));
62     return ret;
63 }
64 
65 }  // namespace
66 
tcpConnect()67 Status DnsTlsSocket::tcpConnect() {
68     ALOGV("%u connecting TCP socket", mMark);
69     int type = SOCK_NONBLOCK | SOCK_CLOEXEC;
70     switch (mServer.protocol) {
71         case IPPROTO_TCP:
72             type |= SOCK_STREAM;
73             break;
74         default:
75             return Status(EPROTONOSUPPORT);
76     }
77 
78     mSslFd.reset(socket(mServer.ss.ss_family, type, mServer.protocol));
79     if (mSslFd.get() == -1) {
80         ALOGE("Failed to create socket");
81         return Status(errno);
82     }
83 
84     const socklen_t len = sizeof(mMark);
85     if (setsockopt(mSslFd.get(), SOL_SOCKET, SO_MARK, &mMark, len) == -1) {
86         ALOGE("Failed to set socket mark");
87         mSslFd.reset();
88         return Status(errno);
89     }
90 
91     const Status tfo = enableSockopt(mSslFd.get(), SOL_TCP, TCP_FASTOPEN_CONNECT);
92     if (!isOk(tfo) && tfo.code() != ENOPROTOOPT) {
93         ALOGI("Failed to enable TFO: %s", tfo.msg().c_str());
94     }
95 
96     // Send 5 keepalives, 3 seconds apart, after 15 seconds of inactivity.
97     enableTcpKeepAlives(mSslFd.get(), 15U, 5U, 3U);
98 
99     if (connect(mSslFd.get(), reinterpret_cast<const struct sockaddr *>(&mServer.ss),
100                 sizeof(mServer.ss)) != 0 &&
101             errno != EINPROGRESS) {
102         ALOGV("Socket failed to connect");
103         mSslFd.reset();
104         return Status(errno);
105     }
106 
107     return netdutils::status::ok;
108 }
109 
getSPKIDigest(const X509 * cert,std::vector<uint8_t> * out)110 bool getSPKIDigest(const X509* cert, std::vector<uint8_t>* out) {
111     int spki_len = i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), NULL);
112     unsigned char spki[spki_len];
113     unsigned char* temp = spki;
114     if (spki_len != i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), &temp)) {
115         ALOGW("SPKI length mismatch");
116         return false;
117     }
118     out->resize(SHA256_SIZE);
119     unsigned int digest_len = 0;
120     int ret = EVP_Digest(spki, spki_len, out->data(), &digest_len, EVP_sha256(), NULL);
121     if (ret != 1) {
122         ALOGW("Server cert digest extraction failed");
123         return false;
124     }
125     if (digest_len != out->size()) {
126         ALOGW("Wrong digest length: %d", digest_len);
127         return false;
128     }
129     return true;
130 }
131 
initialize()132 bool DnsTlsSocket::initialize() {
133     // This method should only be called once, at the beginning, so locking should be
134     // unnecessary.  This lock only serves to help catch bugs in code that calls this method.
135     std::lock_guard<std::mutex> guard(mLock);
136     if (mSslCtx) {
137         // This is a bug in the caller.
138         return false;
139     }
140     mSslCtx.reset(SSL_CTX_new(TLS_method()));
141     if (!mSslCtx) {
142         return false;
143     }
144 
145     // Load system CA certs for hostname verification.
146     //
147     // For discussion of alternative, sustainable approaches see b/71909242.
148     if (SSL_CTX_load_verify_locations(mSslCtx.get(), nullptr, kCaCertDir) != 1) {
149         ALOGE("Failed to load CA cert dir: %s", kCaCertDir);
150         return false;
151     }
152 
153     // Enable TLS false start
154     SSL_CTX_set_false_start_allowed_without_alpn(mSslCtx.get(), 1);
155     SSL_CTX_set_mode(mSslCtx.get(), SSL_MODE_ENABLE_FALSE_START);
156 
157     // Enable session cache
158     mCache->prepareSslContext(mSslCtx.get());
159 
160     // Connect
161     Status status = tcpConnect();
162     if (!status.ok()) {
163         return false;
164     }
165     mSsl = sslConnect(mSslFd.get());
166     if (!mSsl) {
167         return false;
168     }
169     int sv[2];
170     if (socketpair(AF_LOCAL, SOCK_SEQPACKET, 0, sv)) {
171         return false;
172     }
173     // The two sockets are perfectly symmetrical, so the choice of which one is
174     // "in" and which one is "out" is arbitrary.
175     mIpcInFd.reset(sv[0]);
176     mIpcOutFd.reset(sv[1]);
177 
178     // Start the I/O loop.
179     mLoopThread.reset(new std::thread(&DnsTlsSocket::loop, this));
180 
181     return true;
182 }
183 
sslConnect(int fd)184 bssl::UniquePtr<SSL> DnsTlsSocket::sslConnect(int fd) {
185     if (!mSslCtx) {
186         ALOGE("Internal error: context is null in sslConnect");
187         return nullptr;
188     }
189     if (!SSL_CTX_set_min_proto_version(mSslCtx.get(), TLS1_2_VERSION)) {
190         ALOGE("Failed to set minimum TLS version");
191         return nullptr;
192     }
193 
194     bssl::UniquePtr<SSL> ssl(SSL_new(mSslCtx.get()));
195     // This file descriptor is owned by mSslFd, so don't let libssl close it.
196     bssl::UniquePtr<BIO> bio(BIO_new_socket(fd, BIO_NOCLOSE));
197     SSL_set_bio(ssl.get(), bio.get(), bio.get());
198     bio.release();
199 
200     if (!mCache->prepareSsl(ssl.get())) {
201         return nullptr;
202     }
203 
204     if (!mServer.name.empty()) {
205         if (SSL_set_tlsext_host_name(ssl.get(), mServer.name.c_str()) != 1) {
206             ALOGE("Failed to set SNI to %s", mServer.name.c_str());
207             return nullptr;
208         }
209         X509_VERIFY_PARAM* param = SSL_get0_param(ssl.get());
210         if (X509_VERIFY_PARAM_set1_host(param, mServer.name.data(), mServer.name.size()) != 1) {
211             ALOGE("Failed to set verify host param to %s", mServer.name.c_str());
212             return nullptr;
213         }
214         // This will cause the handshake to fail if certificate verification fails.
215         SSL_set_verify(ssl.get(), SSL_VERIFY_PEER, nullptr);
216     }
217 
218     bssl::UniquePtr<SSL_SESSION> session = mCache->getSession();
219     if (session) {
220         ALOGV("Setting session");
221         SSL_set_session(ssl.get(), session.get());
222     } else {
223         ALOGV("No session available");
224     }
225 
226     for (;;) {
227         ALOGV("%u Calling SSL_connect", mMark);
228         int ret = SSL_connect(ssl.get());
229         ALOGV("%u SSL_connect returned %d", mMark, ret);
230         if (ret == 1) break;  // SSL handshake complete;
231 
232         const int ssl_err = SSL_get_error(ssl.get(), ret);
233         switch (ssl_err) {
234             case SSL_ERROR_WANT_READ:
235                 if (waitForReading(fd) != 1) {
236                     ALOGW("SSL_connect read error: %d", errno);
237                     return nullptr;
238                 }
239                 break;
240             case SSL_ERROR_WANT_WRITE:
241                 if (waitForWriting(fd) != 1) {
242                     ALOGW("SSL_connect write error");
243                     return nullptr;
244                 }
245                 break;
246             default:
247                 ALOGW("SSL_connect error %d, errno=%d", ssl_err, errno);
248                 return nullptr;
249         }
250     }
251 
252     // TODO: Call SSL_shutdown before discarding the session if validation fails.
253     if (!mServer.fingerprints.empty()) {
254         ALOGV("Checking DNS over TLS fingerprint");
255 
256         // We only care that the chain is internally self-consistent, not that
257         // it chains to a trusted root, so we can ignore some kinds of errors.
258         // TODO: Add a CA root verification mode that respects these errors.
259         int verify_result = SSL_get_verify_result(ssl.get());
260         switch (verify_result) {
261             case X509_V_OK:
262             case X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT:
263             case X509_V_ERR_SELF_SIGNED_CERT_IN_CHAIN:
264             case X509_V_ERR_CERT_UNTRUSTED:
265                 break;
266             default:
267                 ALOGW("Invalid certificate chain, error %d", verify_result);
268                 return nullptr;
269         }
270 
271         STACK_OF(X509) *chain = SSL_get_peer_cert_chain(ssl.get());
272         if (!chain) {
273             ALOGW("Server has null certificate");
274             return nullptr;
275         }
276         // Chain and its contents are owned by ssl, so we don't need to free explicitly.
277         bool matched = false;
278         for (size_t i = 0; i < sk_X509_num(chain); ++i) {
279             // This appears to be O(N^2), but there doesn't seem to be a straightforward
280             // way to walk a STACK_OF nondestructively in linear time.
281             X509* cert = sk_X509_value(chain, i);
282             std::vector<uint8_t> digest;
283             if (!getSPKIDigest(cert, &digest)) {
284                 ALOGE("Digest computation failed");
285                 return nullptr;
286             }
287 
288             if (mServer.fingerprints.count(digest) > 0) {
289                 matched = true;
290                 break;
291             }
292         }
293 
294         if (!matched) {
295             ALOGW("No matching fingerprint");
296             return nullptr;
297         }
298 
299         ALOGV("DNS over TLS fingerprint is correct");
300     }
301 
302     ALOGV("%u handshake complete", mMark);
303 
304     return ssl;
305 }
306 
sslDisconnect()307 void DnsTlsSocket::sslDisconnect() {
308     if (mSsl) {
309         SSL_shutdown(mSsl.get());
310         mSsl.reset();
311     }
312     mSslFd.reset();
313 }
314 
sslWrite(const Slice buffer)315 bool DnsTlsSocket::sslWrite(const Slice buffer) {
316     ALOGV("%u Writing %zu bytes", mMark, buffer.size());
317     for (;;) {
318         int ret = SSL_write(mSsl.get(), buffer.base(), buffer.size());
319         if (ret == int(buffer.size())) break;  // SSL write complete;
320 
321         if (ret < 1) {
322             const int ssl_err = SSL_get_error(mSsl.get(), ret);
323             switch (ssl_err) {
324                 case SSL_ERROR_WANT_WRITE:
325                     if (waitForWriting(mSslFd.get()) != 1) {
326                         ALOGV("SSL_write error");
327                         return false;
328                     }
329                     continue;
330                 case 0:
331                     break;  // SSL write complete;
332                 default:
333                     ALOGV("SSL_write error %d", ssl_err);
334                     return false;
335             }
336         }
337     }
338     ALOGV("%u Wrote %zu bytes", mMark, buffer.size());
339     return true;
340 }
341 
loop()342 void DnsTlsSocket::loop() {
343     std::lock_guard<std::mutex> guard(mLock);
344     // Buffer at most one query.
345     Query q;
346 
347     const int timeout_msecs = DnsTlsSocket::kIdleTimeout.count() * 1000;
348     while (true) {
349         // poll() ignores negative fds
350         struct pollfd fds[2] = { { .fd = -1 }, { .fd = -1 } };
351         enum { SSLFD = 0, IPCFD = 1 };
352 
353         // Always listen for a response from server.
354         fds[SSLFD].fd = mSslFd.get();
355         fds[SSLFD].events = POLLIN;
356 
357         // If we have a pending query, also wait for space
358         // to write it, otherwise listen for a new query.
359         if (!q.query.empty()) {
360             fds[SSLFD].events |= POLLOUT;
361         } else {
362             fds[IPCFD].fd = mIpcOutFd.get();
363             fds[IPCFD].events = POLLIN;
364         }
365 
366         const int s = TEMP_FAILURE_RETRY(poll(fds, ARRAY_SIZE(fds), timeout_msecs));
367         if (s == 0) {
368             ALOGV("Idle timeout");
369             break;
370         }
371         if (s < 0) {
372             ALOGV("Poll failed: %d", errno);
373             break;
374         }
375         if (fds[SSLFD].revents & (POLLIN | POLLERR)) {
376             if (!readResponse()) {
377                 ALOGV("SSL remote close or read error.");
378                 break;
379             }
380         }
381         if (fds[IPCFD].revents & (POLLIN | POLLERR)) {
382             int res = read(mIpcOutFd.get(), &q, sizeof(q));
383             if (res < 0) {
384                 ALOGW("Error during IPC read");
385                 break;
386             } else if (res == 0) {
387                 ALOGV("IPC channel closed; disconnecting");
388                 break;
389             } else if (res != sizeof(q)) {
390                 ALOGE("Struct size mismatch: %d != %zu", res, sizeof(q));
391                 break;
392             }
393         } else if (fds[SSLFD].revents & POLLOUT) {
394             // query cannot be null here.
395             if (!sendQuery(q)) {
396                 break;
397             }
398             q = Query();  // Reset q to empty
399         }
400     }
401     ALOGV("Closing IPC read FD");
402     mIpcOutFd.reset();
403     ALOGV("Disconnecting");
404     sslDisconnect();
405     ALOGV("Calling onClosed");
406     mObserver->onClosed();
407     ALOGV("Ending loop");
408 }
409 
~DnsTlsSocket()410 DnsTlsSocket::~DnsTlsSocket() {
411     ALOGV("Destructor");
412     // This will trigger an orderly shutdown in loop().
413     mIpcInFd.reset();
414     {
415         // Wait for the orderly shutdown to complete.
416         std::lock_guard<std::mutex> guard(mLock);
417         if (mLoopThread && std::this_thread::get_id() == mLoopThread->get_id()) {
418             ALOGE("Violation of re-entrance precondition");
419             return;
420         }
421     }
422     if (mLoopThread) {
423         ALOGV("Waiting for loop thread to terminate");
424         mLoopThread->join();
425         mLoopThread.reset();
426     }
427     ALOGV("Destructor completed");
428 }
429 
query(uint16_t id,const Slice query)430 bool DnsTlsSocket::query(uint16_t id, const Slice query) {
431     const Query q = { .id = id, .query = query };
432     if (!mIpcInFd) {
433         return false;
434     }
435     int written = write(mIpcInFd.get(), &q, sizeof(q));
436     return written == sizeof(q);
437 }
438 
439 // Read exactly len bytes into buffer or fail with an SSL error code
sslRead(const Slice buffer,bool wait)440 int DnsTlsSocket::sslRead(const Slice buffer, bool wait) {
441     size_t remaining = buffer.size();
442     while (remaining > 0) {
443         int ret = SSL_read(mSsl.get(), buffer.limit() - remaining, remaining);
444         if (ret == 0) {
445             ALOGW_IF(remaining < buffer.size(), "SSL closed with %zu of %zu bytes remaining",
446                      remaining, buffer.size());
447             return SSL_ERROR_ZERO_RETURN;
448         }
449 
450         if (ret < 0) {
451             const int ssl_err = SSL_get_error(mSsl.get(), ret);
452             if (wait && ssl_err == SSL_ERROR_WANT_READ) {
453                 if (waitForReading(mSslFd.get()) != 1) {
454                     ALOGV("Poll failed in sslRead: %d", errno);
455                     return SSL_ERROR_SYSCALL;
456                 }
457                 continue;
458             } else {
459                 ALOGV("SSL_read error %d", ssl_err);
460                 return ssl_err;
461             }
462         }
463 
464         remaining -= ret;
465         wait = true;  // Once a read is started, try to finish.
466     }
467     return SSL_ERROR_NONE;
468 }
469 
sendQuery(const Query & q)470 bool DnsTlsSocket::sendQuery(const Query& q) {
471     ALOGV("sending query");
472     // Compose the entire message in a single buffer, so that it can be
473     // sent as a single TLS record.
474     std::vector<uint8_t> buf(q.query.size() + 4);
475     // Write 2-byte length
476     uint16_t len = q.query.size() + 2; // + 2 for the ID.
477     buf[0] = len >> 8;
478     buf[1] = len;
479     // Write 2-byte ID
480     buf[2] = q.id >> 8;
481     buf[3] = q.id;
482     // Copy body
483     std::memcpy(buf.data() + 4, q.query.base(), q.query.size());
484     if (!sslWrite(netdutils::makeSlice(buf))) {
485         return false;
486     }
487     ALOGV("%u SSL_write complete", mMark);
488     return true;
489 }
490 
readResponse()491 bool DnsTlsSocket::readResponse() {
492     ALOGV("reading response");
493     uint8_t responseHeader[2];
494     int err = sslRead(Slice(responseHeader, 2), false);
495     if (err == SSL_ERROR_WANT_READ) {
496         ALOGV("Ignoring spurious wakeup from server");
497         return true;
498     }
499     if (err != SSL_ERROR_NONE) {
500         return false;
501     }
502     // Truncate responses larger than MAX_SIZE.  This is safe because a DNS packet is
503     // always invalid when truncated, so the response will be treated as an error.
504     constexpr uint16_t MAX_SIZE = 8192;
505     const uint16_t responseSize = (responseHeader[0] << 8) | responseHeader[1];
506     ALOGV("%u Expecting response of size %i", mMark, responseSize);
507     std::vector<uint8_t> response(std::min(responseSize, MAX_SIZE));
508     if (sslRead(netdutils::makeSlice(response), true) != SSL_ERROR_NONE) {
509         ALOGV("%u Failed to read %zu bytes", mMark, response.size());
510         return false;
511     }
512     uint16_t remainingBytes = responseSize - response.size();
513     while (remainingBytes > 0) {
514         constexpr uint16_t CHUNK_SIZE = 2048;
515         std::vector<uint8_t> discard(std::min(remainingBytes, CHUNK_SIZE));
516         if (sslRead(netdutils::makeSlice(discard), true) != SSL_ERROR_NONE) {
517             ALOGV("%u Failed to discard %zu bytes", mMark, discard.size());
518             return false;
519         }
520         remainingBytes -= discard.size();
521     }
522     ALOGV("%u SSL_read complete", mMark);
523 
524     mObserver->onResponse(std::move(response));
525     return true;
526 }
527 
528 }  // end of namespace net
529 }  // end of namespace android
530