1 //! This module handles "arbitration" of ATT packets, to determine whether they
2 //! should be handled by the primary stack or by the Rust stack
3 
4 use std::sync::{Arc, Mutex};
5 
6 use log::{error, trace, warn};
7 use std::sync::RwLock;
8 
9 use crate::{
10     do_in_rust_thread,
11     packets::{AttOpcode, OwnedAttView, OwnedPacket},
12 };
13 
14 use super::{
15     ffi::{InterceptAction, StoreCallbacksFromRust},
16     ids::{AdvertiserId, TransportIndex},
17     mtu::MtuEvent,
18     opcode_types::{classify_opcode, OperationType},
19     server::isolation_manager::IsolationManager,
20 };
21 
22 static ARBITER: RwLock<Option<Arc<Mutex<IsolationManager>>>> = RwLock::new(None);
23 
24 /// Initialize the Arbiter
initialize_arbiter() -> Arc<Mutex<IsolationManager>>25 pub fn initialize_arbiter() -> Arc<Mutex<IsolationManager>> {
26     let arbiter = Arc::new(Mutex::new(IsolationManager::new()));
27     let mut lock = ARBITER.write().unwrap();
28     assert!(lock.is_none(), "Rust stack should only start up once");
29     *lock = Some(arbiter.clone());
30 
31     StoreCallbacksFromRust(
32         on_le_connect,
33         on_le_disconnect,
34         intercept_packet,
35         |tcb_idx| on_mtu_event(TransportIndex(tcb_idx), MtuEvent::OutgoingRequest),
36         |tcb_idx, mtu| on_mtu_event(TransportIndex(tcb_idx), MtuEvent::IncomingResponse(mtu)),
37         |tcb_idx, mtu| on_mtu_event(TransportIndex(tcb_idx), MtuEvent::IncomingRequest(mtu)),
38     );
39 
40     arbiter
41 }
42 
43 /// Clean the Arbiter
clean_arbiter()44 pub fn clean_arbiter() {
45     let mut lock = ARBITER.write().unwrap();
46     *lock = None
47 }
48 
49 /// Acquire the mutex holding the Arbiter and provide a mutable reference to the
50 /// supplied closure
with_arbiter<T>(f: impl FnOnce(&mut IsolationManager) -> T) -> T51 pub fn with_arbiter<T>(f: impl FnOnce(&mut IsolationManager) -> T) -> T {
52     f(ARBITER.read().unwrap().as_ref().expect("Rust stack is not started").lock().as_mut().unwrap())
53 }
54 
55 /// Check if the Arbiter is initialized.
has_arbiter() -> bool56 pub fn has_arbiter() -> bool {
57     ARBITER.read().unwrap().is_some()
58 }
59 
60 /// Test to see if a buffer contains a valid ATT packet with an opcode we
61 /// are interested in intercepting (those intended for servers that are isolated)
try_parse_att_server_packet( isolation_manager: &IsolationManager, tcb_idx: TransportIndex, packet: Box<[u8]>, ) -> Option<OwnedAttView>62 fn try_parse_att_server_packet(
63     isolation_manager: &IsolationManager,
64     tcb_idx: TransportIndex,
65     packet: Box<[u8]>,
66 ) -> Option<OwnedAttView> {
67     isolation_manager.get_server_id(tcb_idx)?;
68 
69     let att = OwnedAttView::try_parse(packet).ok()?;
70 
71     if att.view().get_opcode() == AttOpcode::EXCHANGE_MTU_REQUEST {
72         // special case: this server opcode is handled by legacy stack, and we snoop
73         // on its handling, since the MTU is shared between the client + server
74         return None;
75     }
76 
77     match classify_opcode(att.view().get_opcode()) {
78         OperationType::Command | OperationType::Request | OperationType::Confirmation => Some(att),
79         _ => None,
80     }
81 }
82 
on_le_connect(tcb_idx: u8, advertiser: u8)83 fn on_le_connect(tcb_idx: u8, advertiser: u8) {
84     let tcb_idx = TransportIndex(tcb_idx);
85     let advertiser = AdvertiserId(advertiser);
86     let is_isolated = with_arbiter(|arbiter| arbiter.is_advertiser_isolated(advertiser));
87     if is_isolated {
88         do_in_rust_thread(move |modules| {
89             if let Err(err) = modules.gatt_module.on_le_connect(tcb_idx, Some(advertiser)) {
90                 error!("{err:?}")
91             }
92         })
93     }
94 }
95 
on_le_disconnect(tcb_idx: u8)96 fn on_le_disconnect(tcb_idx: u8) {
97     // Events may be received after a FactoryReset
98     // is initiated for Bluetooth and the rust arbiter is taken
99     // down.
100     if !has_arbiter() {
101         warn!("arbiter is not yet initialized");
102         return;
103     }
104 
105     let tcb_idx = TransportIndex(tcb_idx);
106     let was_isolated = with_arbiter(|arbiter| arbiter.is_connection_isolated(tcb_idx));
107     if was_isolated {
108         do_in_rust_thread(move |modules| {
109             if let Err(err) = modules.gatt_module.on_le_disconnect(tcb_idx) {
110                 error!("{err:?}")
111             }
112         })
113     }
114 }
115 
intercept_packet(tcb_idx: u8, packet: Vec<u8>) -> InterceptAction116 fn intercept_packet(tcb_idx: u8, packet: Vec<u8>) -> InterceptAction {
117     // Events may be received after a FactoryReset
118     // is initiated for Bluetooth and the rust arbiter is taken
119     // down.
120     if !has_arbiter() {
121         warn!("arbiter is not yet initialized");
122         return InterceptAction::Drop;
123     }
124 
125     let tcb_idx = TransportIndex(tcb_idx);
126     if let Some(att) = with_arbiter(|arbiter| {
127         try_parse_att_server_packet(arbiter, tcb_idx, packet.into_boxed_slice())
128     }) {
129         do_in_rust_thread(move |modules| {
130             trace!("pushing packet to GATT");
131             if let Some(bearer) = modules.gatt_module.get_bearer(tcb_idx) {
132                 bearer.handle_packet(att.view())
133             } else {
134                 error!("Bearer for {tcb_idx:?} not found");
135             }
136         });
137         InterceptAction::Drop
138     } else {
139         InterceptAction::Forward
140     }
141 }
142 
on_mtu_event(tcb_idx: TransportIndex, event: MtuEvent)143 fn on_mtu_event(tcb_idx: TransportIndex, event: MtuEvent) {
144     if with_arbiter(|arbiter| arbiter.is_connection_isolated(tcb_idx)) {
145         do_in_rust_thread(move |modules| {
146             let Some(bearer) = modules.gatt_module.get_bearer(tcb_idx) else {
147                 error!("Bearer for {tcb_idx:?} not found");
148                 return;
149             };
150             if let Err(err) = bearer.handle_mtu_event(event) {
151                 error!("{err:?}")
152             }
153         });
154     }
155 }
156 
157 #[cfg(test)]
158 mod test {
159     use super::*;
160 
161     use crate::{
162         gatt::ids::{AttHandle, ServerId},
163         packets::{
164             AttBuilder, AttExchangeMtuRequestBuilder, AttOpcode, AttReadRequestBuilder,
165             Serializable,
166         },
167     };
168 
169     const TCB_IDX: TransportIndex = TransportIndex(1);
170     const ADVERTISER_ID: AdvertiserId = AdvertiserId(3);
171     const SERVER_ID: ServerId = ServerId(4);
172 
create_manager_with_isolated_connection( tcb_idx: TransportIndex, server_id: ServerId, ) -> IsolationManager173     fn create_manager_with_isolated_connection(
174         tcb_idx: TransportIndex,
175         server_id: ServerId,
176     ) -> IsolationManager {
177         let mut isolation_manager = IsolationManager::new();
178         isolation_manager.associate_server_with_advertiser(server_id, ADVERTISER_ID);
179         isolation_manager.on_le_connect(tcb_idx, Some(ADVERTISER_ID));
180         isolation_manager
181     }
182 
183     #[test]
test_packet_capture_when_isolated()184     fn test_packet_capture_when_isolated() {
185         let isolation_manager = create_manager_with_isolated_connection(TCB_IDX, SERVER_ID);
186         let packet = AttBuilder {
187             opcode: AttOpcode::READ_REQUEST,
188             _child_: AttReadRequestBuilder { attribute_handle: AttHandle(1).into() }.into(),
189         };
190 
191         let out = try_parse_att_server_packet(
192             &isolation_manager,
193             TCB_IDX,
194             packet.to_vec().unwrap().into(),
195         );
196 
197         assert!(out.is_some());
198     }
199 
200     #[test]
test_packet_bypass_when_isolated()201     fn test_packet_bypass_when_isolated() {
202         let isolation_manager = create_manager_with_isolated_connection(TCB_IDX, SERVER_ID);
203         let packet = AttBuilder {
204             opcode: AttOpcode::ERROR_RESPONSE,
205             _child_: AttReadRequestBuilder { attribute_handle: AttHandle(1).into() }.into(),
206         };
207 
208         let out = try_parse_att_server_packet(
209             &isolation_manager,
210             TCB_IDX,
211             packet.to_vec().unwrap().into(),
212         );
213 
214         assert!(out.is_none());
215     }
216 
217     #[test]
test_mtu_bypass()218     fn test_mtu_bypass() {
219         let isolation_manager = create_manager_with_isolated_connection(TCB_IDX, SERVER_ID);
220         let packet = AttBuilder {
221             opcode: AttOpcode::EXCHANGE_MTU_REQUEST,
222             _child_: AttExchangeMtuRequestBuilder { mtu: 64 }.into(),
223         };
224 
225         let out = try_parse_att_server_packet(
226             &isolation_manager,
227             TCB_IDX,
228             packet.to_vec().unwrap().into(),
229         );
230 
231         assert!(out.is_none());
232     }
233 
234     #[test]
test_packet_bypass_when_not_isolated()235     fn test_packet_bypass_when_not_isolated() {
236         let isolation_manager = IsolationManager::new();
237         let packet = AttBuilder {
238             opcode: AttOpcode::READ_REQUEST,
239             _child_: AttReadRequestBuilder { attribute_handle: AttHandle(1).into() }.into(),
240         };
241 
242         let out = try_parse_att_server_packet(
243             &isolation_manager,
244             TCB_IDX,
245             packet.to_vec().unwrap().into(),
246         );
247 
248         assert!(out.is_none());
249     }
250 }
251