1 //! Core dependency injection objects
2 
3 use std::any::{Any, TypeId};
4 use std::collections::HashMap;
5 use std::future::Future;
6 use std::pin::Pin;
7 use std::sync::Arc;
8 use tokio::sync::Mutex;
9 
10 pub use gddi_macros::{module, part_out, provides, Stoppable};
11 
12 type InstanceBox = Box<dyn Any + Send + Sync>;
13 /// A box around a future for a provider that is safe to send between threads
14 pub type ProviderFutureBox = Box<dyn Future<Output = Box<dyn Any>> + Send + Sync>;
15 type ProviderFnBox = Box<dyn Fn(Arc<Registry>) -> Pin<ProviderFutureBox> + Send + Sync>;
16 
17 /// Called to stop an injected object
18 pub trait Stoppable {
19     /// Stop and close all resources
stop(&self)20     fn stop(&self) {}
21 }
22 
23 /// Builder for Registry
24 pub struct RegistryBuilder {
25     providers: HashMap<TypeId, Provider>,
26 }
27 
28 /// Keeps track of central injection state
29 pub struct Registry {
30     providers: Arc<Mutex<HashMap<TypeId, Provider>>>,
31     instances: Arc<Mutex<HashMap<TypeId, InstanceBox>>>,
32     start_order: Arc<Mutex<Vec<Box<dyn Stoppable + Send + Sync>>>>,
33 }
34 
35 #[derive(Clone)]
36 struct Provider {
37     f: Arc<ProviderFnBox>,
38 }
39 
40 impl Default for RegistryBuilder {
default() -> Self41     fn default() -> Self {
42         Self::new()
43     }
44 }
45 
46 impl RegistryBuilder {
47     /// Creates a new RegistryBuilder
new() -> Self48     pub fn new() -> Self {
49         RegistryBuilder { providers: HashMap::new() }
50     }
51 
52     /// Registers a module with this registry
register_module<F>(self, init: F) -> Self where F: Fn(Self) -> Self,53     pub fn register_module<F>(self, init: F) -> Self
54     where
55         F: Fn(Self) -> Self,
56     {
57         init(self)
58     }
59 
60     /// Registers a provider function with this registry
register_provider<T: 'static>(mut self, f: ProviderFnBox) -> Self61     pub fn register_provider<T: 'static>(mut self, f: ProviderFnBox) -> Self {
62         self.providers.insert(TypeId::of::<T>(), Provider { f: Arc::new(f) });
63 
64         self
65     }
66 
67     /// Construct the Registry from this builder
build(self) -> Registry68     pub fn build(self) -> Registry {
69         Registry {
70             providers: Arc::new(Mutex::new(self.providers)),
71             instances: Arc::new(Mutex::new(HashMap::new())),
72             start_order: Arc::new(Mutex::new(Vec::new())),
73         }
74     }
75 }
76 
77 impl Registry {
78     /// Gets an instance of a type, implicitly starting any dependencies if necessary
get<T: 'static + Clone + Send + Sync + Stoppable>(self: &Arc<Self>) -> T79     pub async fn get<T: 'static + Clone + Send + Sync + Stoppable>(self: &Arc<Self>) -> T {
80         let typeid = TypeId::of::<T>();
81         {
82             let instances = self.instances.lock().await;
83             if let Some(value) = instances.get(&typeid) {
84                 return value.downcast_ref::<T>().expect("was not correct type").clone();
85             }
86         }
87 
88         let casted = {
89             let provider = { self.providers.lock().await[&typeid].clone() };
90             let result = (provider.f)(self.clone()).await;
91             (*result.downcast::<T>().expect("was not correct type")).clone()
92         };
93 
94         let mut instances = self.instances.lock().await;
95         instances.insert(typeid, Box::new(casted.clone()));
96 
97         let mut start_order = self.start_order.lock().await;
98         start_order.push(Box::new(casted.clone()));
99 
100         casted
101     }
102 
103     /// Inject an already created instance of T. Useful for config.
inject<T: 'static + Clone + Send + Sync>(self: &Arc<Self>, obj: T)104     pub async fn inject<T: 'static + Clone + Send + Sync>(self: &Arc<Self>, obj: T) {
105         let mut instances = self.instances.lock().await;
106         instances.insert(TypeId::of::<T>(), Box::new(obj));
107     }
108 
109     /// Stop all instances, in reverse order of start.
stop_all(self: &Arc<Self>)110     pub async fn stop_all(self: &Arc<Self>) {
111         let mut start_order = self.start_order.lock().await;
112         while let Some(obj) = start_order.pop() {
113             obj.stop();
114         }
115         self.instances.lock().await.clear();
116     }
117 }
118 
119 impl<T> Stoppable for std::sync::Arc<T> {}
120