1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 //! DoH backend for the Android DnsResolver module.
18 
19 use anyhow::{anyhow, Context, Result};
20 use lazy_static::lazy_static;
21 use libc::{c_char, size_t, ssize_t};
22 use log::{debug, error, info, warn};
23 use quiche::h3;
24 use ring::rand::SecureRandom;
25 use std::collections::HashMap;
26 use std::net::{IpAddr, SocketAddr};
27 use std::os::unix::io::{AsRawFd, RawFd};
28 use std::str::FromStr;
29 use std::sync::Arc;
30 use std::{ptr, slice};
31 use tokio::net::UdpSocket;
32 use tokio::runtime::{Builder, Runtime};
33 use tokio::sync::{mpsc, oneshot};
34 use tokio::task;
35 use tokio::time::Duration;
36 use url::Url;
37 
38 lazy_static! {
39     /// Tokio runtime used to perform doh-handler tasks.
40     static ref RUNTIME_STATIC: Arc<Runtime> = Arc::new(
41         Builder::new_multi_thread()
42             .worker_threads(2)
43             .max_blocking_threads(1)
44             .enable_all()
45             .thread_name("doh-handler")
46             .build()
47             .expect("Failed to create tokio runtime")
48     );
49 }
50 
51 const MAX_BUFFERED_CMD_SIZE: usize = 400;
52 const MAX_INCOMING_BUFFER_SIZE_WHOLE: u64 = 10000000;
53 const MAX_INCOMING_BUFFER_SIZE_EACH: u64 = 1000000;
54 const MAX_CONCURRENT_STREAM_SIZE: u64 = 100;
55 const MAX_DATAGRAM_SIZE: usize = 1350;
56 const MAX_DATAGRAM_SIZE_U64: u64 = 1350;
57 const DOH_PORT: u16 = 443;
58 const QUICHE_IDLE_TIMEOUT_MS: u64 = 180000;
59 const SYSTEM_CERT_PATH: &str = "/system/etc/security/cacerts";
60 
61 type SCID = [u8; quiche::MAX_CONN_ID_LEN];
62 type Query = Vec<u8>;
63 type Response = Vec<u8>;
64 type CmdSender = mpsc::Sender<Command>;
65 type CmdReceiver = mpsc::Receiver<Command>;
66 type QueryResponder = oneshot::Sender<Option<Response>>;
67 
68 #[derive(Debug)]
69 enum Command {
70     DohQuery { query: Query, resp: QueryResponder },
71 }
72 
73 /// Context for a running DoH engine.
74 pub struct DohDispatcher {
75     /// Used to submit queries to the I/O thread.
76     query_sender: CmdSender,
77 
78     join_handle: task::JoinHandle<Result<()>>,
79 }
80 
81 fn make_doh_udp_socket(ip_addr: &str, mark: u32) -> Result<std::net::UdpSocket> {
82     let sock_addr = SocketAddr::new(IpAddr::from_str(&ip_addr)?, DOH_PORT);
83     let bind_addr = match sock_addr {
84         std::net::SocketAddr::V4(_) => "0.0.0.0:0",
85         std::net::SocketAddr::V6(_) => "[::]:0",
86     };
87     let udp_sk = std::net::UdpSocket::bind(bind_addr)?;
88     udp_sk.set_nonblocking(true)?;
89     mark_socket(udp_sk.as_raw_fd(), mark)?;
90     udp_sk.connect(sock_addr)?;
91 
92     debug!("connecting to {:} from {:}", sock_addr, udp_sk.local_addr()?);
93     Ok(udp_sk)
94 }
95 
96 // DoH dispatcher
97 impl DohDispatcher {
98     fn new(
99         url: &str,
100         ip_addr: &str,
101         mark: u32,
102         cert_path: Option<&str>,
103     ) -> Result<Box<DohDispatcher>> {
104         // Setup socket
105         let udp_sk = make_doh_udp_socket(&ip_addr, mark)?;
106         DohDispatcher::new_with_socket(url, ip_addr, mark, cert_path, udp_sk)
107     }
108 
109     fn new_with_socket(
110         url: &str,
111         ip_addr: &str,
112         mark: u32,
113         cert_path: Option<&str>,
114         udp_sk: std::net::UdpSocket,
115     ) -> Result<Box<DohDispatcher>> {
116         let url = Url::parse(&url.to_string())?;
117         if url.domain().is_none() {
118             return Err(anyhow!("no domain"));
119         }
120         // Setup quiche config
121         let config = create_quiche_config(cert_path)?;
122         let h3_config = h3::Config::new()?;
123         let mut scid = [0; quiche::MAX_CONN_ID_LEN];
124         ring::rand::SystemRandom::new().fill(&mut scid[..]).context("failed to generate scid")?;
125 
126         let (cmd_sender, cmd_receiver) = mpsc::channel::<Command>(MAX_BUFFERED_CMD_SIZE);
127         debug!(
128             "Creating a doh handler task: url={}, ip_addr={}, mark={:#x}, scid {:x?}",
129             url, ip_addr, mark, &scid
130         );
131         let join_handle =
132             RUNTIME_STATIC.spawn(doh_handler(url, udp_sk, config, h3_config, scid, cmd_receiver));
133         Ok(Box::new(DohDispatcher { query_sender: cmd_sender, join_handle }))
134     }
135 
136     fn query(&self, cmd: Command) -> Result<()> {
137         self.query_sender.blocking_send(cmd)?;
138         Ok(())
139     }
140 
141     fn abort_handler(&self) {
142         self.join_handle.abort();
143     }
144 }
145 
146 async fn doh_handler(
147     url: url::Url,
148     udp_sk: std::net::UdpSocket,
149     mut config: quiche::Config,
150     h3_config: h3::Config,
151     scid: SCID,
152     mut rx: CmdReceiver,
153 ) -> Result<()> {
154     debug!("doh_handler: url={:?}", url);
155 
156     let sk = UdpSocket::from_std(udp_sk)?;
157     let mut conn = quiche::connect(url.domain(), &scid, &mut config)?;
158     let mut quic_conn_start = std::time::Instant::now();
159     let mut h3_conn: Option<h3::Connection> = None;
160     let mut is_idle = false;
161     let mut buf = [0; 65535];
162 
163     let mut query_map = HashMap::<u64, QueryResponder>::new();
164     let mut pending_cmds: Vec<Command> = Vec::new();
165 
166     let mut ts = Duration::from_millis(QUICHE_IDLE_TIMEOUT_MS);
167     loop {
168         tokio::select! {
169             size = sk.recv(&mut buf) => {
170                 debug!("recv {:?} ", size);
171                 match size {
172                     Ok(size) => {
173                         let processed = match conn.recv(&mut buf[..size]) {
174                             Ok(l) => l,
175                             Err(e) => {
176                                 error!("quic recv failed: {:?}", e);
177                                 continue;
178                             }
179                         };
180                         debug!("processed {} bytes", processed);
181                     },
182                     Err(e) => {
183                         error!("socket recv failed: {:?}", e);
184                         continue;
185                     },
186                 };
187             }
188             Some(cmd) = rx.recv() => {
189                 debug!("recv {:?}", cmd);
190                 pending_cmds.push(cmd);
191             }
192             _ = tokio::time::sleep(ts) => {
193                 conn.on_timeout();
194                 debug!("quic connection timeout");
195             }
196         }
197         if conn.is_closed() {
198             // Show connection statistics after it's closed
199             if !is_idle {
200                 info!("connection closed, {:?}, {:?}", quic_conn_start.elapsed(), conn.stats());
201                 is_idle = true;
202                 if !conn.is_established() {
203                     error!("connection handshake timed out after {:?}", quic_conn_start.elapsed());
204                 }
205             }
206 
207             // If there is any pending query, resume the quic connection.
208             if !pending_cmds.is_empty() {
209                 info!("still some pending queries but connection is not avaiable, resume it");
210                 conn = quiche::connect(url.domain(), &scid, &mut config)?;
211                 quic_conn_start = std::time::Instant::now();
212                 h3_conn = None;
213                 is_idle = false;
214             }
215         }
216 
217         // Create a new HTTP/3 connection once the QUIC connection is established.
218         if conn.is_established() && h3_conn.is_none() {
219             info!("quic ready, creating h3 conn");
220             h3_conn = Some(quiche::h3::Connection::with_transport(&mut conn, &h3_config)?);
221         }
222         // Try to receive query answers from h3 connection.
223         if let Some(h3) = h3_conn.as_mut() {
224             recv_query(h3, &mut conn, &mut query_map).await;
225         }
226 
227         // Update the next timeout of quic connection.
228         ts = conn.timeout().unwrap_or_else(|| Duration::from_millis(QUICHE_IDLE_TIMEOUT_MS));
229         info!("next connection timouts  {:?}", ts);
230 
231         // Process the pending queries
232         while !pending_cmds.is_empty() && conn.is_established() {
233             if let Some(cmd) = pending_cmds.pop() {
234                 match cmd {
235                     Command::DohQuery { query, resp } => {
236                         match send_dns_query(&query, &url, &mut h3_conn, &mut conn) {
237                             Ok(stream_id) => {
238                                 query_map.insert(stream_id, resp);
239                             }
240                             Err(e) => {
241                                 info!("failed to send query {}", e);
242                                 pending_cmds.push(Command::DohQuery { query, resp });
243                             }
244                         }
245                     }
246                 }
247             }
248         }
249         flush_tx(&sk, &mut conn).await.unwrap_or_else(|e| {
250             error!("flush error {:?} ", e);
251         });
252     }
253 }
254 
255 fn send_dns_query(
256     query: &[u8],
257     url: &url::Url,
258     h3_conn: &mut Option<quiche::h3::Connection>,
259     mut conn: &mut quiche::Connection,
260 ) -> Result<u64> {
261     let h3_conn = h3_conn.as_mut().ok_or_else(|| anyhow!("h3 conn isn't available"))?;
262 
263     let mut path = String::from(url.path());
264     path.push_str("?dns=");
265     path.push_str(std::str::from_utf8(&query)?);
266     let _req = vec![
267         quiche::h3::Header::new(":method", "GET"),
268         quiche::h3::Header::new(":scheme", "https"),
269         quiche::h3::Header::new(
270             ":authority",
271             url.host_str().ok_or_else(|| anyhow!("failed to get host"))?,
272         ),
273         quiche::h3::Header::new(":path", &path),
274         quiche::h3::Header::new("user-agent", "quiche"),
275         quiche::h3::Header::new("accept", "application/dns-message"),
276         // TODO: is content-length required?
277     ];
278 
279     Ok(h3_conn.send_request(&mut conn, &_req, false /*fin*/)?)
280 }
281 
282 async fn recv_query(
283     h3_conn: &mut h3::Connection,
284     mut conn: &mut quiche::Connection,
285     map: &mut HashMap<u64, QueryResponder>,
286 ) {
287     // Process HTTP/3 events.
288     let mut buf = [0; MAX_DATAGRAM_SIZE];
289     loop {
290         match h3_conn.poll(&mut conn) {
291             Ok((stream_id, quiche::h3::Event::Headers { list, has_body })) => {
292                 info!(
293                     "got response headers {:?} on stream id {} has_body {}",
294                     list, stream_id, has_body
295                 );
296             }
297             Ok((stream_id, quiche::h3::Event::Data)) => {
298                 debug!("quiche::h3::Event::Data");
299                 if let Ok(read) = h3_conn.recv_body(&mut conn, stream_id, &mut buf) {
300                     info!(
301                         "got {} bytes of response data on stream {}: {:x?}",
302                         read,
303                         stream_id,
304                         &buf[..read]
305                     );
306                     if let Some(resp) = map.remove(&stream_id) {
307                         resp.send(Some(buf[..read].to_vec())).unwrap_or_else(|e| {
308                             warn!("the receiver dropped {:?}", e);
309                         });
310                     }
311                 }
312             }
313             Ok((_stream_id, quiche::h3::Event::Finished)) => {
314                 debug!("quiche::h3::Event::Finished");
315             }
316             Ok((_stream_id, quiche::h3::Event::Datagram)) => {
317                 debug!("quiche::h3::Event::Datagram");
318             }
319             Ok((_stream_id, quiche::h3::Event::GoAway)) => {
320                 debug!("quiche::h3::Event::GoAway");
321             }
322             Err(quiche::h3::Error::Done) => {
323                 debug!("quiche::h3::Error::Done");
324                 break;
325             }
326             Err(e) => {
327                 error!("HTTP/3 processing failed: {:?}", e);
328                 break;
329             }
330         }
331     }
332 }
333 
334 async fn flush_tx(sk: &UdpSocket, conn: &mut quiche::Connection) -> Result<()> {
335     let mut out = [0; MAX_DATAGRAM_SIZE];
336     loop {
337         let write = match conn.send(&mut out) {
338             Ok(v) => v,
339             Err(quiche::Error::Done) => {
340                 debug!("done writing");
341                 break;
342             }
343             Err(e) => {
344                 conn.close(false, 0x1, b"fail").ok();
345                 return Err(anyhow::Error::new(e));
346             }
347         };
348         sk.send(&out[..write]).await?;
349         debug!("written {}", write);
350     }
351     Ok(())
352 }
353 
354 fn create_quiche_config(cert_path: Option<&str>) -> Result<quiche::Config> {
355     let mut config = quiche::Config::new(quiche::PROTOCOL_VERSION)?;
356     config.set_application_protos(h3::APPLICATION_PROTOCOL)?;
357     config.verify_peer(true);
358     config.load_verify_locations_from_directory(cert_path.unwrap_or(SYSTEM_CERT_PATH))?;
359     // Some of these configs are necessary, or the server can't respond the HTTP/3 request.
360     config.set_max_idle_timeout(QUICHE_IDLE_TIMEOUT_MS);
361     config.set_max_udp_payload_size(MAX_DATAGRAM_SIZE_U64);
362     config.set_initial_max_data(MAX_INCOMING_BUFFER_SIZE_WHOLE);
363     config.set_initial_max_stream_data_bidi_local(MAX_INCOMING_BUFFER_SIZE_EACH);
364     config.set_initial_max_stream_data_bidi_remote(MAX_INCOMING_BUFFER_SIZE_EACH);
365     config.set_initial_max_stream_data_uni(MAX_INCOMING_BUFFER_SIZE_EACH);
366     config.set_initial_max_streams_bidi(MAX_CONCURRENT_STREAM_SIZE);
367     config.set_initial_max_streams_uni(MAX_CONCURRENT_STREAM_SIZE);
368     config.set_disable_active_migration(true);
369     Ok(config)
370 }
371 
372 fn mark_socket(fd: RawFd, mark: u32) -> Result<()> {
373     // libc::setsockopt is a wrapper function calling into bionic setsockopt.
374     // Both fd and mark are valid, which makes the function call mostly safe.
375     if unsafe {
376         libc::setsockopt(
377             fd,
378             libc::SOL_SOCKET,
379             libc::SO_MARK,
380             &mark as *const _ as *const libc::c_void,
381             std::mem::size_of::<u32>() as libc::socklen_t,
382         )
383     } == 0
384     {
385         Ok(())
386     } else {
387         Err(anyhow::Error::new(std::io::Error::last_os_error()))
388     }
389 }
390 
391 /// Performs static initialization fo the DoH engine.
392 #[no_mangle]
393 pub extern "C" fn doh_init() -> *const c_char {
394     android_logger::init_once(android_logger::Config::default().with_min_level(log::Level::Trace));
395     static VERSION: &str = "1.0\0";
396     VERSION.as_ptr() as *const c_char
397 }
398 
399 /// Creates and returns a DoH engine instance.
400 /// The returned object must be freed with doh_delete().
401 /// # Safety
402 /// All the pointer args are null terminated strings.
403 #[no_mangle]
404 pub unsafe extern "C" fn doh_new(
405     url: *const c_char,
406     ip_addr: *const c_char,
407     mark: libc::uint32_t,
408     cert_path: *const c_char,
409 ) -> *mut DohDispatcher {
410     let (url, ip_addr, cert_path) = match (
411         std::ffi::CStr::from_ptr(url).to_str(),
412         std::ffi::CStr::from_ptr(ip_addr).to_str(),
413         std::ffi::CStr::from_ptr(cert_path).to_str(),
414     ) {
415         (Ok(url), Ok(ip_addr), Ok(cert_path)) => {
416             if !cert_path.is_empty() {
417                 (url, ip_addr, Some(cert_path))
418             } else {
419                 (url, ip_addr, None)
420             }
421         }
422         _ => {
423             error!("bad input");
424             return ptr::null_mut();
425         }
426     };
427     match DohDispatcher::new(url, ip_addr, mark, cert_path) {
428         Ok(c) => Box::into_raw(c),
429         Err(e) => {
430             error!("doh_new: failed: {:?}", e);
431             ptr::null_mut()
432         }
433     }
434 }
435 
436 /// Deletes a DoH engine created by doh_new().
437 /// # Safety
438 /// `doh` must be a non-null pointer previously created by `doh_new()`
439 /// and not yet deleted by `doh_delete()`.
440 #[no_mangle]
441 pub unsafe extern "C" fn doh_delete(doh: *mut DohDispatcher) {
442     Box::from_raw(doh).abort_handler()
443 }
444 
445 /// Sends a DNS query and waits for the response.
446 /// # Safety
447 /// `doh` must be a non-null pointer previously created by `doh_new()`
448 /// and not yet deleted by `doh_delete()`.
449 /// `query` must point to a buffer at least `query_len` in size.
450 /// `response` must point to a buffer at least `response_len` in size.
451 #[no_mangle]
452 pub unsafe extern "C" fn doh_query(
453     doh: &mut DohDispatcher,
454     query: *mut u8,
455     query_len: size_t,
456     response: *mut u8,
457     response_len: size_t,
458 ) -> ssize_t {
459     let q = slice::from_raw_parts_mut(query, query_len);
460     let (resp_tx, resp_rx) = oneshot::channel();
461     let cmd = Command::DohQuery { query: q.to_vec(), resp: resp_tx };
462     if let Err(e) = doh.query(cmd) {
463         error!("Failed to send the query: {:?}", e);
464         return -1;
465     }
466     match RUNTIME_STATIC.block_on(resp_rx) {
467         Ok(value) => {
468             if let Some(resp) = value {
469                 if resp.len() > response_len || resp.len() > isize::MAX as usize {
470                     return -1;
471                 }
472                 let response = slice::from_raw_parts_mut(response, resp.len());
473                 response.copy_from_slice(&resp);
474                 return resp.len() as ssize_t;
475             }
476             -1
477         }
478         Err(e) => {
479             error!("no result {}", e);
480             -1
481         }
482     }
483 }
484 
485 #[cfg(test)]
486 mod tests {
487     use super::*;
488     use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
489 
490     const TEST_MARK: u32 = 0xD0033;
491     const LOOPBACK_ADDR: &str = "127.0.0.1";
492 
493     #[test]
494     fn dohdispatcher_invalid_args() {
495         let test_args = [
496             // Bad url
497             ("foo", "bar"),
498             ("https://1", "bar"),
499             ("https:/", "bar"),
500             // Bad ip
501             ("https://dns.google", "bar"),
502             ("https://dns.google", "256.256.256.256"),
503         ];
504         for args in &test_args {
505             assert!(
506                 DohDispatcher::new(args.0, args.1, 0, None).is_err(),
507                 "doh dispatcher should not be created"
508             )
509         }
510     }
511 
512     #[test]
513     fn make_doh_udp_socket() {
514         // Bad ip
515         for ip in &["foo", "1", "333.333.333.333"] {
516             assert!(super::make_doh_udp_socket(ip, 0).is_err(), "udp socket should not be created");
517         }
518         // Make a socket connecting to loopback with a test mark.
519         let sk = super::make_doh_udp_socket(LOOPBACK_ADDR, TEST_MARK).unwrap();
520         // Check if the socket is connected to loopback.
521         assert_eq!(
522             sk.peer_addr().unwrap(),
523             SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), DOH_PORT))
524         );
525 
526         // Check if the socket mark is correct.
527         let fd: RawFd = sk.as_raw_fd();
528 
529         let mut mark: u32 = 50;
530         let mut size = std::mem::size_of::<u32>() as libc::socklen_t;
531         unsafe {
532             // Safety: fd must be valid.
533             assert_eq!(
534                 libc::getsockopt(
535                     fd,
536                     libc::SOL_SOCKET,
537                     libc::SO_MARK,
538                     &mut mark as *mut _ as *mut libc::c_void,
539                     &mut size as *mut _ as *mut libc::socklen_t,
540                 ),
541                 0
542             );
543         }
544         assert_eq!(mark, TEST_MARK);
545 
546         // Check if the socket is non-blocking.
547         unsafe {
548             // Safety: fd must be valid.
549             assert_eq!(libc::fcntl(fd, libc::F_GETFL, 0) & libc::O_NONBLOCK, libc::O_NONBLOCK);
550         }
551     }
552 
553     #[test]
554     fn create_quiche_config() {
555         assert!(
556             super::create_quiche_config(None).is_ok(),
557             "quiche config without cert creating failed"
558         );
559         assert!(
560             super::create_quiche_config(Some("data/local/tmp/")).is_ok(),
561             "quiche config with cert creating failed"
562         );
563     }
564 
565     const GOOGLE_DNS_URL: &str = "https://dns.google/dns-query";
566     const GOOGLE_DNS_IP: &str = "8.8.8.8";
567     // qtype: A, qname: www.example.com
568     const SAMPLE_QUERY: &str = "q80BAAABAAAAAAAAA3d3dwdleGFtcGxlA2NvbQAAAQAB";
569     #[test]
570     fn close_doh() {
571         let udp_sk = super::make_doh_udp_socket(LOOPBACK_ADDR, TEST_MARK).unwrap();
572         let doh =
573             DohDispatcher::new_with_socket(GOOGLE_DNS_URL, GOOGLE_DNS_IP, 0, None, udp_sk).unwrap();
574         let (resp_tx, resp_rx) = oneshot::channel();
575         let cmd = Command::DohQuery { query: SAMPLE_QUERY.as_bytes().to_vec(), resp: resp_tx };
576         assert!(doh.query(cmd).is_ok(), "Send query failed");
577         doh.abort_handler();
578         assert!(RUNTIME_STATIC.block_on(resp_rx).is_err(), "channel should already be closed");
579     }
580 
581     #[test]
582     fn doh_init() {
583         unsafe {
584             // Safety: the returned pointer of doh_init() must be a null terminated string.
585             assert_eq!(std::ffi::CStr::from_ptr(super::doh_init()).to_str().unwrap(), "1.0");
586         }
587     }
588 }
589