1 //! Provides utilities for managing callbacks.
2 
3 use std::collections::HashMap;
4 use tokio::sync::mpsc::Sender;
5 
6 use crate::{Message, RPCProxy};
7 
8 /// Utility for managing callbacks conveniently.
9 pub struct Callbacks<T: Send + ?Sized> {
10     callbacks: HashMap<u32, Box<T>>,
11     object_id_to_cbid: HashMap<String, u32>,
12     tx: Sender<Message>,
13     disconnected_message: fn(u32) -> Message,
14 }
15 
16 impl<T: RPCProxy + Send + ?Sized> Callbacks<T> {
17     /// Creates new Callbacks.
18     ///
19     /// Parameters:
20     /// `tx`: Sender to use when notifying callback disconnect events.
21     /// `disconnected_message`: Constructor of the message to be sent on callback disconnection.
new(tx: Sender<Message>, disconnected_message: fn(u32) -> Message) -> Self22     pub fn new(tx: Sender<Message>, disconnected_message: fn(u32) -> Message) -> Self {
23         Self {
24             callbacks: HashMap::new(),
25             object_id_to_cbid: HashMap::new(),
26             tx,
27             disconnected_message,
28         }
29     }
30 
31     /// Stores a new callback and monitors for callback disconnect. If the callback object id
32     /// already exists, return the callback ID previously added.
33     ///
34     /// When the callback disconnects, a message is sent. This message should be handled and then
35     /// the `remove_callback` function can be used.
36     ///
37     /// Returns the id of the callback.
add_callback(&mut self, mut callback: Box<T>) -> u3238     pub fn add_callback(&mut self, mut callback: Box<T>) -> u32 {
39         if let Some(cbid) = self.object_id_to_cbid.get(&callback.get_object_id()) {
40             return *cbid;
41         }
42 
43         let tx = self.tx.clone();
44         let disconnected_message = self.disconnected_message;
45         let id = callback.register_disconnect(Box::new(move |cb_id| {
46             let tx = tx.clone();
47             tokio::spawn(async move {
48                 let _result = tx.send(disconnected_message(cb_id)).await;
49             });
50         }));
51 
52         self.object_id_to_cbid.insert(callback.get_object_id(), id);
53         self.callbacks.insert(id, callback);
54         id
55     }
56 
57     /// Removes the callback given the id.
58     ///
59     /// When a callback is removed, disconnect monitoring is stopped and the proxy object is
60     /// removed.
61     ///
62     /// Returns true if callback is removed, false if there is no such id.
remove_callback(&mut self, id: u32) -> bool63     pub fn remove_callback(&mut self, id: u32) -> bool {
64         match self.callbacks.get_mut(&id) {
65             Some(callback) => {
66                 // Stop watching for disconnect.
67                 callback.unregister(id);
68                 // Remove the proxy object.
69                 self.object_id_to_cbid.remove(&callback.get_object_id());
70                 self.callbacks.remove(&id);
71                 true
72             }
73             None => false,
74         }
75     }
76 
77     /// Returns the callback object based on the given id.
get_by_id(&self, id: u32) -> Option<&Box<T>>78     pub fn get_by_id(&self, id: u32) -> Option<&Box<T>> {
79         self.callbacks.get(&id)
80     }
81 
82     /// Returns the mut callback object based on the given id.
get_by_id_mut(&mut self, id: u32) -> Option<&mut Box<T>>83     pub fn get_by_id_mut(&mut self, id: u32) -> Option<&mut Box<T>> {
84         self.callbacks.get_mut(&id)
85     }
86 
87     /// Applies the given function on all active callbacks.
for_all_callbacks<F: Fn(&mut Box<T>)>(&mut self, f: F)88     pub fn for_all_callbacks<F: Fn(&mut Box<T>)>(&mut self, f: F) {
89         for (_, ref mut callback) in self.callbacks.iter_mut() {
90             f(callback);
91         }
92     }
93 }
94 
95 #[cfg(test)]
96 mod tests {
97     use std::sync::atomic::{AtomicU32, Ordering};
98 
99     static CBID: AtomicU32 = AtomicU32::new(0);
100 
101     struct TestCallback {
102         id: String,
103     }
104 
105     impl TestCallback {
new(id: String) -> TestCallback106         fn new(id: String) -> TestCallback {
107             TestCallback { id }
108         }
109     }
110 
111     impl RPCProxy for TestCallback {
get_object_id(&self) -> String112         fn get_object_id(&self) -> String {
113             self.id.clone()
114         }
register_disconnect(&mut self, _f: Box<dyn Fn(u32) + Send>) -> u32115         fn register_disconnect(&mut self, _f: Box<dyn Fn(u32) + Send>) -> u32 {
116             CBID.fetch_add(1, Ordering::SeqCst)
117         }
118     }
119 
120     use super::*;
121 
122     #[test]
test_add_and_remove()123     fn test_add_and_remove() {
124         let (tx, _rx) = crate::Stack::create_channel();
125         let mut callbacks = Callbacks::new(tx.clone(), Message::AdapterCallbackDisconnected);
126 
127         let cb_string = String::from("Test Callback");
128 
129         // Test add
130         let cbid = callbacks.add_callback(Box::new(TestCallback::new(cb_string.clone())));
131         let found = callbacks.get_by_id(cbid);
132         assert!(found.is_some());
133         assert_eq!(
134             cb_string,
135             match found {
136                 Some(c) => c.get_object_id(),
137                 None => String::new(),
138             }
139         );
140 
141         // Attempting to add another callback with same object id should return the same cbid
142         let cbid1 = callbacks.add_callback(Box::new(TestCallback::new(cb_string.clone())));
143         assert_eq!(cbid, cbid1);
144 
145         // Test remove
146         let success = callbacks.remove_callback(cbid);
147         assert!(success);
148         let found = callbacks.get_by_id(cbid);
149         assert!(found.is_none());
150 
151         // Attempting to add another callback with same object id should now return a new cbid
152         let cbid2 = callbacks.add_callback(Box::new(TestCallback::new(cb_string.clone())));
153         assert_ne!(cbid, cbid2);
154     }
155 }
156