1 // Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. See License.txt in the project root for license information.
2 
3 #pragma once
4 
5 #if !defined(RXCPP_RX_SUBJECT_HPP)
6 #define RXCPP_RX_SUBJECT_HPP
7 
8 #include "../rx-includes.hpp"
9 
10 namespace rxcpp {
11 
12 namespace subjects {
13 
14 namespace detail {
15 
16 template<class T>
17 class multicast_observer
18 {
19     typedef subscriber<T> observer_type;
20     typedef std::vector<observer_type> list_type;
21 
22     struct mode
23     {
24         enum type {
25             Invalid = 0,
26             Casting,
27             Disposed,
28             Completed,
29             Errored
30         };
31     };
32 
33     struct state_type
34         : public std::enable_shared_from_this<state_type>
35     {
state_typerxcpp::subjects::detail::multicast_observer::state_type36         explicit state_type(composite_subscription cs)
37             : current(mode::Casting)
38             , lifetime(cs)
39         {
40         }
41         std::mutex lock;
42         typename mode::type current;
43         rxu::error_ptr error;
44         composite_subscription lifetime;
45     };
46 
47     struct completer_type
48         : public std::enable_shared_from_this<completer_type>
49     {
~completer_typerxcpp::subjects::detail::multicast_observer::completer_type50         ~completer_type()
51         {
52         }
completer_typerxcpp::subjects::detail::multicast_observer::completer_type53         completer_type(std::shared_ptr<state_type> s, const std::shared_ptr<completer_type>& old, observer_type o)
54             : state(s)
55         {
56             retain(old);
57             observers.push_back(o);
58         }
completer_typerxcpp::subjects::detail::multicast_observer::completer_type59         completer_type(std::shared_ptr<state_type> s, const std::shared_ptr<completer_type>& old)
60             : state(s)
61         {
62             retain(old);
63         }
retainrxcpp::subjects::detail::multicast_observer::completer_type64         void retain(const std::shared_ptr<completer_type>& old) {
65             if (old) {
66                 observers.reserve(old->observers.size() + 1);
67                 std::copy_if(
68                     old->observers.begin(), old->observers.end(),
69                     std::inserter(observers, observers.end()),
70                     [](const observer_type& o){
71                         return o.is_subscribed();
72                     });
73             }
74         }
75         std::shared_ptr<state_type> state;
76         list_type observers;
77     };
78 
79     // this type prevents a circular ref between state and completer
80     struct binder_type
81         : public std::enable_shared_from_this<binder_type>
82     {
binder_typerxcpp::subjects::detail::multicast_observer::binder_type83         explicit binder_type(composite_subscription cs)
84             : state(std::make_shared<state_type>(cs))
85             , id(trace_id::make_next_id_subscriber())
86         {
87         }
88 
89         std::shared_ptr<state_type> state;
90 
91         trace_id id;
92 
93         // used to avoid taking lock in on_next
94         mutable std::weak_ptr<completer_type> current_completer;
95 
96         // must only be accessed under state->lock
97         mutable std::shared_ptr<completer_type> completer;
98     };
99 
100     std::shared_ptr<binder_type> b;
101 
102 public:
103     typedef subscriber<T, observer<T, detail::multicast_observer<T>>> input_subscriber_type;
104 
multicast_observer(composite_subscription cs)105     explicit multicast_observer(composite_subscription cs)
106         : b(std::make_shared<binder_type>(cs))
107     {
108         std::weak_ptr<binder_type> binder = b;
109         b->state->lifetime.add([binder](){
110             auto b = binder.lock();
111             if (b && b->state->current == mode::Casting){
112                 b->state->current = mode::Disposed;
113                 b->current_completer.reset();
114                 b->completer.reset();
115             }
116         });
117     }
get_id() const118     trace_id get_id() const {
119         return b->id;
120     }
get_subscription() const121     composite_subscription get_subscription() const {
122         return b->state->lifetime;
123     }
get_subscriber() const124     input_subscriber_type get_subscriber() const {
125         return make_subscriber<T>(get_id(), get_subscription(), observer<T, detail::multicast_observer<T>>(*this));
126     }
has_observers() const127     bool has_observers() const {
128         std::unique_lock<std::mutex> guard(b->state->lock);
129         return b->completer && !b->completer->observers.empty();
130     }
131     template<class SubscriberFrom>
add(const SubscriberFrom & sf,observer_type o) const132     void add(const SubscriberFrom& sf, observer_type o) const {
133         trace_activity().connect(sf, o);
134         std::unique_lock<std::mutex> guard(b->state->lock);
135         switch (b->state->current) {
136         case mode::Casting:
137             {
138                 if (o.is_subscribed()) {
139                     std::weak_ptr<binder_type> binder = b;
140                     o.add([=](){
141                         auto b = binder.lock();
142                         if (b) {
143                             std::unique_lock<std::mutex> guard(b->state->lock);
144                             b->completer = std::make_shared<completer_type>(b->state, b->completer);
145                         }
146                     });
147                     b->completer = std::make_shared<completer_type>(b->state, b->completer, o);
148                 }
149             }
150             break;
151         case mode::Completed:
152             {
153                 guard.unlock();
154                 o.on_completed();
155                 return;
156             }
157             break;
158         case mode::Errored:
159             {
160                 auto e = b->state->error;
161                 guard.unlock();
162                 o.on_error(e);
163                 return;
164             }
165             break;
166         case mode::Disposed:
167             {
168                 guard.unlock();
169                 o.unsubscribe();
170                 return;
171             }
172             break;
173         default:
174             std::terminate();
175         }
176     }
177     template<class V>
on_next(V v) const178     void on_next(V v) const {
179         auto current_completer = b->current_completer.lock();
180         if (!current_completer) {
181             std::unique_lock<std::mutex> guard(b->state->lock);
182             b->current_completer = b->completer;
183             current_completer = b->current_completer.lock();
184         }
185         if (!current_completer || current_completer->observers.empty()) {
186             return;
187         }
188         for (auto& o : current_completer->observers) {
189             if (o.is_subscribed()) {
190                 o.on_next(v);
191             }
192         }
193     }
on_error(rxu::error_ptr e) const194     void on_error(rxu::error_ptr e) const {
195         std::unique_lock<std::mutex> guard(b->state->lock);
196         if (b->state->current == mode::Casting) {
197             b->state->error = e;
198             b->state->current = mode::Errored;
199             auto s = b->state->lifetime;
200             auto c = std::move(b->completer);
201             b->current_completer.reset();
202             guard.unlock();
203             if (c) {
204                 for (auto& o : c->observers) {
205                     if (o.is_subscribed()) {
206                         o.on_error(e);
207                     }
208                 }
209             }
210             s.unsubscribe();
211         }
212     }
on_completed() const213     void on_completed() const {
214         std::unique_lock<std::mutex> guard(b->state->lock);
215         if (b->state->current == mode::Casting) {
216             b->state->current = mode::Completed;
217             auto s = b->state->lifetime;
218             auto c = std::move(b->completer);
219             b->current_completer.reset();
220             guard.unlock();
221             if (c) {
222                 for (auto& o : c->observers) {
223                     if (o.is_subscribed()) {
224                         o.on_completed();
225                     }
226                 }
227             }
228             s.unsubscribe();
229         }
230     }
231 };
232 
233 
234 }
235 
236 template<class T>
237 class subject
238 {
239     detail::multicast_observer<T> s;
240 
241 public:
242     typedef subscriber<T, observer<T, detail::multicast_observer<T>>> subscriber_type;
243     typedef observable<T> observable_type;
subject()244     subject()
245         : s(composite_subscription())
246     {
247     }
subject(composite_subscription cs)248     explicit subject(composite_subscription cs)
249         : s(cs)
250     {
251     }
252 
has_observers() const253     bool has_observers() const {
254         return s.has_observers();
255     }
256 
get_subscription() const257     composite_subscription get_subscription() const {
258         return s.get_subscription();
259     }
260 
get_subscriber() const261     subscriber_type get_subscriber() const {
262         return s.get_subscriber();
263     }
264 
get_observable() const265     observable<T> get_observable() const {
266         auto keepAlive = s;
267         return make_observable_dynamic<T>([=](subscriber<T> o){
268             keepAlive.add(keepAlive.get_subscriber(), std::move(o));
269         });
270     }
271 };
272 
273 }
274 
275 }
276 
277 #endif
278