1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 //! C API for the DoH backend for the Android DnsResolver module.
18 
19 use crate::boot_time::{timeout, BootTime, Duration};
20 use crate::dispatcher::{Command, Dispatcher, Response, ServerInfo};
21 use crate::network::{SocketTagger, ValidationReporter};
22 use base64::{prelude::BASE64_URL_SAFE_NO_PAD, Engine};
23 use futures::FutureExt;
24 use libc::{c_char, int32_t, size_t, ssize_t, uint32_t, uint64_t};
25 use log::{error, warn};
26 use std::ffi::CString;
27 use std::net::{IpAddr, SocketAddr};
28 use std::ops::DerefMut;
29 use std::os::unix::io::RawFd;
30 use std::str::FromStr;
31 use std::sync::{Arc, Mutex};
32 use std::{ptr, slice};
33 use tokio::runtime::Builder;
34 use tokio::sync::oneshot;
35 use tokio::task;
36 use url::Url;
37 
38 pub type ValidationCallback = unsafe extern "C" fn(
39     net_id: uint32_t,
40     success: bool,
41     ip_addr: *const c_char,
42     host: *const c_char,
43 );
44 pub type TagSocketCallback = extern "C" fn(sock: RawFd);
45 
46 #[repr(C)]
47 pub struct FeatureFlags {
48     probe_timeout_ms: uint64_t,
49     idle_timeout_ms: uint64_t,
50     use_session_resumption: bool,
51     enable_early_data: bool,
52 }
53 
wrap_validation_callback(validation_fn: ValidationCallback) -> ValidationReporter54 fn wrap_validation_callback(validation_fn: ValidationCallback) -> ValidationReporter {
55     Arc::new(move |info: &ServerInfo, success: bool| {
56         async move {
57             let (ip_addr, domain) = match (
58                 CString::new(info.peer_addr.ip().to_string()),
59                 CString::new(info.domain.clone().unwrap_or_default()),
60             ) {
61                 (Ok(ip_addr), Ok(domain)) => (ip_addr, domain),
62                 _ => {
63                     error!("validation_callback bad input");
64                     return;
65                 }
66             };
67             let netd_id = info.net_id;
68             // SAFETY: The string pointers are obtained from `CString`, so they must be valid C
69             // strings.
70             task::spawn_blocking(move || unsafe {
71                 validation_fn(netd_id, success, ip_addr.as_ptr(), domain.as_ptr())
72             })
73             .await
74             .unwrap_or_else(|e| warn!("Validation function task failed: {}", e))
75         }
76         .boxed()
77     })
78 }
79 
wrap_tag_socket_callback(tag_socket_fn: TagSocketCallback) -> SocketTagger80 fn wrap_tag_socket_callback(tag_socket_fn: TagSocketCallback) -> SocketTagger {
81     use std::os::unix::io::AsRawFd;
82     Arc::new(move |udp_socket: &std::net::UdpSocket| {
83         let fd = udp_socket.as_raw_fd();
84         async move {
85             task::spawn_blocking(move || {
86                 tag_socket_fn(fd);
87             })
88             .await
89             .unwrap_or_else(|e| warn!("Socket tag function task failed: {}", e))
90         }
91         .boxed()
92     })
93 }
94 
95 pub struct DohDispatcher(Mutex<Dispatcher>);
96 
97 impl DohDispatcher {
lock(&self) -> impl DerefMut<Target = Dispatcher> + '_98     fn lock(&self) -> impl DerefMut<Target = Dispatcher> + '_ {
99         self.0.lock().unwrap()
100     }
101 }
102 
103 const SYSTEM_CERT_PATH: &str = "/system/etc/security/cacerts";
104 
105 /// The return code of doh_query means that there is no answer.
106 pub const DOH_RESULT_INTERNAL_ERROR: ssize_t = -1;
107 /// The return code of doh_query means that query can't be sent.
108 pub const DOH_RESULT_CAN_NOT_SEND: ssize_t = -2;
109 /// The return code of doh_query to indicate that the query timed out.
110 pub const DOH_RESULT_TIMEOUT: ssize_t = -255;
111 
112 /// The error log level.
113 pub const DOH_LOG_LEVEL_ERROR: u32 = 0;
114 /// The warning log level.
115 pub const DOH_LOG_LEVEL_WARN: u32 = 1;
116 /// The info log level.
117 pub const DOH_LOG_LEVEL_INFO: u32 = 2;
118 /// The debug log level.
119 pub const DOH_LOG_LEVEL_DEBUG: u32 = 3;
120 /// The trace log level.
121 pub const DOH_LOG_LEVEL_TRACE: u32 = 4;
122 
123 const DOH_PORT: u16 = 443;
124 
level_from_u32(level: u32) -> Option<log::LevelFilter>125 fn level_from_u32(level: u32) -> Option<log::LevelFilter> {
126     use log::LevelFilter::*;
127     match level {
128         DOH_LOG_LEVEL_ERROR => Some(Error),
129         DOH_LOG_LEVEL_WARN => Some(Warn),
130         DOH_LOG_LEVEL_INFO => Some(Info),
131         DOH_LOG_LEVEL_DEBUG => Some(Debug),
132         DOH_LOG_LEVEL_TRACE => Some(Trace),
133         _ => None,
134     }
135 }
136 
137 /// Performs static initialization for android logger.
138 /// If an invalid level is passed, defaults to logging errors only.
139 /// If called more than once, it will have no effect on subsequent calls.
140 #[no_mangle]
doh_init_logger(level: u32)141 pub extern "C" fn doh_init_logger(level: u32) {
142     let log_level = level_from_u32(level).unwrap_or(log::LevelFilter::Error);
143     android_logger::init_once(android_logger::Config::default().with_max_level(log_level));
144 }
145 
146 /// Set the log level.
147 /// If an invalid level is passed, defaults to logging errors only.
148 #[no_mangle]
doh_set_log_level(level: u32)149 pub extern "C" fn doh_set_log_level(level: u32) {
150     let level_filter = level_from_u32(level).unwrap_or(log::LevelFilter::Error);
151     log::set_max_level(level_filter);
152 }
153 
154 /// Performs the initialization for the DoH engine.
155 /// Creates and returns a DoH engine instance.
156 #[no_mangle]
doh_dispatcher_new( validation_fn: ValidationCallback, tag_socket_fn: TagSocketCallback, ) -> *mut DohDispatcher157 pub extern "C" fn doh_dispatcher_new(
158     validation_fn: ValidationCallback,
159     tag_socket_fn: TagSocketCallback,
160 ) -> *mut DohDispatcher {
161     match Dispatcher::new(
162         wrap_validation_callback(validation_fn),
163         wrap_tag_socket_callback(tag_socket_fn),
164     ) {
165         Ok(c) => Box::into_raw(Box::new(DohDispatcher(Mutex::new(c)))),
166         Err(e) => {
167             error!("doh_dispatcher_new: failed: {:?}", e);
168             ptr::null_mut()
169         }
170     }
171 }
172 
173 /// Deletes a DoH engine created by doh_dispatcher_new().
174 ///
175 /// # Safety
176 ///
177 /// `doh` must be a non-null pointer previously created by `doh_dispatcher_new()`
178 /// and not yet deleted by `doh_dispatcher_delete()`.
179 #[no_mangle]
doh_dispatcher_delete(doh: *mut DohDispatcher)180 pub unsafe extern "C" fn doh_dispatcher_delete(doh: *mut DohDispatcher) {
181     // SAFETY: The caller guarantees that `doh` was created by `doh_dispatcher_new` (which does so
182     // using `Box::into_raw`), and that it hasn't yet been deleted by this function.
183     unsafe { Box::from_raw(doh) }.lock().exit_handler()
184 }
185 
186 /// Probes and stores the DoH server with the given configurations.
187 /// Use the negative errno-style codes as the return value to represent the result.
188 /// # Safety
189 /// `doh` must be a non-null pointer previously created by `doh_dispatcher_new()`
190 /// and not yet deleted by `doh_dispatcher_delete()`.
191 /// `url`, `domain`, `ip_addr`, `cert_path` are null terminated strings.
192 #[no_mangle]
doh_net_new( doh: &DohDispatcher, net_id: uint32_t, url: *const c_char, domain: *const c_char, ip_addr: *const c_char, sk_mark: libc::uint32_t, cert_path: *const c_char, flags: &FeatureFlags, network_type: uint32_t, private_dns_mode: uint32_t, ) -> int32_t193 pub unsafe extern "C" fn doh_net_new(
194     doh: &DohDispatcher,
195     net_id: uint32_t,
196     url: *const c_char,
197     domain: *const c_char,
198     ip_addr: *const c_char,
199     sk_mark: libc::uint32_t,
200     cert_path: *const c_char,
201     flags: &FeatureFlags,
202     network_type: uint32_t,
203     private_dns_mode: uint32_t,
204 ) -> int32_t {
205     // SAFETY: The caller guarantees that these are all valid nul-terminated C strings.
206     let (url, domain, ip_addr, cert_path) = match unsafe {
207         (
208             std::ffi::CStr::from_ptr(url).to_str(),
209             std::ffi::CStr::from_ptr(domain).to_str(),
210             std::ffi::CStr::from_ptr(ip_addr).to_str(),
211             std::ffi::CStr::from_ptr(cert_path).to_str(),
212         )
213     } {
214         (Ok(url), Ok(domain), Ok(ip_addr), Ok(cert_path)) => {
215             if domain.is_empty() {
216                 (url, None, ip_addr.to_string(), None)
217             } else if !cert_path.is_empty() {
218                 (url, Some(domain.to_string()), ip_addr.to_string(), Some(cert_path.to_string()))
219             } else {
220                 (
221                     url,
222                     Some(domain.to_string()),
223                     ip_addr.to_string(),
224                     Some(SYSTEM_CERT_PATH.to_string()),
225                 )
226             }
227         }
228         _ => {
229             error!("bad input"); // Should not happen
230             return -libc::EINVAL;
231         }
232     };
233 
234     let (url, ip_addr) = match (Url::parse(url), IpAddr::from_str(&ip_addr)) {
235         (Ok(url), Ok(ip_addr)) => (url, ip_addr),
236         _ => {
237             error!("bad ip or url"); // Should not happen
238             return -libc::EINVAL;
239         }
240     };
241     let cmd = Command::Probe {
242         info: ServerInfo {
243             net_id,
244             url,
245             peer_addr: SocketAddr::new(ip_addr, DOH_PORT),
246             domain,
247             sk_mark,
248             cert_path,
249             idle_timeout_ms: flags.idle_timeout_ms,
250             use_session_resumption: flags.use_session_resumption,
251             enable_early_data: flags.enable_early_data,
252             network_type,
253             private_dns_mode,
254         },
255         timeout: Duration::from_millis(flags.probe_timeout_ms),
256     };
257     if let Err(e) = doh.lock().send_cmd(cmd) {
258         error!("Failed to send the probe: {:?}", e);
259         return -libc::EPIPE;
260     }
261     0
262 }
263 
264 /// Sends a DNS query via the network associated to the given |net_id| and waits for the response.
265 /// The return code should be either one of the public constant DOH_RESULT_* to indicate the error
266 /// or the size of the answer.
267 /// # Safety
268 /// `doh` must be a non-null pointer previously created by `doh_dispatcher_new()`
269 /// and not yet deleted by `doh_dispatcher_delete()`.
270 /// `dns_query` must point to a buffer at least `dns_query_len` in size.
271 /// `response` must point to a buffer at least `response_len` in size.
272 #[no_mangle]
doh_query( doh: &DohDispatcher, net_id: uint32_t, dns_query: *mut u8, dns_query_len: size_t, response: *mut u8, response_len: size_t, timeout_ms: uint64_t, ) -> ssize_t273 pub unsafe extern "C" fn doh_query(
274     doh: &DohDispatcher,
275     net_id: uint32_t,
276     dns_query: *mut u8,
277     dns_query_len: size_t,
278     response: *mut u8,
279     response_len: size_t,
280     timeout_ms: uint64_t,
281 ) -> ssize_t {
282     // SAFETY: The caller guarantees that `dns_query` is a valid pointer to a buffer of at least
283     // `dns_query_len` items.
284     let q = unsafe { slice::from_raw_parts_mut(dns_query, dns_query_len) };
285 
286     let (resp_tx, resp_rx) = oneshot::channel();
287     let t = Duration::from_millis(timeout_ms);
288     if let Some(expired_time) = BootTime::now().checked_add(t) {
289         let cmd = Command::Query {
290             net_id,
291             base64_query: BASE64_URL_SAFE_NO_PAD.encode(q),
292             expired_time,
293             resp: resp_tx,
294         };
295 
296         if let Err(e) = doh.lock().send_cmd(cmd) {
297             error!("Failed to send the query: {:?}", e);
298             return DOH_RESULT_CAN_NOT_SEND;
299         }
300     } else {
301         error!("Bad timeout parameter: {}", timeout_ms);
302         return DOH_RESULT_CAN_NOT_SEND;
303     }
304 
305     if let Ok(rt) = Builder::new_current_thread().enable_all().build() {
306         let local = task::LocalSet::new();
307         match local.block_on(&rt, async { timeout(t, resp_rx).await }) {
308             Ok(v) => match v {
309                 Ok(v) => match v {
310                     Response::Success { answer } => {
311                         if answer.len() > response_len || answer.len() > isize::MAX as usize {
312                             return DOH_RESULT_INTERNAL_ERROR;
313                         }
314                         // SAFETY: The caller guarantees that response points to a valid buffer at
315                         // least `response_len` long, and we just checked that `answer.len()` is no
316                         // longer than `response_len`.
317                         let response = unsafe { slice::from_raw_parts_mut(response, answer.len()) };
318                         response.copy_from_slice(&answer);
319                         answer.len() as ssize_t
320                     }
321                     rsp => {
322                         error!("Non-successful response: {:?}", rsp);
323                         DOH_RESULT_CAN_NOT_SEND
324                     }
325                 },
326                 Err(e) => {
327                     error!("no result {}", e);
328                     DOH_RESULT_CAN_NOT_SEND
329                 }
330             },
331             Err(e) => {
332                 error!("timeout: {}", e);
333                 DOH_RESULT_TIMEOUT
334             }
335         }
336     } else {
337         DOH_RESULT_CAN_NOT_SEND
338     }
339 }
340 
341 /// Clears the DoH servers associated with the given |netid|.
342 /// # Safety
343 /// `doh` must be a non-null pointer previously created by `doh_dispatcher_new()`
344 /// and not yet deleted by `doh_dispatcher_delete()`.
345 #[no_mangle]
doh_net_delete(doh: &DohDispatcher, net_id: uint32_t)346 pub extern "C" fn doh_net_delete(doh: &DohDispatcher, net_id: uint32_t) {
347     if let Err(e) = doh.lock().send_cmd(Command::Clear { net_id }) {
348         error!("Failed to send the query: {:?}", e);
349     }
350 }
351 
352 #[cfg(test)]
353 mod tests {
354     use super::*;
355 
356     const TEST_NET_ID: u32 = 50;
357     const LOOPBACK_ADDR: &str = "127.0.0.1:443";
358     const LOCALHOST_URL: &str = "https://mylocal.com/dns-query";
359 
success_cb( net_id: uint32_t, success: bool, ip_addr: *const c_char, host: *const c_char, )360     unsafe extern "C" fn success_cb(
361         net_id: uint32_t,
362         success: bool,
363         ip_addr: *const c_char,
364         host: *const c_char,
365     ) {
366         assert!(success);
367         // SAFETY: The caller guarantees that ip_addr and host are valid nul-terminated C strings.
368         unsafe {
369             assert_validation_info(net_id, ip_addr, host);
370         }
371     }
372 
fail_cb( net_id: uint32_t, success: bool, ip_addr: *const c_char, host: *const c_char, )373     unsafe extern "C" fn fail_cb(
374         net_id: uint32_t,
375         success: bool,
376         ip_addr: *const c_char,
377         host: *const c_char,
378     ) {
379         assert!(!success);
380         // SAFETY: The caller guarantees that ip_addr and host are valid nul-terminated C strings.
381         unsafe {
382             assert_validation_info(net_id, ip_addr, host);
383         }
384     }
385 
386     // # Safety
387     // `ip_addr`, `host` are null terminated strings
assert_validation_info( net_id: uint32_t, ip_addr: *const c_char, host: *const c_char, )388     unsafe fn assert_validation_info(
389         net_id: uint32_t,
390         ip_addr: *const c_char,
391         host: *const c_char,
392     ) {
393         assert_eq!(net_id, TEST_NET_ID);
394         // SAFETY: The caller guarantees that `ip_addr` is a valid nul-terminated C string.
395         let ip_addr = unsafe { std::ffi::CStr::from_ptr(ip_addr) }.to_str().unwrap();
396         let expected_addr: SocketAddr = LOOPBACK_ADDR.parse().unwrap();
397         assert_eq!(ip_addr, expected_addr.ip().to_string());
398         // SAFETY: The caller guarantees that `host` is a valid nul-terminated C string.
399         let host = unsafe { std::ffi::CStr::from_ptr(host) }.to_str().unwrap();
400         assert_eq!(host, "");
401     }
402 
403     #[tokio::test]
wrap_validation_callback_converts_correctly()404     async fn wrap_validation_callback_converts_correctly() {
405         let info = ServerInfo {
406             net_id: TEST_NET_ID,
407             url: Url::parse(LOCALHOST_URL).unwrap(),
408             peer_addr: LOOPBACK_ADDR.parse().unwrap(),
409             domain: None,
410             sk_mark: 0,
411             cert_path: None,
412             idle_timeout_ms: 0,
413             use_session_resumption: true,
414             enable_early_data: true,
415             network_type: 2,
416             private_dns_mode: 3,
417         };
418 
419         wrap_validation_callback(success_cb)(&info, true).await;
420         wrap_validation_callback(fail_cb)(&info, false).await;
421     }
422 
tag_socket_cb(raw_fd: RawFd)423     extern "C" fn tag_socket_cb(raw_fd: RawFd) {
424         assert!(raw_fd > 0)
425     }
426 
427     #[tokio::test]
wrap_tag_socket_callback_converts_correctly()428     async fn wrap_tag_socket_callback_converts_correctly() {
429         let sock = std::net::UdpSocket::bind("127.0.0.1:0").unwrap();
430         wrap_tag_socket_callback(tag_socket_cb)(&sock).await;
431     }
432 }
433