1 // Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. See License.txt in the project root for license information.
2 
3 #pragma once
4 
5 /*! \file rx-ref_count.hpp
6 
7     \brief  Make some \c connectable_observable behave like an ordinary \c observable.
8             Uses a reference count of the subscribers to control the connection to the published observable.
9 
10             The first subscription will cause a call to \c connect(), and the last \c unsubscribe will unsubscribe the connection.
11 
12             There are 2 variants of the operator:
13             \li \c ref_count(): calls \c connect on the \c source \c connectable_observable.
14             \li \c ref_count(other): calls \c connect on the \c other \c connectable_observable.
15 
16     \tparam ConnectableObservable the type of the \c other \c connectable_observable (optional)
17     \param  other \c connectable_observable to call \c connect on (optional)
18 
19     If \c other is omitted, then \c source is used instead (which must be a \c connectable_observable).
20     Otherwise, \c source can be a regular \c observable.
21 
22     \return An \c observable that emits the items from its \c source.
23 
24     \sample
25     \snippet ref_count.cpp ref_count other diamond sample
26     \snippet output.txt ref_count other diamond sample
27  */
28 
29 #if !defined(RXCPP_OPERATORS_RX_REF_COUNT_HPP)
30 #define RXCPP_OPERATORS_RX_REF_COUNT_HPP
31 
32 #include "../rx-includes.hpp"
33 
34 namespace rxcpp {
35 
36 namespace operators {
37 
38 namespace detail {
39 
40 template<class... AN>
41 struct ref_count_invalid_arguments {};
42 
43 template<class... AN>
44 struct ref_count_invalid : public rxo::operator_base<ref_count_invalid_arguments<AN...>> {
45     using type = observable<ref_count_invalid_arguments<AN...>, ref_count_invalid<AN...>>;
46 };
47 template<class... AN>
48 using ref_count_invalid_t = typename ref_count_invalid<AN...>::type;
49 
50 // ref_count(other) takes a regular observable source, not a connectable_observable.
51 // use template specialization to avoid instantiating 'subscribe' for two different types
52 // which would cause a compilation error.
53 template <typename connectable_type, typename observable_type>
54 struct ref_count_state_base {
ref_count_state_baserxcpp::operators::detail::ref_count_state_base55     ref_count_state_base(connectable_type other, observable_type source)
56         : connectable(std::move(other))
57         , subscribable(std::move(source)) {}
58 
59     connectable_type connectable; // connects to this. subscribes to this if subscribable empty.
60     observable_type subscribable; // subscribes to this if non-empty.
61 
62     template <typename Subscriber>
subscriberxcpp::operators::detail::ref_count_state_base63     void subscribe(Subscriber&& o) {
64         subscribable.subscribe(std::forward<Subscriber>(o));
65     }
66 };
67 
68 // Note: explicit specializations have to be at namespace scope prior to C++17.
69 template <typename connectable_type>
70 struct ref_count_state_base<connectable_type, void> {
ref_count_state_baserxcpp::operators::detail::ref_count_state_base71     explicit ref_count_state_base(connectable_type c)
72         : connectable(std::move(c)) {}
73 
74     connectable_type connectable; // connects to this. subscribes to this if subscribable empty.
75 
76     template <typename Subscriber>
subscriberxcpp::operators::detail::ref_count_state_base77     void subscribe(Subscriber&& o) {
78         connectable.subscribe(std::forward<Subscriber>(o));
79     }
80 };
81 
82 template<class T,
83          class ConnectableObservable,
84          class Observable = void> // note: type order flipped versus the operator.
85 struct ref_count : public operator_base<T>
86 {
87     typedef rxu::decay_t<Observable> observable_type;
88     typedef rxu::decay_t<ConnectableObservable> connectable_type;
89 
90     // ref_count() == false
91     // ref_count(other) == true
92     using has_observable_t = rxu::negation<std::is_same<void, Observable>>;
93     static constexpr bool has_observable_v = has_observable_t::value;
94 
95     struct ref_count_state : public std::enable_shared_from_this<ref_count_state>,
96                              public ref_count_state_base<ConnectableObservable, Observable>
97     {
98         template <class HasObservable = has_observable_t,
99                   class Enabled = rxu::enable_if_all_true_type_t<
100                       rxu::negation<HasObservable>>>
ref_count_staterxcpp::operators::detail::ref_count::ref_count_state101         explicit ref_count_state(connectable_type source)
102             : ref_count_state_base<ConnectableObservable, Observable>(std::move(source))
103             , subscribers(0)
104         {
105         }
106 
107         template <bool HasObservableV = has_observable_v>
ref_count_staterxcpp::operators::detail::ref_count::ref_count_state108         ref_count_state(connectable_type other,
109                         typename std::enable_if<HasObservableV, observable_type>::type source)
110             : ref_count_state_base<ConnectableObservable, Observable>(std::move(other),
111                                                                       std::move(source))
112             , subscribers(0)
113         {
114         }
115 
116         std::mutex lock;
117         long subscribers;
118         composite_subscription connection;
119     };
120     std::shared_ptr<ref_count_state> state;
121 
122     // connectable_observable<T> source = ...;
123     // source.ref_count();
124     //
125     // calls connect on source after the subscribe on source.
126     template <class HasObservable = has_observable_t,
127               class Enabled = rxu::enable_if_all_true_type_t<
128                   rxu::negation<HasObservable>>>
ref_countrxcpp::operators::detail::ref_count129     explicit ref_count(connectable_type source)
130         : state(std::make_shared<ref_count_state>(std::move(source)))
131     {
132     }
133 
134     // connectable_observable<?> other = ...;
135     // observable<T> source = ...;
136     // source.ref_count(other);
137     //
138     // calls connect on 'other' after the subscribe on 'source'.
139     template <bool HasObservableV = has_observable_v>
ref_countrxcpp::operators::detail::ref_count140     ref_count(connectable_type other,
141               typename std::enable_if<HasObservableV, observable_type>::type source)
142         : state(std::make_shared<ref_count_state>(std::move(other), std::move(source)))
143     {
144     }
145 
146     template<class Subscriber>
on_subscriberxcpp::operators::detail::ref_count147     void on_subscribe(Subscriber&& o) const {
148         std::unique_lock<std::mutex> guard(state->lock);
149         auto needConnect = ++state->subscribers == 1;
150         auto keepAlive = state;
151         guard.unlock();
152         o.add(
153             [keepAlive](){
154                 std::unique_lock<std::mutex> guard_unsubscribe(keepAlive->lock);
155                 if (--keepAlive->subscribers == 0) {
156                     keepAlive->connection.unsubscribe();
157                     keepAlive->connection = composite_subscription();
158                 }
159             });
160         keepAlive->subscribe(std::forward<Subscriber>(o));
161         if (needConnect) {
162             keepAlive->connectable.connect(keepAlive->connection);
163         }
164     }
165 };
166 
167 }
168 
169 /*! @copydoc rx-ref_count.hpp
170 */
171 template<class... AN>
ref_count(AN &&...an)172 auto ref_count(AN&&... an)
173     ->     operator_factory<ref_count_tag, AN...> {
174     return operator_factory<ref_count_tag, AN...>(std::make_tuple(std::forward<AN>(an)...));
175 }
176 
177 }
178 
179 template<>
180 struct member_overload<ref_count_tag>
181 {
182     template<class ConnectableObservable,
183         class Enabled = rxu::enable_if_all_true_type_t<
184             is_connectable_observable<ConnectableObservable>>,
185         class SourceValue = rxu::value_type_t<ConnectableObservable>,
186         class RefCount = rxo::detail::ref_count<SourceValue, rxu::decay_t<ConnectableObservable>>,
187         class Value = rxu::value_type_t<RefCount>,
188         class Result = observable<Value, RefCount>
189         >
memberrxcpp::member_overload190     static Result member(ConnectableObservable&& o) {
191         return Result(RefCount(std::forward<ConnectableObservable>(o)));
192     }
193 
194     template<class Observable,
195         class ConnectableObservable,
196         class Enabled = rxu::enable_if_all_true_type_t<
197             is_observable<Observable>,
198             is_connectable_observable<ConnectableObservable>>,
199         class SourceValue = rxu::value_type_t<Observable>,
200         class RefCount = rxo::detail::ref_count<SourceValue,
201             rxu::decay_t<ConnectableObservable>,
202             rxu::decay_t<Observable>>,
203         class Value = rxu::value_type_t<RefCount>,
204         class Result = observable<Value, RefCount>
205         >
memberrxcpp::member_overload206     static Result member(Observable&& o, ConnectableObservable&& other) {
207         return Result(RefCount(std::forward<ConnectableObservable>(other),
208                                std::forward<Observable>(o)));
209     }
210 
211     template<class... AN>
memberrxcpp::member_overload212     static operators::detail::ref_count_invalid_t<AN...> member(AN...) {
213         std::terminate();
214         return {};
215         static_assert(sizeof...(AN) == 10000, "ref_count takes (optional ConnectableObservable)");
216     }
217 };
218 
219 }
220 
221 #endif
222