1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "dns_tls_frontend.h"
18 
19 #include <netdb.h>
20 #include <stdio.h>
21 #include <unistd.h>
22 #include <sys/poll.h>
23 #include <sys/socket.h>
24 #include <sys/types.h>
25 #include <arpa/inet.h>
26 #include <openssl/err.h>
27 #include <openssl/evp.h>
28 #include <openssl/ssl.h>
29 #include <unistd.h>
30 
31 #define LOG_TAG "DnsTlsFrontend"
32 #include <log/log.h>
33 #include <netdutils/SocketOption.h>
34 
35 using android::netdutils::enableSockopt;
36 
37 namespace {
38 
39 // Copied from DnsTlsTransport.
getSPKIDigest(const X509 * cert,std::vector<uint8_t> * out)40 bool getSPKIDigest(const X509* cert, std::vector<uint8_t>* out) {
41     int spki_len = i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), NULL);
42     unsigned char spki[spki_len];
43     unsigned char* temp = spki;
44     if (spki_len != i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), &temp)) {
45         ALOGE("SPKI length mismatch");
46         return false;
47     }
48     out->resize(test::SHA256_SIZE);
49     unsigned int digest_len = 0;
50     int ret = EVP_Digest(spki, spki_len, out->data(), &digest_len, EVP_sha256(), NULL);
51     if (ret != 1) {
52         ALOGE("Server cert digest extraction failed");
53         return false;
54     }
55     if (digest_len != out->size()) {
56         ALOGE("Wrong digest length: %d", digest_len);
57         return false;
58     }
59     return true;
60 }
61 
errno2str()62 std::string errno2str() {
63     char error_msg[512] = { 0 };
64     if (strerror_r(errno, error_msg, sizeof(error_msg)))
65         return std::string();
66     return std::string(error_msg);
67 }
68 
69 #define APLOGI(fmt, ...) ALOGI(fmt ": [%d] %s", __VA_ARGS__, errno, errno2str().c_str())
70 
addr2str(const sockaddr * sa,socklen_t sa_len)71 std::string addr2str(const sockaddr* sa, socklen_t sa_len) {
72     char host_str[NI_MAXHOST] = { 0 };
73     int rv = getnameinfo(sa, sa_len, host_str, sizeof(host_str), nullptr, 0,
74                          NI_NUMERICHOST);
75     if (rv == 0) return std::string(host_str);
76     return std::string();
77 }
78 
make_private_key()79 bssl::UniquePtr<EVP_PKEY> make_private_key() {
80     bssl::UniquePtr<BIGNUM> e(BN_new());
81     if (!e) {
82         ALOGE("BN_new failed");
83         return nullptr;
84     }
85     if (!BN_set_word(e.get(), RSA_F4)) {
86         ALOGE("BN_set_word failed");
87         return nullptr;
88     }
89 
90     bssl::UniquePtr<RSA> rsa(RSA_new());
91     if (!rsa) {
92         ALOGE("RSA_new failed");
93         return nullptr;
94     }
95     if (!RSA_generate_key_ex(rsa.get(), 2048, e.get(), NULL)) {
96         ALOGE("RSA_generate_key_ex failed");
97         return nullptr;
98     }
99 
100     bssl::UniquePtr<EVP_PKEY> privkey(EVP_PKEY_new());
101     if (!privkey) {
102         ALOGE("EVP_PKEY_new failed");
103         return nullptr;
104     }
105     if(!EVP_PKEY_assign_RSA(privkey.get(), rsa.get())) {
106         ALOGE("EVP_PKEY_assign_RSA failed");
107         return nullptr;
108     }
109 
110     // |rsa| is now owned by |privkey|, so no need to free it.
111     rsa.release();
112     return privkey;
113 }
114 
make_cert(EVP_PKEY * privkey,EVP_PKEY * parent_key)115 bssl::UniquePtr<X509> make_cert(EVP_PKEY* privkey, EVP_PKEY* parent_key) {
116     bssl::UniquePtr<X509> cert(X509_new());
117     if (!cert) {
118         ALOGE("X509_new failed");
119         return nullptr;
120     }
121 
122     ASN1_INTEGER_set(X509_get_serialNumber(cert.get()), 1);
123 
124     // Set one hour expiration.
125     X509_gmtime_adj(X509_get_notBefore(cert.get()), 0);
126     X509_gmtime_adj(X509_get_notAfter(cert.get()), 60 * 60);
127 
128     X509_set_pubkey(cert.get(), privkey);
129 
130     if (!X509_sign(cert.get(), parent_key, EVP_sha256())) {
131         ALOGE("X509_sign failed");
132         return nullptr;
133     }
134 
135     return cert;
136 }
137 
138 }
139 
140 namespace test {
141 
startServer()142 bool DnsTlsFrontend::startServer() {
143     SSL_load_error_strings();
144     OpenSSL_add_ssl_algorithms();
145 
146     ctx_.reset(SSL_CTX_new(TLS_server_method()));
147     if (!ctx_) {
148         ALOGE("SSL context creation failed");
149         return false;
150     }
151 
152     SSL_CTX_set_ecdh_auto(ctx_.get(), 1);
153 
154     // Make certificate chain
155     std::vector<bssl::UniquePtr<EVP_PKEY>> keys(chain_length_);
156     for (int i = 0; i < chain_length_; ++i) {
157         keys[i] = make_private_key();
158     }
159     std::vector<bssl::UniquePtr<X509>> certs(chain_length_);
160     for (int i = 0; i < chain_length_; ++i) {
161         int next = std::min(i + 1, chain_length_ - 1);
162         certs[i] = make_cert(keys[i].get(), keys[next].get());
163     }
164 
165     // Install certificate chain.
166     if (SSL_CTX_use_certificate(ctx_.get(), certs[0].get()) <= 0) {
167         ALOGE("SSL_CTX_use_certificate failed");
168         return false;
169     }
170     if (SSL_CTX_use_PrivateKey(ctx_.get(), keys[0].get()) <= 0 ) {
171         ALOGE("SSL_CTX_use_PrivateKey failed");
172         return false;
173     }
174     for (int i = 1; i < chain_length_; ++i) {
175         if (SSL_CTX_add1_chain_cert(ctx_.get(), certs[i].get()) != 1) {
176             ALOGE("SSL_CTX_add1_chain_cert failed");
177             return false;
178         }
179     }
180 
181     // Report the fingerprint of the "middle" cert.  For N = 2, this is the root.
182     int fp_index = chain_length_ / 2;
183     if (!getSPKIDigest(certs[fp_index].get(), &fingerprint_)) {
184         ALOGE("getSPKIDigest failed");
185         return false;
186     }
187 
188     // Set up TCP server socket for clients.
189     addrinfo frontend_ai_hints{
190         .ai_family = AF_UNSPEC,
191         .ai_socktype = SOCK_STREAM,
192         .ai_flags = AI_PASSIVE
193     };
194     addrinfo* frontend_ai_res;
195     int rv = getaddrinfo(listen_address_.c_str(), listen_service_.c_str(),
196                          &frontend_ai_hints, &frontend_ai_res);
197     if (rv) {
198         ALOGE("frontend getaddrinfo(%s, %s) failed: %s", listen_address_.c_str(),
199             listen_service_.c_str(), gai_strerror(rv));
200         return false;
201     }
202 
203     int s = -1;
204     for (const addrinfo* ai = frontend_ai_res ; ai ; ai = ai->ai_next) {
205         s = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol);
206         if (s < 0) continue;
207         enableSockopt(s, SOL_SOCKET, SO_REUSEPORT);
208         enableSockopt(s, SOL_SOCKET, SO_REUSEADDR);
209         if (bind(s, ai->ai_addr, ai->ai_addrlen)) {
210             APLOGI("bind failed for socket %d", s);
211             close(s);
212             s = -1;
213             continue;
214         }
215         std::string host_str = addr2str(ai->ai_addr, ai->ai_addrlen);
216         ALOGI("bound to TCP %s:%s", host_str.c_str(), listen_service_.c_str());
217         break;
218     }
219     freeaddrinfo(frontend_ai_res);
220     if (s < 0) {
221         ALOGE("server socket creation failed");
222         return false;
223     }
224 
225     if (listen(s, 1) < 0) {
226         ALOGE("listen failed");
227         return false;
228     }
229 
230     socket_ = s;
231 
232     // Set up UDP client socket to backend.
233     addrinfo backend_ai_hints{
234         .ai_family = AF_UNSPEC,
235         .ai_socktype = SOCK_DGRAM
236     };
237     addrinfo* backend_ai_res;
238     rv = getaddrinfo(backend_address_.c_str(), backend_service_.c_str(),
239                          &backend_ai_hints, &backend_ai_res);
240     if (rv) {
241         ALOGE("backend getaddrinfo(%s, %s) failed: %s", listen_address_.c_str(),
242             listen_service_.c_str(), gai_strerror(rv));
243         return false;
244     }
245     backend_socket_ = socket(backend_ai_res->ai_family, backend_ai_res->ai_socktype,
246         backend_ai_res->ai_protocol);
247     if (backend_socket_ < 0) {
248         ALOGE("backend socket creation failed");
249         return false;
250     }
251     connect(backend_socket_, backend_ai_res->ai_addr, backend_ai_res->ai_addrlen);
252     freeaddrinfo(backend_ai_res);
253 
254     {
255         std::lock_guard<std::mutex> lock(update_mutex_);
256         handler_thread_ = std::thread(&DnsTlsFrontend::requestHandler, this);
257     }
258     ALOGI("server started successfully");
259     return true;
260 }
261 
requestHandler()262 void DnsTlsFrontend::requestHandler() {
263     ALOGD("Request handler started");
264     struct pollfd fds[1] = {{ .fd = socket_, .events = POLLIN }};
265 
266     while (!terminate_) {
267         int poll_code = poll(fds, 1, 10 /* ms */);
268         if (poll_code == 0) {
269             // Timeout.  Poll again.
270             continue;
271         } else if (poll_code < 0) {
272             ALOGW("Poll failed with error %d", poll_code);
273             // Error.
274             break;
275         }
276         sockaddr_storage addr;
277         socklen_t len = sizeof(addr);
278 
279         ALOGD("Trying to accept a client");
280         int client = accept(socket_, reinterpret_cast<sockaddr*>(&addr), &len);
281         ALOGD("Got client socket %d", client);
282         if (client < 0) {
283             // Stop
284             break;
285         }
286 
287         bssl::UniquePtr<SSL> ssl(SSL_new(ctx_.get()));
288         SSL_set_fd(ssl.get(), client);
289 
290         ALOGD("Doing SSL handshake");
291         bool success = false;
292         if (SSL_accept(ssl.get()) <= 0) {
293             ALOGI("SSL negotiation failure");
294         } else {
295             ALOGD("SSL handshake complete");
296             success = handleOneRequest(ssl.get());
297         }
298 
299         close(client);
300 
301         if (success) {
302             // Increment queries_ as late as possible, because it represents
303             // a query that is fully processed, and the response returned to the
304             // client, including cleanup actions.
305             ++queries_;
306         }
307     }
308     ALOGD("Request handler terminating");
309 }
310 
handleOneRequest(SSL * ssl)311 bool DnsTlsFrontend::handleOneRequest(SSL* ssl) {
312     uint8_t queryHeader[2];
313     if (SSL_read(ssl, &queryHeader, 2) != 2) {
314         ALOGI("Not enough header bytes");
315         return false;
316     }
317     const uint16_t qlen = (queryHeader[0] << 8) | queryHeader[1];
318     uint8_t query[qlen];
319     size_t qbytes = 0;
320     while (qbytes < qlen) {
321         int ret = SSL_read(ssl, query + qbytes, qlen - qbytes);
322         if (ret <= 0) {
323             ALOGI("Error while reading query");
324             return false;
325         }
326         qbytes += ret;
327     }
328     int sent = send(backend_socket_, query, qlen, 0);
329     if (sent != qlen) {
330         ALOGI("Failed to send query");
331         return false;
332     }
333     const int max_size = 4096;
334     uint8_t recv_buffer[max_size];
335     int rlen = recv(backend_socket_, recv_buffer, max_size, 0);
336     if (rlen <= 0) {
337         ALOGI("Failed to receive response");
338         return false;
339     }
340     uint8_t responseHeader[2];
341     responseHeader[0] = rlen >> 8;
342     responseHeader[1] = rlen;
343     if (SSL_write(ssl, responseHeader, 2) != 2) {
344         ALOGI("Failed to write response header");
345         return false;
346     }
347     if (SSL_write(ssl, recv_buffer, rlen) != rlen) {
348         ALOGI("Failed to write response body");
349         return false;
350     }
351     return true;
352 }
353 
stopServer()354 bool DnsTlsFrontend::stopServer() {
355     std::lock_guard<std::mutex> lock(update_mutex_);
356     if (!running()) {
357         ALOGI("server not running");
358         return false;
359     }
360     if (terminate_) {
361         ALOGI("LOGIC ERROR");
362         return false;
363     }
364     ALOGI("stopping frontend");
365     terminate_ = true;
366     handler_thread_.join();
367     close(socket_);
368     close(backend_socket_);
369     terminate_ = false;
370     socket_ = -1;
371     backend_socket_ = -1;
372     ctx_.reset();
373     fingerprint_.clear();
374     ALOGI("frontend stopped successfully");
375     return true;
376 }
377 
waitForQueries(int number,int timeoutMs) const378 bool DnsTlsFrontend::waitForQueries(int number, int timeoutMs) const {
379     constexpr int intervalMs = 20;
380     int limit = timeoutMs / intervalMs;
381     for (int count = 0; count <= limit; ++count) {
382         bool done = queries_ >= number;
383         // Always sleep at least one more interval after we are done, to wait for
384         // any immediate post-query actions that the client may take (such as
385         // marking this server as reachable during validation).
386         usleep(intervalMs * 1000);
387         if (done) {
388             // For ensuring that calls have sufficient headroom for slow machines
389             ALOGD("Query arrived in %d/%d of allotted time", count, limit);
390             return true;
391         }
392     }
393     return false;
394 }
395 
396 }  // namespace test
397