1 /*
2 * Copyright (C) 2021 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 //! C API for the DoH backend for the Android DnsResolver module.
18
19 use crate::boot_time::{timeout, BootTime, Duration};
20 use crate::dispatcher::{Command, Dispatcher, Response, ServerInfo};
21 use crate::network::{SocketTagger, ValidationReporter};
22 use base64::{prelude::BASE64_URL_SAFE_NO_PAD, Engine};
23 use futures::FutureExt;
24 use libc::{c_char, int32_t, size_t, ssize_t, uint32_t, uint64_t};
25 use log::{error, warn};
26 use std::ffi::CString;
27 use std::net::{IpAddr, SocketAddr};
28 use std::ops::DerefMut;
29 use std::os::unix::io::RawFd;
30 use std::str::FromStr;
31 use std::sync::{Arc, Mutex};
32 use std::{ptr, slice};
33 use tokio::runtime::Builder;
34 use tokio::sync::oneshot;
35 use tokio::task;
36 use url::Url;
37
38 pub type ValidationCallback = unsafe extern "C" fn(
39 net_id: uint32_t,
40 success: bool,
41 ip_addr: *const c_char,
42 host: *const c_char,
43 );
44 pub type TagSocketCallback = extern "C" fn(sock: RawFd);
45
46 #[repr(C)]
47 pub struct FeatureFlags {
48 probe_timeout_ms: uint64_t,
49 idle_timeout_ms: uint64_t,
50 use_session_resumption: bool,
51 enable_early_data: bool,
52 }
53
wrap_validation_callback(validation_fn: ValidationCallback) -> ValidationReporter54 fn wrap_validation_callback(validation_fn: ValidationCallback) -> ValidationReporter {
55 Arc::new(move |info: &ServerInfo, success: bool| {
56 async move {
57 let (ip_addr, domain) = match (
58 CString::new(info.peer_addr.ip().to_string()),
59 CString::new(info.domain.clone().unwrap_or_default()),
60 ) {
61 (Ok(ip_addr), Ok(domain)) => (ip_addr, domain),
62 _ => {
63 error!("validation_callback bad input");
64 return;
65 }
66 };
67 let netd_id = info.net_id;
68 // SAFETY: The string pointers are obtained from `CString`, so they must be valid C
69 // strings.
70 task::spawn_blocking(move || unsafe {
71 validation_fn(netd_id, success, ip_addr.as_ptr(), domain.as_ptr())
72 })
73 .await
74 .unwrap_or_else(|e| warn!("Validation function task failed: {}", e))
75 }
76 .boxed()
77 })
78 }
79
wrap_tag_socket_callback(tag_socket_fn: TagSocketCallback) -> SocketTagger80 fn wrap_tag_socket_callback(tag_socket_fn: TagSocketCallback) -> SocketTagger {
81 use std::os::unix::io::AsRawFd;
82 Arc::new(move |udp_socket: &std::net::UdpSocket| {
83 let fd = udp_socket.as_raw_fd();
84 async move {
85 task::spawn_blocking(move || {
86 tag_socket_fn(fd);
87 })
88 .await
89 .unwrap_or_else(|e| warn!("Socket tag function task failed: {}", e))
90 }
91 .boxed()
92 })
93 }
94
95 pub struct DohDispatcher(Mutex<Dispatcher>);
96
97 impl DohDispatcher {
lock(&self) -> impl DerefMut<Target = Dispatcher> + '_98 fn lock(&self) -> impl DerefMut<Target = Dispatcher> + '_ {
99 self.0.lock().unwrap()
100 }
101 }
102
103 const SYSTEM_CERT_PATH: &str = "/system/etc/security/cacerts";
104
105 /// The return code of doh_query means that there is no answer.
106 pub const DOH_RESULT_INTERNAL_ERROR: ssize_t = -1;
107 /// The return code of doh_query means that query can't be sent.
108 pub const DOH_RESULT_CAN_NOT_SEND: ssize_t = -2;
109 /// The return code of doh_query to indicate that the query timed out.
110 pub const DOH_RESULT_TIMEOUT: ssize_t = -255;
111
112 /// The error log level.
113 pub const DOH_LOG_LEVEL_ERROR: u32 = 0;
114 /// The warning log level.
115 pub const DOH_LOG_LEVEL_WARN: u32 = 1;
116 /// The info log level.
117 pub const DOH_LOG_LEVEL_INFO: u32 = 2;
118 /// The debug log level.
119 pub const DOH_LOG_LEVEL_DEBUG: u32 = 3;
120 /// The trace log level.
121 pub const DOH_LOG_LEVEL_TRACE: u32 = 4;
122
123 const DOH_PORT: u16 = 443;
124
level_from_u32(level: u32) -> Option<log::LevelFilter>125 fn level_from_u32(level: u32) -> Option<log::LevelFilter> {
126 use log::LevelFilter::*;
127 match level {
128 DOH_LOG_LEVEL_ERROR => Some(Error),
129 DOH_LOG_LEVEL_WARN => Some(Warn),
130 DOH_LOG_LEVEL_INFO => Some(Info),
131 DOH_LOG_LEVEL_DEBUG => Some(Debug),
132 DOH_LOG_LEVEL_TRACE => Some(Trace),
133 _ => None,
134 }
135 }
136
137 /// Performs static initialization for android logger.
138 /// If an invalid level is passed, defaults to logging errors only.
139 /// If called more than once, it will have no effect on subsequent calls.
140 #[no_mangle]
doh_init_logger(level: u32)141 pub extern "C" fn doh_init_logger(level: u32) {
142 let log_level = level_from_u32(level).unwrap_or(log::LevelFilter::Error);
143 android_logger::init_once(android_logger::Config::default().with_max_level(log_level));
144 }
145
146 /// Set the log level.
147 /// If an invalid level is passed, defaults to logging errors only.
148 #[no_mangle]
doh_set_log_level(level: u32)149 pub extern "C" fn doh_set_log_level(level: u32) {
150 let level_filter = level_from_u32(level).unwrap_or(log::LevelFilter::Error);
151 log::set_max_level(level_filter);
152 }
153
154 /// Performs the initialization for the DoH engine.
155 /// Creates and returns a DoH engine instance.
156 #[no_mangle]
doh_dispatcher_new( validation_fn: ValidationCallback, tag_socket_fn: TagSocketCallback, ) -> *mut DohDispatcher157 pub extern "C" fn doh_dispatcher_new(
158 validation_fn: ValidationCallback,
159 tag_socket_fn: TagSocketCallback,
160 ) -> *mut DohDispatcher {
161 match Dispatcher::new(
162 wrap_validation_callback(validation_fn),
163 wrap_tag_socket_callback(tag_socket_fn),
164 ) {
165 Ok(c) => Box::into_raw(Box::new(DohDispatcher(Mutex::new(c)))),
166 Err(e) => {
167 error!("doh_dispatcher_new: failed: {:?}", e);
168 ptr::null_mut()
169 }
170 }
171 }
172
173 /// Deletes a DoH engine created by doh_dispatcher_new().
174 ///
175 /// # Safety
176 ///
177 /// `doh` must be a non-null pointer previously created by `doh_dispatcher_new()`
178 /// and not yet deleted by `doh_dispatcher_delete()`.
179 #[no_mangle]
doh_dispatcher_delete(doh: *mut DohDispatcher)180 pub unsafe extern "C" fn doh_dispatcher_delete(doh: *mut DohDispatcher) {
181 // SAFETY: The caller guarantees that `doh` was created by `doh_dispatcher_new` (which does so
182 // using `Box::into_raw`), and that it hasn't yet been deleted by this function.
183 unsafe { Box::from_raw(doh) }.lock().exit_handler()
184 }
185
186 /// Probes and stores the DoH server with the given configurations.
187 /// Use the negative errno-style codes as the return value to represent the result.
188 /// # Safety
189 /// `doh` must be a non-null pointer previously created by `doh_dispatcher_new()`
190 /// and not yet deleted by `doh_dispatcher_delete()`.
191 /// `url`, `domain`, `ip_addr`, `cert_path` are null terminated strings.
192 #[no_mangle]
doh_net_new( doh: &DohDispatcher, net_id: uint32_t, url: *const c_char, domain: *const c_char, ip_addr: *const c_char, sk_mark: libc::uint32_t, cert_path: *const c_char, flags: &FeatureFlags, network_type: uint32_t, private_dns_mode: uint32_t, ) -> int32_t193 pub unsafe extern "C" fn doh_net_new(
194 doh: &DohDispatcher,
195 net_id: uint32_t,
196 url: *const c_char,
197 domain: *const c_char,
198 ip_addr: *const c_char,
199 sk_mark: libc::uint32_t,
200 cert_path: *const c_char,
201 flags: &FeatureFlags,
202 network_type: uint32_t,
203 private_dns_mode: uint32_t,
204 ) -> int32_t {
205 // SAFETY: The caller guarantees that these are all valid nul-terminated C strings.
206 let (url, domain, ip_addr, cert_path) = match unsafe {
207 (
208 std::ffi::CStr::from_ptr(url).to_str(),
209 std::ffi::CStr::from_ptr(domain).to_str(),
210 std::ffi::CStr::from_ptr(ip_addr).to_str(),
211 std::ffi::CStr::from_ptr(cert_path).to_str(),
212 )
213 } {
214 (Ok(url), Ok(domain), Ok(ip_addr), Ok(cert_path)) => {
215 if domain.is_empty() {
216 (url, None, ip_addr.to_string(), None)
217 } else if !cert_path.is_empty() {
218 (url, Some(domain.to_string()), ip_addr.to_string(), Some(cert_path.to_string()))
219 } else {
220 (
221 url,
222 Some(domain.to_string()),
223 ip_addr.to_string(),
224 Some(SYSTEM_CERT_PATH.to_string()),
225 )
226 }
227 }
228 _ => {
229 error!("bad input"); // Should not happen
230 return -libc::EINVAL;
231 }
232 };
233
234 let (url, ip_addr) = match (Url::parse(url), IpAddr::from_str(&ip_addr)) {
235 (Ok(url), Ok(ip_addr)) => (url, ip_addr),
236 _ => {
237 error!("bad ip or url"); // Should not happen
238 return -libc::EINVAL;
239 }
240 };
241 let cmd = Command::Probe {
242 info: ServerInfo {
243 net_id,
244 url,
245 peer_addr: SocketAddr::new(ip_addr, DOH_PORT),
246 domain,
247 sk_mark,
248 cert_path,
249 idle_timeout_ms: flags.idle_timeout_ms,
250 use_session_resumption: flags.use_session_resumption,
251 enable_early_data: flags.enable_early_data,
252 network_type,
253 private_dns_mode,
254 },
255 timeout: Duration::from_millis(flags.probe_timeout_ms),
256 };
257 if let Err(e) = doh.lock().send_cmd(cmd) {
258 error!("Failed to send the probe: {:?}", e);
259 return -libc::EPIPE;
260 }
261 0
262 }
263
264 /// Sends a DNS query via the network associated to the given |net_id| and waits for the response.
265 /// The return code should be either one of the public constant DOH_RESULT_* to indicate the error
266 /// or the size of the answer.
267 /// # Safety
268 /// `doh` must be a non-null pointer previously created by `doh_dispatcher_new()`
269 /// and not yet deleted by `doh_dispatcher_delete()`.
270 /// `dns_query` must point to a buffer at least `dns_query_len` in size.
271 /// `response` must point to a buffer at least `response_len` in size.
272 #[no_mangle]
doh_query( doh: &DohDispatcher, net_id: uint32_t, dns_query: *mut u8, dns_query_len: size_t, response: *mut u8, response_len: size_t, timeout_ms: uint64_t, ) -> ssize_t273 pub unsafe extern "C" fn doh_query(
274 doh: &DohDispatcher,
275 net_id: uint32_t,
276 dns_query: *mut u8,
277 dns_query_len: size_t,
278 response: *mut u8,
279 response_len: size_t,
280 timeout_ms: uint64_t,
281 ) -> ssize_t {
282 // SAFETY: The caller guarantees that `dns_query` is a valid pointer to a buffer of at least
283 // `dns_query_len` items.
284 let q = unsafe { slice::from_raw_parts_mut(dns_query, dns_query_len) };
285
286 let (resp_tx, resp_rx) = oneshot::channel();
287 let t = Duration::from_millis(timeout_ms);
288 if let Some(expired_time) = BootTime::now().checked_add(t) {
289 let cmd = Command::Query {
290 net_id,
291 base64_query: BASE64_URL_SAFE_NO_PAD.encode(q),
292 expired_time,
293 resp: resp_tx,
294 };
295
296 if let Err(e) = doh.lock().send_cmd(cmd) {
297 error!("Failed to send the query: {:?}", e);
298 return DOH_RESULT_CAN_NOT_SEND;
299 }
300 } else {
301 error!("Bad timeout parameter: {}", timeout_ms);
302 return DOH_RESULT_CAN_NOT_SEND;
303 }
304
305 if let Ok(rt) = Builder::new_current_thread().enable_all().build() {
306 let local = task::LocalSet::new();
307 match local.block_on(&rt, async { timeout(t, resp_rx).await }) {
308 Ok(v) => match v {
309 Ok(v) => match v {
310 Response::Success { answer } => {
311 if answer.len() > response_len || answer.len() > isize::MAX as usize {
312 return DOH_RESULT_INTERNAL_ERROR;
313 }
314 // SAFETY: The caller guarantees that response points to a valid buffer at
315 // least `response_len` long, and we just checked that `answer.len()` is no
316 // longer than `response_len`.
317 let response = unsafe { slice::from_raw_parts_mut(response, answer.len()) };
318 response.copy_from_slice(&answer);
319 answer.len() as ssize_t
320 }
321 rsp => {
322 error!("Non-successful response: {:?}", rsp);
323 DOH_RESULT_CAN_NOT_SEND
324 }
325 },
326 Err(e) => {
327 error!("no result {}", e);
328 DOH_RESULT_CAN_NOT_SEND
329 }
330 },
331 Err(e) => {
332 error!("timeout: {}", e);
333 DOH_RESULT_TIMEOUT
334 }
335 }
336 } else {
337 DOH_RESULT_CAN_NOT_SEND
338 }
339 }
340
341 /// Clears the DoH servers associated with the given |netid|.
342 /// # Safety
343 /// `doh` must be a non-null pointer previously created by `doh_dispatcher_new()`
344 /// and not yet deleted by `doh_dispatcher_delete()`.
345 #[no_mangle]
doh_net_delete(doh: &DohDispatcher, net_id: uint32_t)346 pub extern "C" fn doh_net_delete(doh: &DohDispatcher, net_id: uint32_t) {
347 if let Err(e) = doh.lock().send_cmd(Command::Clear { net_id }) {
348 error!("Failed to send the query: {:?}", e);
349 }
350 }
351
352 #[cfg(test)]
353 mod tests {
354 use super::*;
355
356 const TEST_NET_ID: u32 = 50;
357 const LOOPBACK_ADDR: &str = "127.0.0.1:443";
358 const LOCALHOST_URL: &str = "https://mylocal.com/dns-query";
359
success_cb( net_id: uint32_t, success: bool, ip_addr: *const c_char, host: *const c_char, )360 unsafe extern "C" fn success_cb(
361 net_id: uint32_t,
362 success: bool,
363 ip_addr: *const c_char,
364 host: *const c_char,
365 ) {
366 assert!(success);
367 // SAFETY: The caller guarantees that ip_addr and host are valid nul-terminated C strings.
368 unsafe {
369 assert_validation_info(net_id, ip_addr, host);
370 }
371 }
372
fail_cb( net_id: uint32_t, success: bool, ip_addr: *const c_char, host: *const c_char, )373 unsafe extern "C" fn fail_cb(
374 net_id: uint32_t,
375 success: bool,
376 ip_addr: *const c_char,
377 host: *const c_char,
378 ) {
379 assert!(!success);
380 // SAFETY: The caller guarantees that ip_addr and host are valid nul-terminated C strings.
381 unsafe {
382 assert_validation_info(net_id, ip_addr, host);
383 }
384 }
385
386 // # Safety
387 // `ip_addr`, `host` are null terminated strings
assert_validation_info( net_id: uint32_t, ip_addr: *const c_char, host: *const c_char, )388 unsafe fn assert_validation_info(
389 net_id: uint32_t,
390 ip_addr: *const c_char,
391 host: *const c_char,
392 ) {
393 assert_eq!(net_id, TEST_NET_ID);
394 // SAFETY: The caller guarantees that `ip_addr` is a valid nul-terminated C string.
395 let ip_addr = unsafe { std::ffi::CStr::from_ptr(ip_addr) }.to_str().unwrap();
396 let expected_addr: SocketAddr = LOOPBACK_ADDR.parse().unwrap();
397 assert_eq!(ip_addr, expected_addr.ip().to_string());
398 // SAFETY: The caller guarantees that `host` is a valid nul-terminated C string.
399 let host = unsafe { std::ffi::CStr::from_ptr(host) }.to_str().unwrap();
400 assert_eq!(host, "");
401 }
402
403 #[tokio::test]
wrap_validation_callback_converts_correctly()404 async fn wrap_validation_callback_converts_correctly() {
405 let info = ServerInfo {
406 net_id: TEST_NET_ID,
407 url: Url::parse(LOCALHOST_URL).unwrap(),
408 peer_addr: LOOPBACK_ADDR.parse().unwrap(),
409 domain: None,
410 sk_mark: 0,
411 cert_path: None,
412 idle_timeout_ms: 0,
413 use_session_resumption: true,
414 enable_early_data: true,
415 network_type: 2,
416 private_dns_mode: 3,
417 };
418
419 wrap_validation_callback(success_cb)(&info, true).await;
420 wrap_validation_callback(fail_cb)(&info, false).await;
421 }
422
tag_socket_cb(raw_fd: RawFd)423 extern "C" fn tag_socket_cb(raw_fd: RawFd) {
424 assert!(raw_fd > 0)
425 }
426
427 #[tokio::test]
wrap_tag_socket_callback_converts_correctly()428 async fn wrap_tag_socket_callback_converts_correctly() {
429 let sock = std::net::UdpSocket::bind("127.0.0.1:0").unwrap();
430 wrap_tag_socket_callback(tag_socket_cb)(&sock).await;
431 }
432 }
433