1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 use crate::boot_time::{BootTime, Duration};
18 use anyhow::Result;
19 use log::error;
20 use tokio::runtime::{Builder, Runtime};
21 use tokio::sync::{mpsc, oneshot};
22 use tokio::task;
23 
24 pub use crate::network::{ServerInfo, SocketTagger, ValidationReporter};
25 
26 const MAX_BUFFERED_CMD_COUNT: usize = 400;
27 
28 mod driver;
29 use driver::Driver;
30 
31 #[derive(Eq, PartialEq, Debug)]
32 /// Error response to a query
33 pub enum QueryError {
34     /// Network failed probing
35     BrokenServer,
36     /// HTTP/3 connection died
37     ConnectionError,
38     /// Network not probed yet
39     ServerNotReady,
40     /// Server reset HTTP/3 stream
41     Reset(u64),
42     /// Tried to query non-existent network
43     Unexpected,
44 }
45 
46 #[derive(Eq, PartialEq, Debug)]
47 pub enum Response {
48     Error { error: QueryError },
49     Success { answer: Vec<u8> },
50 }
51 
52 #[derive(Debug)]
53 pub enum Command {
54     Probe {
55         info: ServerInfo,
56         timeout: Duration,
57     },
58     Query {
59         net_id: u32,
60         base64_query: String,
61         expired_time: BootTime,
62         resp: oneshot::Sender<Response>,
63     },
64     Clear {
65         net_id: u32,
66     },
67     Exit,
68 }
69 
70 /// Context for a running DoH engine.
71 pub struct Dispatcher {
72     /// Used to submit cmds to the I/O task.
73     cmd_sender: mpsc::Sender<Command>,
74     join_handle: task::JoinHandle<Result<()>>,
75     runtime: Runtime,
76 }
77 
78 impl Dispatcher {
79     const DOH_THREADS: usize = 1;
80 
new(validation: ValidationReporter, tagger: SocketTagger) -> Result<Dispatcher>81     pub fn new(validation: ValidationReporter, tagger: SocketTagger) -> Result<Dispatcher> {
82         let (cmd_sender, cmd_receiver) = mpsc::channel::<Command>(MAX_BUFFERED_CMD_COUNT);
83         let runtime = Builder::new_multi_thread()
84             .worker_threads(Self::DOH_THREADS)
85             .enable_all()
86             .thread_name("doh-handler")
87             .build()?;
88         let join_handle = runtime.spawn(async {
89             let result = Driver::new(cmd_receiver, validation, tagger).drive().await;
90             if let Err(ref e) = result {
91                 error!("Dispatcher driver exited due to {:?}", e)
92             }
93             result
94         });
95         Ok(Dispatcher { cmd_sender, join_handle, runtime })
96     }
97 
send_cmd(&self, cmd: Command) -> Result<()>98     pub fn send_cmd(&self, cmd: Command) -> Result<()> {
99         self.cmd_sender.blocking_send(cmd)?;
100         Ok(())
101     }
102 
exit_handler(&mut self)103     pub fn exit_handler(&mut self) {
104         if self.cmd_sender.blocking_send(Command::Exit).is_err() {
105             return;
106         }
107         let _ = self.runtime.block_on(&mut self.join_handle);
108     }
109 }
110