1 //! Provides utilities for managing callbacks. 2 3 use std::collections::HashMap; 4 use tokio::sync::mpsc::Sender; 5 6 use crate::{Message, RPCProxy}; 7 8 /// Utility for managing callbacks conveniently. 9 pub struct Callbacks<T: Send + ?Sized> { 10 callbacks: HashMap<u32, Box<T>>, 11 object_id_to_cbid: HashMap<String, u32>, 12 tx: Sender<Message>, 13 disconnected_message: fn(u32) -> Message, 14 } 15 16 impl<T: RPCProxy + Send + ?Sized> Callbacks<T> { 17 /// Creates new Callbacks. 18 /// 19 /// Parameters: 20 /// `tx`: Sender to use when notifying callback disconnect events. 21 /// `disconnected_message`: Constructor of the message to be sent on callback disconnection. new(tx: Sender<Message>, disconnected_message: fn(u32) -> Message) -> Self22 pub fn new(tx: Sender<Message>, disconnected_message: fn(u32) -> Message) -> Self { 23 Self { 24 callbacks: HashMap::new(), 25 object_id_to_cbid: HashMap::new(), 26 tx, 27 disconnected_message, 28 } 29 } 30 31 /// Stores a new callback and monitors for callback disconnect. If the callback object id 32 /// already exists, return the callback ID previously added. 33 /// 34 /// When the callback disconnects, a message is sent. This message should be handled and then 35 /// the `remove_callback` function can be used. 36 /// 37 /// Returns the id of the callback. add_callback(&mut self, mut callback: Box<T>) -> u3238 pub fn add_callback(&mut self, mut callback: Box<T>) -> u32 { 39 if let Some(cbid) = self.object_id_to_cbid.get(&callback.get_object_id()) { 40 return *cbid; 41 } 42 43 let tx = self.tx.clone(); 44 let disconnected_message = self.disconnected_message; 45 let id = callback.register_disconnect(Box::new(move |cb_id| { 46 let tx = tx.clone(); 47 tokio::spawn(async move { 48 let _result = tx.send(disconnected_message(cb_id)).await; 49 }); 50 })); 51 52 self.object_id_to_cbid.insert(callback.get_object_id(), id); 53 self.callbacks.insert(id, callback); 54 id 55 } 56 57 /// Removes the callback given the id. 58 /// 59 /// When a callback is removed, disconnect monitoring is stopped and the proxy object is 60 /// removed. 61 /// 62 /// Returns true if callback is removed, false if there is no such id. remove_callback(&mut self, id: u32) -> bool63 pub fn remove_callback(&mut self, id: u32) -> bool { 64 match self.callbacks.get_mut(&id) { 65 Some(callback) => { 66 // Stop watching for disconnect. 67 callback.unregister(id); 68 // Remove the proxy object. 69 self.object_id_to_cbid.remove(&callback.get_object_id()); 70 self.callbacks.remove(&id); 71 true 72 } 73 None => false, 74 } 75 } 76 77 /// Returns the callback object based on the given id. get_by_id(&self, id: u32) -> Option<&Box<T>>78 pub fn get_by_id(&self, id: u32) -> Option<&Box<T>> { 79 self.callbacks.get(&id) 80 } 81 82 /// Returns the mut callback object based on the given id. get_by_id_mut(&mut self, id: u32) -> Option<&mut Box<T>>83 pub fn get_by_id_mut(&mut self, id: u32) -> Option<&mut Box<T>> { 84 self.callbacks.get_mut(&id) 85 } 86 87 /// Applies the given function on all active callbacks. for_all_callbacks<F: Fn(&mut Box<T>)>(&mut self, f: F)88 pub fn for_all_callbacks<F: Fn(&mut Box<T>)>(&mut self, f: F) { 89 for (_, ref mut callback) in self.callbacks.iter_mut() { 90 f(callback); 91 } 92 } 93 } 94 95 #[cfg(test)] 96 mod tests { 97 use std::sync::atomic::{AtomicU32, Ordering}; 98 99 static CBID: AtomicU32 = AtomicU32::new(0); 100 101 struct TestCallback { 102 id: String, 103 } 104 105 impl TestCallback { new(id: String) -> TestCallback106 fn new(id: String) -> TestCallback { 107 TestCallback { id } 108 } 109 } 110 111 impl RPCProxy for TestCallback { get_object_id(&self) -> String112 fn get_object_id(&self) -> String { 113 self.id.clone() 114 } register_disconnect(&mut self, _f: Box<dyn Fn(u32) + Send>) -> u32115 fn register_disconnect(&mut self, _f: Box<dyn Fn(u32) + Send>) -> u32 { 116 CBID.fetch_add(1, Ordering::SeqCst) 117 } 118 } 119 120 use super::*; 121 122 #[test] test_add_and_remove()123 fn test_add_and_remove() { 124 let (tx, _rx) = crate::Stack::create_channel(); 125 let mut callbacks = Callbacks::new(tx.clone(), Message::AdapterCallbackDisconnected); 126 127 let cb_string = String::from("Test Callback"); 128 129 // Test add 130 let cbid = callbacks.add_callback(Box::new(TestCallback::new(cb_string.clone()))); 131 let found = callbacks.get_by_id(cbid); 132 assert!(found.is_some()); 133 assert_eq!( 134 cb_string, 135 match found { 136 Some(c) => c.get_object_id(), 137 None => String::new(), 138 } 139 ); 140 141 // Attempting to add another callback with same object id should return the same cbid 142 let cbid1 = callbacks.add_callback(Box::new(TestCallback::new(cb_string.clone()))); 143 assert_eq!(cbid, cbid1); 144 145 // Test remove 146 let success = callbacks.remove_callback(cbid); 147 assert!(success); 148 let found = callbacks.get_by_id(cbid); 149 assert!(found.is_none()); 150 151 // Attempting to add another callback with same object id should now return a new cbid 152 let cbid2 = callbacks.add_callback(Box::new(TestCallback::new(cb_string.clone()))); 153 assert_ne!(cbid, cbid2); 154 } 155 } 156