1 use std::{
2     collections::{hash_map::Entry, HashMap},
3     future::{Future, IntoFuture},
4 };
5 
6 use tokio::sync::oneshot;
7 
8 use crate::core::address::AddressWithType;
9 
10 use super::{
11     le_manager::ErrorCode, CancelConnectFailure, ConnectionFailure, ConnectionManagerClient,
12     CreateConnectionFailure, LeConnection,
13 };
14 
15 #[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
16 pub enum ConnectionMode {
17     Background,
18     Direct,
19 }
20 
21 #[derive(Debug)]
22 struct ConnectionAttemptData {
23     id: AttemptId,
24     conn_tx: Option<oneshot::Sender<Result<LeConnection, ErrorCode>>>,
25 }
26 
27 #[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)]
28 pub struct ConnectionAttempt {
29     pub client: ConnectionManagerClient,
30     pub mode: ConnectionMode,
31     pub remote_address: AddressWithType,
32 }
33 
34 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
35 pub struct AttemptId(u64);
36 
37 #[derive(Debug)]
38 pub struct ConnectionAttempts {
39     attempt_id: AttemptId,
40     attempts: HashMap<ConnectionAttempt, ConnectionAttemptData>,
41 }
42 
43 #[derive(Debug)]
44 pub struct PendingConnectionAttempt<F> {
45     pub id: AttemptId,
46     f: F,
47 }
48 
49 impl<F> IntoFuture for PendingConnectionAttempt<F>
50 where
51     F: Future<Output = Result<LeConnection, ConnectionFailure>>,
52 {
53     type Output = F::Output;
54     type IntoFuture = F;
55 
into_future(self) -> Self::IntoFuture56     fn into_future(self) -> Self::IntoFuture {
57         self.f
58     }
59 }
60 
61 impl ConnectionAttempts {
62     /// Constructor
new() -> Self63     pub fn new() -> Self {
64         Self { attempt_id: AttemptId(0), attempts: HashMap::new() }
65     }
66 
new_attempt_id(&mut self) -> AttemptId67     fn new_attempt_id(&mut self) -> AttemptId {
68         let AttemptId(id) = self.attempt_id;
69         self.attempt_id = AttemptId(id.wrapping_add(1));
70         AttemptId(id)
71     }
72 
73     /// Register a pending direct connection to the peer. Note that the peer MUST NOT be connected at this point.
74     /// Returns the AttemptId of this attempt, as well as a future resolving with the connection (once created) or an
75     /// error.
76     ///
77     /// Note that only one connection attempt from the same (client, address, mode) tuple can be pending at any time.
78     ///
79     /// # Cancellation Safety
80     /// If this future is cancelled, the attempt will NOT BE REMOVED! It must be cancelled explicitly. To avoid
81     /// cancelling the wrong future, the returned ID should be used.
register_direct_connection( &mut self, client: ConnectionManagerClient, address: AddressWithType, ) -> Result< PendingConnectionAttempt<impl Future<Output = Result<LeConnection, ConnectionFailure>>>, CreateConnectionFailure, >82     pub fn register_direct_connection(
83         &mut self,
84         client: ConnectionManagerClient,
85         address: AddressWithType,
86     ) -> Result<
87         PendingConnectionAttempt<impl Future<Output = Result<LeConnection, ConnectionFailure>>>,
88         CreateConnectionFailure,
89     > {
90         let attempt =
91             ConnectionAttempt { client, mode: ConnectionMode::Direct, remote_address: address };
92 
93         let id = self.new_attempt_id();
94         let Entry::Vacant(entry) = self.attempts.entry(attempt) else {
95             return Err(CreateConnectionFailure::ConnectionAlreadyPending);
96         };
97         let (tx, rx) = oneshot::channel();
98         entry.insert(ConnectionAttemptData { conn_tx: Some(tx), id });
99 
100         Ok(PendingConnectionAttempt {
101             id,
102             f: async move {
103                 rx.await
104                     .map_err(|_| ConnectionFailure::Cancelled)?
105                     .map_err(ConnectionFailure::Error)
106             },
107         })
108     }
109 
110     /// Register a pending background connection to the peer. Returns the AttemptId of this attempt.
111     ///
112     /// Note that only one connection attempt from the same (client, address, mode) tuple can be pending at any time.
register_background_connection( &mut self, client: ConnectionManagerClient, address: AddressWithType, ) -> Result<AttemptId, CreateConnectionFailure>113     pub fn register_background_connection(
114         &mut self,
115         client: ConnectionManagerClient,
116         address: AddressWithType,
117     ) -> Result<AttemptId, CreateConnectionFailure> {
118         let attempt =
119             ConnectionAttempt { client, mode: ConnectionMode::Background, remote_address: address };
120 
121         let id = self.new_attempt_id();
122         let Entry::Vacant(entry) = self.attempts.entry(attempt) else {
123             return Err(CreateConnectionFailure::ConnectionAlreadyPending);
124         };
125         entry.insert(ConnectionAttemptData { conn_tx: None, id });
126 
127         Ok(id)
128     }
129 
130     /// Cancel connection attempts with the specified mode from this client to the specified address.
cancel_attempt( &mut self, client: ConnectionManagerClient, address: AddressWithType, mode: ConnectionMode, ) -> Result<(), CancelConnectFailure>131     pub fn cancel_attempt(
132         &mut self,
133         client: ConnectionManagerClient,
134         address: AddressWithType,
135         mode: ConnectionMode,
136     ) -> Result<(), CancelConnectFailure> {
137         let existing =
138             self.attempts.remove(&ConnectionAttempt { client, mode, remote_address: address });
139 
140         if existing.is_some() {
141             // note: dropping the ConnectionAttemptData is sufficient to close the channel and send a cancellation error
142             Ok(())
143         } else {
144             Err(CancelConnectFailure::ConnectionNotPending)
145         }
146     }
147 
148     /// Cancel the connection attempt with the given ID.
cancel_attempt_with_id(&mut self, id: AttemptId)149     pub fn cancel_attempt_with_id(&mut self, id: AttemptId) {
150         self.attempts.retain(|_, attempt| attempt.id != id);
151     }
152 
153     /// Cancel all connection attempts to this address
remove_unconditionally(&mut self, address: AddressWithType)154     pub fn remove_unconditionally(&mut self, address: AddressWithType) {
155         self.attempts.retain(|attempt, _| attempt.remote_address != address);
156     }
157 
158     /// Cancel all connection attempts from this client
remove_client(&mut self, client: ConnectionManagerClient)159     pub fn remove_client(&mut self, client: ConnectionManagerClient) {
160         self.attempts.retain(|attempt, _| attempt.client != client);
161     }
162 
163     /// List all active connection attempts. Note that we can have active background (but NOT) direct
164     /// connection attempts to connected devices, as we will resume the connection attempt when the
165     /// peer disconnects from us.
active_attempts(&self) -> Vec<ConnectionAttempt>166     pub fn active_attempts(&self) -> Vec<ConnectionAttempt> {
167         self.attempts.keys().cloned().collect()
168     }
169 
170     /// Handle a successful connection by notifying clients and resolving direct connect attempts
process_connection( &mut self, address: AddressWithType, result: Result<LeConnection, ErrorCode>, )171     pub fn process_connection(
172         &mut self,
173         address: AddressWithType,
174         result: Result<LeConnection, ErrorCode>,
175     ) {
176         let interested_clients = self
177             .attempts
178             .keys()
179             .filter(|attempt| attempt.remote_address == address)
180             .copied()
181             .collect::<Vec<_>>();
182 
183         for attempt in interested_clients {
184             if attempt.mode == ConnectionMode::Direct {
185                 // TODO(aryarahul): clean up these unwraps
186                 let _ = self.attempts.remove(&attempt).unwrap().conn_tx.unwrap().send(result);
187             } else {
188                 // TODO(aryarahul): inform background clients of the connection
189             }
190         }
191     }
192 }
193 
194 #[cfg(test)]
195 mod test {
196     use crate::{
197         core::address::AddressType,
198         utils::task::{block_on_locally, try_await},
199     };
200 
201     use super::*;
202 
203     const CLIENT_1: ConnectionManagerClient = ConnectionManagerClient::GattClient(1);
204     const CLIENT_2: ConnectionManagerClient = ConnectionManagerClient::GattClient(2);
205 
206     const ADDRESS_1: AddressWithType =
207         AddressWithType { address: [1, 2, 3, 4, 5, 6], address_type: AddressType::Public };
208     const ADDRESS_2: AddressWithType =
209         AddressWithType { address: [1, 2, 3, 4, 5, 6], address_type: AddressType::Random };
210 
211     const CONNECTION_1: LeConnection = LeConnection { remote_address: ADDRESS_1 };
212     const CONNECTION_2: LeConnection = LeConnection { remote_address: ADDRESS_2 };
213 
214     #[test]
test_direct_connection()215     fn test_direct_connection() {
216         block_on_locally(async {
217             // arrange
218             let mut attempts = ConnectionAttempts::new();
219 
220             // act: start a pending direct connection
221             let _ = attempts.register_direct_connection(CLIENT_1, ADDRESS_1).unwrap();
222 
223             // assert: this attempt is pending
224             assert_eq!(attempts.active_attempts().len(), 1);
225             assert_eq!(attempts.active_attempts()[0].client, CLIENT_1);
226             assert_eq!(attempts.active_attempts()[0].mode, ConnectionMode::Direct);
227             assert_eq!(attempts.active_attempts()[0].remote_address, ADDRESS_1);
228         });
229     }
230 
231     #[test]
test_cancel_direct_connection()232     fn test_cancel_direct_connection() {
233         block_on_locally(async {
234             // arrange: one pending direct connection
235             let mut attempts = ConnectionAttempts::new();
236             let pending_direct_connection =
237                 attempts.register_direct_connection(CLIENT_1, ADDRESS_1).unwrap();
238 
239             // act: cancel it
240             attempts.cancel_attempt(CLIENT_1, ADDRESS_1, ConnectionMode::Direct).unwrap();
241             let resp = pending_direct_connection.await;
242 
243             // assert: the original future resolved, and the attempt is cleared
244             assert_eq!(resp, Err(ConnectionFailure::Cancelled));
245             assert!(attempts.active_attempts().is_empty());
246         });
247     }
248 
249     #[test]
test_multiple_direct_connections()250     fn test_multiple_direct_connections() {
251         block_on_locally(async {
252             // arrange
253             let mut attempts = ConnectionAttempts::new();
254 
255             // act: start two direct connections
256             attempts.register_direct_connection(CLIENT_1, ADDRESS_1).unwrap();
257             attempts.register_direct_connection(CLIENT_2, ADDRESS_1).unwrap();
258 
259             // assert: both attempts are pending
260             assert_eq!(attempts.active_attempts().len(), 2);
261         });
262     }
263 
264     #[test]
test_two_direct_connection_cancel_one()265     fn test_two_direct_connection_cancel_one() {
266         block_on_locally(async {
267             // arrange: two pending direct connections
268             let mut attempts = ConnectionAttempts::new();
269             attempts.register_direct_connection(CLIENT_1, ADDRESS_1).unwrap();
270             attempts.register_direct_connection(CLIENT_2, ADDRESS_1).unwrap();
271 
272             // act: cancel one
273             attempts.cancel_attempt(CLIENT_1, ADDRESS_1, ConnectionMode::Direct).unwrap();
274 
275             // assert: one attempt is still pending
276             assert_eq!(attempts.active_attempts().len(), 1);
277             assert_eq!(attempts.active_attempts()[0].client, CLIENT_2);
278         });
279     }
280 
281     #[test]
test_drop_pending_connection_after_cancel_and_restart()282     fn test_drop_pending_connection_after_cancel_and_restart() {
283         // arrange
284         let mut attempts = ConnectionAttempts::new();
285 
286         // act: start one pending direct connection, cancel it, restart it, and then drop the first future
287         let pending_1 = attempts.register_direct_connection(CLIENT_1, ADDRESS_1).unwrap();
288         attempts.cancel_attempt(CLIENT_1, ADDRESS_1, ConnectionMode::Direct).unwrap();
289         let _pending_2 = attempts.register_direct_connection(CLIENT_1, ADDRESS_1).unwrap();
290         drop(pending_1);
291 
292         // assert: the restart is still pending
293         assert_eq!(attempts.active_attempts().len(), 1);
294     }
295 
296     #[test]
test_background_connection()297     fn test_background_connection() {
298         block_on_locally(async {
299             // arrange
300             let mut attempts = ConnectionAttempts::new();
301 
302             // act: start a pending background connection
303             attempts.register_background_connection(CLIENT_1, ADDRESS_1).unwrap();
304 
305             // assert: this attempt is pending
306             assert_eq!(attempts.active_attempts().len(), 1);
307             assert_eq!(attempts.active_attempts()[0].client, CLIENT_1);
308             assert_eq!(attempts.active_attempts()[0].mode, ConnectionMode::Background);
309             assert_eq!(attempts.active_attempts()[0].remote_address, ADDRESS_1);
310         });
311     }
312 
313     #[test]
test_reject_duplicate_direct_connection()314     fn test_reject_duplicate_direct_connection() {
315         block_on_locally(async {
316             // arrange
317             let mut attempts = ConnectionAttempts::new();
318 
319             // act: start two background connections with the same parameters
320             let _fut = attempts.register_direct_connection(CLIENT_1, ADDRESS_1).unwrap();
321             let ret = attempts.register_direct_connection(CLIENT_1, ADDRESS_1);
322 
323             // assert: this attempt is pending
324             assert!(matches!(ret, Err(CreateConnectionFailure::ConnectionAlreadyPending)));
325         });
326     }
327 
328     #[test]
test_reject_duplicate_background_connection()329     fn test_reject_duplicate_background_connection() {
330         block_on_locally(async {
331             // arrange
332             let mut attempts = ConnectionAttempts::new();
333 
334             // act: start two background connections with the same parameters
335             attempts.register_background_connection(CLIENT_1, ADDRESS_1).unwrap();
336             let ret = attempts.register_background_connection(CLIENT_1, ADDRESS_1);
337 
338             // assert: this attempt is pending
339             assert_eq!(ret, Err(CreateConnectionFailure::ConnectionAlreadyPending));
340         });
341     }
342 
343     #[test]
test_resolved_direct_connection()344     fn test_resolved_direct_connection() {
345         block_on_locally(async {
346             // arrange: one pending direct connection
347             let mut attempts = ConnectionAttempts::new();
348             let pending_conn = attempts.register_direct_connection(CLIENT_1, ADDRESS_1).unwrap();
349 
350             // act: resolve with an incoming connection
351             attempts.process_connection(ADDRESS_1, Ok(CONNECTION_1));
352 
353             // assert: the attempt is resolved and is no longer active
354             assert_eq!(pending_conn.await.unwrap(), CONNECTION_1);
355             assert!(attempts.active_attempts().is_empty());
356         });
357     }
358 
359     #[test]
test_failed_direct_connection()360     fn test_failed_direct_connection() {
361         block_on_locally(async {
362             // arrange: one pending direct connection
363             let mut attempts = ConnectionAttempts::new();
364             let pending_conn = attempts.register_direct_connection(CLIENT_1, ADDRESS_1).unwrap();
365 
366             // act: resolve with an incoming connection
367             attempts.process_connection(ADDRESS_1, Err(ErrorCode(1)));
368 
369             // assert: the attempt is resolved and is no longer active
370             assert_eq!(pending_conn.await, Err(ConnectionFailure::Error(ErrorCode(1))));
371             assert!(attempts.active_attempts().is_empty());
372         });
373     }
374 
375     #[test]
test_resolved_background_connection()376     fn test_resolved_background_connection() {
377         block_on_locally(async {
378             // arrange: one pending direct connection
379             let mut attempts = ConnectionAttempts::new();
380             attempts.register_background_connection(CLIENT_1, ADDRESS_1).unwrap();
381 
382             // act: resolve with an incoming connection
383             attempts.process_connection(ADDRESS_1, Ok(CONNECTION_1));
384 
385             // assert: the attempt is still active
386             assert_eq!(attempts.active_attempts().len(), 1);
387         });
388     }
389 
390     #[test]
test_incoming_connection_while_another_is_pending()391     fn test_incoming_connection_while_another_is_pending() {
392         block_on_locally(async {
393             // arrange: one pending direct connection
394             let mut attempts = ConnectionAttempts::new();
395             let pending_conn = attempts.register_direct_connection(CLIENT_1, ADDRESS_1).unwrap();
396 
397             // act: an incoming connection arrives to a different address
398             attempts.process_connection(ADDRESS_2, Ok(CONNECTION_2));
399 
400             // assert: the attempt is still pending
401             assert!(try_await(pending_conn).await.is_err());
402             assert_eq!(attempts.active_attempts().len(), 1);
403         });
404     }
405 
406     #[test]
test_incoming_connection_resolves_some_but_not_all()407     fn test_incoming_connection_resolves_some_but_not_all() {
408         block_on_locally(async {
409             // arrange: one pending direct connection and one background connection to each of two addresses
410             let mut attempts = ConnectionAttempts::new();
411             let pending_conn_1 = attempts.register_direct_connection(CLIENT_1, ADDRESS_1).unwrap();
412             let pending_conn_2 = attempts.register_direct_connection(CLIENT_1, ADDRESS_2).unwrap();
413             attempts.register_background_connection(CLIENT_1, ADDRESS_1).unwrap();
414             attempts.register_background_connection(CLIENT_1, ADDRESS_2).unwrap();
415 
416             // act: an incoming connection arrives to the first address
417             attempts.process_connection(ADDRESS_1, Ok(CONNECTION_1));
418 
419             // assert: one direct attempt is completed, one is still pending
420             assert_eq!(pending_conn_1.await, Ok(CONNECTION_1));
421             assert!(try_await(pending_conn_2).await.is_err());
422             // three attempts remain (the unresolved direct, and both background attempts)
423             assert_eq!(attempts.active_attempts().len(), 3);
424         });
425     }
426 
427     #[test]
test_remove_background_connection()428     fn test_remove_background_connection() {
429         block_on_locally(async {
430             // arrange: one pending background connection
431             let mut attempts = ConnectionAttempts::new();
432             attempts.register_background_connection(CLIENT_1, ADDRESS_1).unwrap();
433 
434             // act: remove it
435             attempts.cancel_attempt(CLIENT_1, ADDRESS_1, ConnectionMode::Background).unwrap();
436 
437             // assert: no pending attempts
438             assert!(attempts.active_attempts().is_empty());
439         });
440     }
441 
442     #[test]
test_cancel_nonexistent_connection()443     fn test_cancel_nonexistent_connection() {
444         block_on_locally(async {
445             // arrange
446             let mut attempts = ConnectionAttempts::new();
447 
448             // act: cancel a nonexistent direct connection
449             let resp = attempts.cancel_attempt(CLIENT_1, ADDRESS_1, ConnectionMode::Direct);
450 
451             // assert: got an error
452             assert_eq!(resp, Err(CancelConnectFailure::ConnectionNotPending));
453         });
454     }
455 
456     #[test]
test_remove_unconditionally()457     fn test_remove_unconditionally() {
458         block_on_locally(async {
459             // arrange: one pending direct connection, and one background connection, to each address
460             let mut attempts = ConnectionAttempts::new();
461             let pending_conn_1 = attempts.register_direct_connection(CLIENT_1, ADDRESS_1).unwrap();
462             let pending_conn_2 = attempts.register_direct_connection(CLIENT_1, ADDRESS_2).unwrap();
463             attempts.register_background_connection(CLIENT_1, ADDRESS_1).unwrap();
464             attempts.register_background_connection(CLIENT_1, ADDRESS_2).unwrap();
465 
466             // act: cancel all connections to the first address
467             attempts.remove_unconditionally(ADDRESS_1);
468 
469             // assert: one direct attempt is completed, one is still pending
470             assert_eq!(pending_conn_1.await, Err(ConnectionFailure::Cancelled));
471             assert!(try_await(pending_conn_2).await.is_err());
472             // assert: two attempts remain, both to the other address
473             assert_eq!(attempts.active_attempts().len(), 2);
474             assert_eq!(attempts.active_attempts()[0].remote_address, ADDRESS_2);
475             assert_eq!(attempts.active_attempts()[1].remote_address, ADDRESS_2);
476         });
477     }
478 
479     #[test]
test_remove_client()480     fn test_remove_client() {
481         block_on_locally(async {
482             // arrange: one pending direct connection, and one background connection, from each address
483             let mut attempts = ConnectionAttempts::new();
484             let pending_conn_1 = attempts.register_direct_connection(CLIENT_1, ADDRESS_1).unwrap();
485             let pending_conn_2 = attempts.register_direct_connection(CLIENT_2, ADDRESS_1).unwrap();
486             attempts.register_background_connection(CLIENT_1, ADDRESS_1).unwrap();
487             attempts.register_background_connection(CLIENT_2, ADDRESS_1).unwrap();
488 
489             // act: remove the first client
490             attempts.remove_client(CLIENT_1);
491 
492             // assert: one direct attempt is completed, one is still pending
493             assert_eq!(pending_conn_1.await, Err(ConnectionFailure::Cancelled));
494             assert!(try_await(pending_conn_2).await.is_err());
495             // assert: two attempts remain, both from the second client
496             assert_eq!(attempts.active_attempts().len(), 2);
497             assert_eq!(attempts.active_attempts()[0].client, CLIENT_2);
498             assert_eq!(attempts.active_attempts()[1].client, CLIENT_2);
499         });
500     }
501 }
502