1 use std::{ 2 collections::{hash_map::Entry, HashMap}, 3 future::{Future, IntoFuture}, 4 }; 5 6 use tokio::sync::oneshot; 7 8 use crate::core::address::AddressWithType; 9 10 use super::{ 11 le_manager::ErrorCode, CancelConnectFailure, ConnectionFailure, ConnectionManagerClient, 12 CreateConnectionFailure, LeConnection, 13 }; 14 15 #[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)] 16 pub enum ConnectionMode { 17 Background, 18 Direct, 19 } 20 21 #[derive(Debug)] 22 struct ConnectionAttemptData { 23 id: AttemptId, 24 conn_tx: Option<oneshot::Sender<Result<LeConnection, ErrorCode>>>, 25 } 26 27 #[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)] 28 pub struct ConnectionAttempt { 29 pub client: ConnectionManagerClient, 30 pub mode: ConnectionMode, 31 pub remote_address: AddressWithType, 32 } 33 34 #[derive(Debug, Copy, Clone, PartialEq, Eq)] 35 pub struct AttemptId(u64); 36 37 #[derive(Debug)] 38 pub struct ConnectionAttempts { 39 attempt_id: AttemptId, 40 attempts: HashMap<ConnectionAttempt, ConnectionAttemptData>, 41 } 42 43 #[derive(Debug)] 44 pub struct PendingConnectionAttempt<F> { 45 pub id: AttemptId, 46 f: F, 47 } 48 49 impl<F> IntoFuture for PendingConnectionAttempt<F> 50 where 51 F: Future<Output = Result<LeConnection, ConnectionFailure>>, 52 { 53 type Output = F::Output; 54 type IntoFuture = F; 55 into_future(self) -> Self::IntoFuture56 fn into_future(self) -> Self::IntoFuture { 57 self.f 58 } 59 } 60 61 impl ConnectionAttempts { 62 /// Constructor new() -> Self63 pub fn new() -> Self { 64 Self { attempt_id: AttemptId(0), attempts: HashMap::new() } 65 } 66 new_attempt_id(&mut self) -> AttemptId67 fn new_attempt_id(&mut self) -> AttemptId { 68 let AttemptId(id) = self.attempt_id; 69 self.attempt_id = AttemptId(id.wrapping_add(1)); 70 AttemptId(id) 71 } 72 73 /// Register a pending direct connection to the peer. Note that the peer MUST NOT be connected at this point. 74 /// Returns the AttemptId of this attempt, as well as a future resolving with the connection (once created) or an 75 /// error. 76 /// 77 /// Note that only one connection attempt from the same (client, address, mode) tuple can be pending at any time. 78 /// 79 /// # Cancellation Safety 80 /// If this future is cancelled, the attempt will NOT BE REMOVED! It must be cancelled explicitly. To avoid 81 /// cancelling the wrong future, the returned ID should be used. register_direct_connection( &mut self, client: ConnectionManagerClient, address: AddressWithType, ) -> Result< PendingConnectionAttempt<impl Future<Output = Result<LeConnection, ConnectionFailure>>>, CreateConnectionFailure, >82 pub fn register_direct_connection( 83 &mut self, 84 client: ConnectionManagerClient, 85 address: AddressWithType, 86 ) -> Result< 87 PendingConnectionAttempt<impl Future<Output = Result<LeConnection, ConnectionFailure>>>, 88 CreateConnectionFailure, 89 > { 90 let attempt = 91 ConnectionAttempt { client, mode: ConnectionMode::Direct, remote_address: address }; 92 93 let id = self.new_attempt_id(); 94 let Entry::Vacant(entry) = self.attempts.entry(attempt) else { 95 return Err(CreateConnectionFailure::ConnectionAlreadyPending); 96 }; 97 let (tx, rx) = oneshot::channel(); 98 entry.insert(ConnectionAttemptData { conn_tx: Some(tx), id }); 99 100 Ok(PendingConnectionAttempt { 101 id, 102 f: async move { 103 rx.await 104 .map_err(|_| ConnectionFailure::Cancelled)? 105 .map_err(ConnectionFailure::Error) 106 }, 107 }) 108 } 109 110 /// Register a pending background connection to the peer. Returns the AttemptId of this attempt. 111 /// 112 /// Note that only one connection attempt from the same (client, address, mode) tuple can be pending at any time. register_background_connection( &mut self, client: ConnectionManagerClient, address: AddressWithType, ) -> Result<AttemptId, CreateConnectionFailure>113 pub fn register_background_connection( 114 &mut self, 115 client: ConnectionManagerClient, 116 address: AddressWithType, 117 ) -> Result<AttemptId, CreateConnectionFailure> { 118 let attempt = 119 ConnectionAttempt { client, mode: ConnectionMode::Background, remote_address: address }; 120 121 let id = self.new_attempt_id(); 122 let Entry::Vacant(entry) = self.attempts.entry(attempt) else { 123 return Err(CreateConnectionFailure::ConnectionAlreadyPending); 124 }; 125 entry.insert(ConnectionAttemptData { conn_tx: None, id }); 126 127 Ok(id) 128 } 129 130 /// Cancel connection attempts with the specified mode from this client to the specified address. cancel_attempt( &mut self, client: ConnectionManagerClient, address: AddressWithType, mode: ConnectionMode, ) -> Result<(), CancelConnectFailure>131 pub fn cancel_attempt( 132 &mut self, 133 client: ConnectionManagerClient, 134 address: AddressWithType, 135 mode: ConnectionMode, 136 ) -> Result<(), CancelConnectFailure> { 137 let existing = 138 self.attempts.remove(&ConnectionAttempt { client, mode, remote_address: address }); 139 140 if existing.is_some() { 141 // note: dropping the ConnectionAttemptData is sufficient to close the channel and send a cancellation error 142 Ok(()) 143 } else { 144 Err(CancelConnectFailure::ConnectionNotPending) 145 } 146 } 147 148 /// Cancel the connection attempt with the given ID. cancel_attempt_with_id(&mut self, id: AttemptId)149 pub fn cancel_attempt_with_id(&mut self, id: AttemptId) { 150 self.attempts.retain(|_, attempt| attempt.id != id); 151 } 152 153 /// Cancel all connection attempts to this address remove_unconditionally(&mut self, address: AddressWithType)154 pub fn remove_unconditionally(&mut self, address: AddressWithType) { 155 self.attempts.retain(|attempt, _| attempt.remote_address != address); 156 } 157 158 /// Cancel all connection attempts from this client remove_client(&mut self, client: ConnectionManagerClient)159 pub fn remove_client(&mut self, client: ConnectionManagerClient) { 160 self.attempts.retain(|attempt, _| attempt.client != client); 161 } 162 163 /// List all active connection attempts. Note that we can have active background (but NOT) direct 164 /// connection attempts to connected devices, as we will resume the connection attempt when the 165 /// peer disconnects from us. active_attempts(&self) -> Vec<ConnectionAttempt>166 pub fn active_attempts(&self) -> Vec<ConnectionAttempt> { 167 self.attempts.keys().cloned().collect() 168 } 169 170 /// Handle a successful connection by notifying clients and resolving direct connect attempts process_connection( &mut self, address: AddressWithType, result: Result<LeConnection, ErrorCode>, )171 pub fn process_connection( 172 &mut self, 173 address: AddressWithType, 174 result: Result<LeConnection, ErrorCode>, 175 ) { 176 let interested_clients = self 177 .attempts 178 .keys() 179 .filter(|attempt| attempt.remote_address == address) 180 .copied() 181 .collect::<Vec<_>>(); 182 183 for attempt in interested_clients { 184 if attempt.mode == ConnectionMode::Direct { 185 // TODO(aryarahul): clean up these unwraps 186 let _ = self.attempts.remove(&attempt).unwrap().conn_tx.unwrap().send(result); 187 } else { 188 // TODO(aryarahul): inform background clients of the connection 189 } 190 } 191 } 192 } 193 194 #[cfg(test)] 195 mod test { 196 use crate::{ 197 core::address::AddressType, 198 utils::task::{block_on_locally, try_await}, 199 }; 200 201 use super::*; 202 203 const CLIENT_1: ConnectionManagerClient = ConnectionManagerClient::GattClient(1); 204 const CLIENT_2: ConnectionManagerClient = ConnectionManagerClient::GattClient(2); 205 206 const ADDRESS_1: AddressWithType = 207 AddressWithType { address: [1, 2, 3, 4, 5, 6], address_type: AddressType::Public }; 208 const ADDRESS_2: AddressWithType = 209 AddressWithType { address: [1, 2, 3, 4, 5, 6], address_type: AddressType::Random }; 210 211 const CONNECTION_1: LeConnection = LeConnection { remote_address: ADDRESS_1 }; 212 const CONNECTION_2: LeConnection = LeConnection { remote_address: ADDRESS_2 }; 213 214 #[test] test_direct_connection()215 fn test_direct_connection() { 216 block_on_locally(async { 217 // arrange 218 let mut attempts = ConnectionAttempts::new(); 219 220 // act: start a pending direct connection 221 let _ = attempts.register_direct_connection(CLIENT_1, ADDRESS_1).unwrap(); 222 223 // assert: this attempt is pending 224 assert_eq!(attempts.active_attempts().len(), 1); 225 assert_eq!(attempts.active_attempts()[0].client, CLIENT_1); 226 assert_eq!(attempts.active_attempts()[0].mode, ConnectionMode::Direct); 227 assert_eq!(attempts.active_attempts()[0].remote_address, ADDRESS_1); 228 }); 229 } 230 231 #[test] test_cancel_direct_connection()232 fn test_cancel_direct_connection() { 233 block_on_locally(async { 234 // arrange: one pending direct connection 235 let mut attempts = ConnectionAttempts::new(); 236 let pending_direct_connection = 237 attempts.register_direct_connection(CLIENT_1, ADDRESS_1).unwrap(); 238 239 // act: cancel it 240 attempts.cancel_attempt(CLIENT_1, ADDRESS_1, ConnectionMode::Direct).unwrap(); 241 let resp = pending_direct_connection.await; 242 243 // assert: the original future resolved, and the attempt is cleared 244 assert_eq!(resp, Err(ConnectionFailure::Cancelled)); 245 assert!(attempts.active_attempts().is_empty()); 246 }); 247 } 248 249 #[test] test_multiple_direct_connections()250 fn test_multiple_direct_connections() { 251 block_on_locally(async { 252 // arrange 253 let mut attempts = ConnectionAttempts::new(); 254 255 // act: start two direct connections 256 attempts.register_direct_connection(CLIENT_1, ADDRESS_1).unwrap(); 257 attempts.register_direct_connection(CLIENT_2, ADDRESS_1).unwrap(); 258 259 // assert: both attempts are pending 260 assert_eq!(attempts.active_attempts().len(), 2); 261 }); 262 } 263 264 #[test] test_two_direct_connection_cancel_one()265 fn test_two_direct_connection_cancel_one() { 266 block_on_locally(async { 267 // arrange: two pending direct connections 268 let mut attempts = ConnectionAttempts::new(); 269 attempts.register_direct_connection(CLIENT_1, ADDRESS_1).unwrap(); 270 attempts.register_direct_connection(CLIENT_2, ADDRESS_1).unwrap(); 271 272 // act: cancel one 273 attempts.cancel_attempt(CLIENT_1, ADDRESS_1, ConnectionMode::Direct).unwrap(); 274 275 // assert: one attempt is still pending 276 assert_eq!(attempts.active_attempts().len(), 1); 277 assert_eq!(attempts.active_attempts()[0].client, CLIENT_2); 278 }); 279 } 280 281 #[test] test_drop_pending_connection_after_cancel_and_restart()282 fn test_drop_pending_connection_after_cancel_and_restart() { 283 // arrange 284 let mut attempts = ConnectionAttempts::new(); 285 286 // act: start one pending direct connection, cancel it, restart it, and then drop the first future 287 let pending_1 = attempts.register_direct_connection(CLIENT_1, ADDRESS_1).unwrap(); 288 attempts.cancel_attempt(CLIENT_1, ADDRESS_1, ConnectionMode::Direct).unwrap(); 289 let _pending_2 = attempts.register_direct_connection(CLIENT_1, ADDRESS_1).unwrap(); 290 drop(pending_1); 291 292 // assert: the restart is still pending 293 assert_eq!(attempts.active_attempts().len(), 1); 294 } 295 296 #[test] test_background_connection()297 fn test_background_connection() { 298 block_on_locally(async { 299 // arrange 300 let mut attempts = ConnectionAttempts::new(); 301 302 // act: start a pending background connection 303 attempts.register_background_connection(CLIENT_1, ADDRESS_1).unwrap(); 304 305 // assert: this attempt is pending 306 assert_eq!(attempts.active_attempts().len(), 1); 307 assert_eq!(attempts.active_attempts()[0].client, CLIENT_1); 308 assert_eq!(attempts.active_attempts()[0].mode, ConnectionMode::Background); 309 assert_eq!(attempts.active_attempts()[0].remote_address, ADDRESS_1); 310 }); 311 } 312 313 #[test] test_reject_duplicate_direct_connection()314 fn test_reject_duplicate_direct_connection() { 315 block_on_locally(async { 316 // arrange 317 let mut attempts = ConnectionAttempts::new(); 318 319 // act: start two background connections with the same parameters 320 let _fut = attempts.register_direct_connection(CLIENT_1, ADDRESS_1).unwrap(); 321 let ret = attempts.register_direct_connection(CLIENT_1, ADDRESS_1); 322 323 // assert: this attempt is pending 324 assert!(matches!(ret, Err(CreateConnectionFailure::ConnectionAlreadyPending))); 325 }); 326 } 327 328 #[test] test_reject_duplicate_background_connection()329 fn test_reject_duplicate_background_connection() { 330 block_on_locally(async { 331 // arrange 332 let mut attempts = ConnectionAttempts::new(); 333 334 // act: start two background connections with the same parameters 335 attempts.register_background_connection(CLIENT_1, ADDRESS_1).unwrap(); 336 let ret = attempts.register_background_connection(CLIENT_1, ADDRESS_1); 337 338 // assert: this attempt is pending 339 assert_eq!(ret, Err(CreateConnectionFailure::ConnectionAlreadyPending)); 340 }); 341 } 342 343 #[test] test_resolved_direct_connection()344 fn test_resolved_direct_connection() { 345 block_on_locally(async { 346 // arrange: one pending direct connection 347 let mut attempts = ConnectionAttempts::new(); 348 let pending_conn = attempts.register_direct_connection(CLIENT_1, ADDRESS_1).unwrap(); 349 350 // act: resolve with an incoming connection 351 attempts.process_connection(ADDRESS_1, Ok(CONNECTION_1)); 352 353 // assert: the attempt is resolved and is no longer active 354 assert_eq!(pending_conn.await.unwrap(), CONNECTION_1); 355 assert!(attempts.active_attempts().is_empty()); 356 }); 357 } 358 359 #[test] test_failed_direct_connection()360 fn test_failed_direct_connection() { 361 block_on_locally(async { 362 // arrange: one pending direct connection 363 let mut attempts = ConnectionAttempts::new(); 364 let pending_conn = attempts.register_direct_connection(CLIENT_1, ADDRESS_1).unwrap(); 365 366 // act: resolve with an incoming connection 367 attempts.process_connection(ADDRESS_1, Err(ErrorCode(1))); 368 369 // assert: the attempt is resolved and is no longer active 370 assert_eq!(pending_conn.await, Err(ConnectionFailure::Error(ErrorCode(1)))); 371 assert!(attempts.active_attempts().is_empty()); 372 }); 373 } 374 375 #[test] test_resolved_background_connection()376 fn test_resolved_background_connection() { 377 block_on_locally(async { 378 // arrange: one pending direct connection 379 let mut attempts = ConnectionAttempts::new(); 380 attempts.register_background_connection(CLIENT_1, ADDRESS_1).unwrap(); 381 382 // act: resolve with an incoming connection 383 attempts.process_connection(ADDRESS_1, Ok(CONNECTION_1)); 384 385 // assert: the attempt is still active 386 assert_eq!(attempts.active_attempts().len(), 1); 387 }); 388 } 389 390 #[test] test_incoming_connection_while_another_is_pending()391 fn test_incoming_connection_while_another_is_pending() { 392 block_on_locally(async { 393 // arrange: one pending direct connection 394 let mut attempts = ConnectionAttempts::new(); 395 let pending_conn = attempts.register_direct_connection(CLIENT_1, ADDRESS_1).unwrap(); 396 397 // act: an incoming connection arrives to a different address 398 attempts.process_connection(ADDRESS_2, Ok(CONNECTION_2)); 399 400 // assert: the attempt is still pending 401 assert!(try_await(pending_conn).await.is_err()); 402 assert_eq!(attempts.active_attempts().len(), 1); 403 }); 404 } 405 406 #[test] test_incoming_connection_resolves_some_but_not_all()407 fn test_incoming_connection_resolves_some_but_not_all() { 408 block_on_locally(async { 409 // arrange: one pending direct connection and one background connection to each of two addresses 410 let mut attempts = ConnectionAttempts::new(); 411 let pending_conn_1 = attempts.register_direct_connection(CLIENT_1, ADDRESS_1).unwrap(); 412 let pending_conn_2 = attempts.register_direct_connection(CLIENT_1, ADDRESS_2).unwrap(); 413 attempts.register_background_connection(CLIENT_1, ADDRESS_1).unwrap(); 414 attempts.register_background_connection(CLIENT_1, ADDRESS_2).unwrap(); 415 416 // act: an incoming connection arrives to the first address 417 attempts.process_connection(ADDRESS_1, Ok(CONNECTION_1)); 418 419 // assert: one direct attempt is completed, one is still pending 420 assert_eq!(pending_conn_1.await, Ok(CONNECTION_1)); 421 assert!(try_await(pending_conn_2).await.is_err()); 422 // three attempts remain (the unresolved direct, and both background attempts) 423 assert_eq!(attempts.active_attempts().len(), 3); 424 }); 425 } 426 427 #[test] test_remove_background_connection()428 fn test_remove_background_connection() { 429 block_on_locally(async { 430 // arrange: one pending background connection 431 let mut attempts = ConnectionAttempts::new(); 432 attempts.register_background_connection(CLIENT_1, ADDRESS_1).unwrap(); 433 434 // act: remove it 435 attempts.cancel_attempt(CLIENT_1, ADDRESS_1, ConnectionMode::Background).unwrap(); 436 437 // assert: no pending attempts 438 assert!(attempts.active_attempts().is_empty()); 439 }); 440 } 441 442 #[test] test_cancel_nonexistent_connection()443 fn test_cancel_nonexistent_connection() { 444 block_on_locally(async { 445 // arrange 446 let mut attempts = ConnectionAttempts::new(); 447 448 // act: cancel a nonexistent direct connection 449 let resp = attempts.cancel_attempt(CLIENT_1, ADDRESS_1, ConnectionMode::Direct); 450 451 // assert: got an error 452 assert_eq!(resp, Err(CancelConnectFailure::ConnectionNotPending)); 453 }); 454 } 455 456 #[test] test_remove_unconditionally()457 fn test_remove_unconditionally() { 458 block_on_locally(async { 459 // arrange: one pending direct connection, and one background connection, to each address 460 let mut attempts = ConnectionAttempts::new(); 461 let pending_conn_1 = attempts.register_direct_connection(CLIENT_1, ADDRESS_1).unwrap(); 462 let pending_conn_2 = attempts.register_direct_connection(CLIENT_1, ADDRESS_2).unwrap(); 463 attempts.register_background_connection(CLIENT_1, ADDRESS_1).unwrap(); 464 attempts.register_background_connection(CLIENT_1, ADDRESS_2).unwrap(); 465 466 // act: cancel all connections to the first address 467 attempts.remove_unconditionally(ADDRESS_1); 468 469 // assert: one direct attempt is completed, one is still pending 470 assert_eq!(pending_conn_1.await, Err(ConnectionFailure::Cancelled)); 471 assert!(try_await(pending_conn_2).await.is_err()); 472 // assert: two attempts remain, both to the other address 473 assert_eq!(attempts.active_attempts().len(), 2); 474 assert_eq!(attempts.active_attempts()[0].remote_address, ADDRESS_2); 475 assert_eq!(attempts.active_attempts()[1].remote_address, ADDRESS_2); 476 }); 477 } 478 479 #[test] test_remove_client()480 fn test_remove_client() { 481 block_on_locally(async { 482 // arrange: one pending direct connection, and one background connection, from each address 483 let mut attempts = ConnectionAttempts::new(); 484 let pending_conn_1 = attempts.register_direct_connection(CLIENT_1, ADDRESS_1).unwrap(); 485 let pending_conn_2 = attempts.register_direct_connection(CLIENT_2, ADDRESS_1).unwrap(); 486 attempts.register_background_connection(CLIENT_1, ADDRESS_1).unwrap(); 487 attempts.register_background_connection(CLIENT_2, ADDRESS_1).unwrap(); 488 489 // act: remove the first client 490 attempts.remove_client(CLIENT_1); 491 492 // assert: one direct attempt is completed, one is still pending 493 assert_eq!(pending_conn_1.await, Err(ConnectionFailure::Cancelled)); 494 assert!(try_await(pending_conn_2).await.is_err()); 495 // assert: two attempts remain, both from the second client 496 assert_eq!(attempts.active_attempts().len(), 2); 497 assert_eq!(attempts.active_attempts()[0].client, CLIENT_2); 498 assert_eq!(attempts.active_attempts()[1].client, CLIENT_2); 499 }); 500 } 501 } 502