1 // Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. See License.txt in the project root for license information.
2 
3 #pragma once
4 
5 /*! \file rx-group_by.hpp
6 
7     \brief Return an observable that emits grouped_observables, each of which corresponds to a unique key value and each of which emits those items from the source observable that share that key value.
8 
9     \tparam KeySelector      the type of the key extracting function
10     \tparam MarbleSelector   the type of the element extracting function
11     \tparam BinaryPredicate  the type of the key comparing function
12     \tparam DurationSelector the type of the duration observable function
13 
14     \param  ks  a function that extracts the key for each item (optional)
15     \param  ms  a function that extracts the return element for each item (optional)
16     \param  p   a function that implements comparison of two keys (optional)
17 
18     \return  Observable that emits values of grouped_observable type, each of which corresponds to a unique key value and each of which emits those items from the source observable that share that key value.
19 
20     \sample
21     \snippet group_by.cpp group_by full intro
22     \snippet group_by.cpp group_by full sample
23     \snippet output.txt group_by full sample
24 
25     \sample
26     \snippet group_by.cpp group_by sample
27     \snippet output.txt group_by sample
28 */
29 
30 #if !defined(RXCPP_OPERATORS_RX_GROUP_BY_HPP)
31 #define RXCPP_OPERATORS_RX_GROUP_BY_HPP
32 
33 #include "../rx-includes.hpp"
34 
35 namespace rxcpp {
36 
37 namespace operators {
38 
39 namespace detail {
40 
41 template<class... AN>
42 struct group_by_invalid_arguments {};
43 
44 template<class... AN>
45 struct group_by_invalid : public rxo::operator_base<group_by_invalid_arguments<AN...>> {
46     using type = observable<group_by_invalid_arguments<AN...>, group_by_invalid<AN...>>;
47 };
48 template<class... AN>
49 using group_by_invalid_t = typename group_by_invalid<AN...>::type;
50 
51 template<class T, class Selector>
52 struct is_group_by_selector_for {
53 
54     typedef rxu::decay_t<Selector> selector_type;
55     typedef T source_value_type;
56 
57     struct tag_not_valid {};
58     template<class CV, class CS>
59     static auto check(int) -> decltype((*(CS*)nullptr)(*(CV*)nullptr));
60     template<class CV, class CS>
61     static tag_not_valid check(...);
62 
63     typedef decltype(check<source_value_type, selector_type>(0)) type;
64     static const bool value = !std::is_same<type, tag_not_valid>::value;
65 };
66 
67 template<class T, class Observable, class KeySelector, class MarbleSelector, class BinaryPredicate, class DurationSelector>
68 struct group_by_traits
69 {
70     typedef T source_value_type;
71     typedef rxu::decay_t<Observable> source_type;
72     typedef rxu::decay_t<KeySelector> key_selector_type;
73     typedef rxu::decay_t<MarbleSelector> marble_selector_type;
74     typedef rxu::decay_t<BinaryPredicate> predicate_type;
75     typedef rxu::decay_t<DurationSelector> duration_selector_type;
76 
77     static_assert(is_group_by_selector_for<source_value_type, key_selector_type>::value, "group_by KeySelector must be a function with the signature key_type(source_value_type)");
78 
79     typedef typename is_group_by_selector_for<source_value_type, key_selector_type>::type key_type;
80 
81     static_assert(is_group_by_selector_for<source_value_type, marble_selector_type>::value, "group_by MarbleSelector must be a function with the signature marble_type(source_value_type)");
82 
83     typedef typename is_group_by_selector_for<source_value_type, marble_selector_type>::type marble_type;
84 
85     typedef rxsub::subject<marble_type> subject_type;
86 
87     typedef std::map<key_type, typename subject_type::subscriber_type, predicate_type> key_subscriber_map_type;
88 
89     typedef grouped_observable<key_type, marble_type> grouped_observable_type;
90 };
91 
92 template<class T, class Observable, class KeySelector, class MarbleSelector, class BinaryPredicate, class DurationSelector>
93 struct group_by
94 {
95     typedef group_by_traits<T, Observable, KeySelector, MarbleSelector, BinaryPredicate, DurationSelector> traits_type;
96     typedef typename traits_type::key_selector_type key_selector_type;
97     typedef typename traits_type::marble_selector_type marble_selector_type;
98     typedef typename traits_type::marble_type marble_type;
99     typedef typename traits_type::predicate_type predicate_type;
100     typedef typename traits_type::duration_selector_type duration_selector_type;
101     typedef typename traits_type::subject_type subject_type;
102     typedef typename traits_type::key_type key_type;
103 
104     typedef typename traits_type::key_subscriber_map_type group_map_type;
105     typedef std::vector<typename composite_subscription::weak_subscription> bindings_type;
106 
107     struct group_by_state_type
108     {
group_by_state_typerxcpp::operators::detail::group_by::group_by_state_type109         group_by_state_type(composite_subscription sl, predicate_type p)
110             : source_lifetime(sl)
111             , groups(p)
112             , observers(0)
113         {}
114         composite_subscription source_lifetime;
115         rxsc::worker worker;
116         group_map_type groups;
117         std::atomic<int> observers;
118     };
119 
120     template<class Subscriber>
stopsourcerxcpp::operators::detail::group_by121     static void stopsource(Subscriber&& dest, std::shared_ptr<group_by_state_type>& state) {
122         ++state->observers;
123         dest.add([state](){
124             if (!state->source_lifetime.is_subscribed()) {
125                 return;
126             }
127             --state->observers;
128             if (state->observers == 0) {
129                 state->source_lifetime.unsubscribe();
130             }
131         });
132     }
133 
134     struct group_by_values
135     {
group_by_valuesrxcpp::operators::detail::group_by::group_by_values136         group_by_values(key_selector_type ks, marble_selector_type ms, predicate_type p, duration_selector_type ds)
137             : keySelector(std::move(ks))
138             , marbleSelector(std::move(ms))
139             , predicate(std::move(p))
140             , durationSelector(std::move(ds))
141         {
142         }
143         mutable key_selector_type keySelector;
144         mutable marble_selector_type marbleSelector;
145         mutable predicate_type predicate;
146         mutable duration_selector_type durationSelector;
147     };
148 
149     group_by_values initial;
150 
group_byrxcpp::operators::detail::group_by151     group_by(key_selector_type ks, marble_selector_type ms, predicate_type p, duration_selector_type ds)
152         : initial(std::move(ks), std::move(ms), std::move(p), std::move(ds))
153     {
154     }
155 
156     struct group_by_observable : public rxs::source_base<marble_type>
157     {
158         mutable std::shared_ptr<group_by_state_type> state;
159         subject_type subject;
160         key_type key;
161 
group_by_observablerxcpp::operators::detail::group_by::group_by_observable162         group_by_observable(std::shared_ptr<group_by_state_type> st, subject_type s, key_type k)
163             : state(std::move(st))
164             , subject(std::move(s))
165             , key(k)
166         {
167         }
168 
169         template<class Subscriber>
on_subscriberxcpp::operators::detail::group_by::group_by_observable170         void on_subscribe(Subscriber&& o) const {
171             group_by::stopsource(o, state);
172             subject.get_observable().subscribe(std::forward<Subscriber>(o));
173         }
174 
on_get_keyrxcpp::operators::detail::group_by::group_by_observable175         key_type on_get_key() {
176             return key;
177         }
178     };
179 
180     template<class Subscriber>
181     struct group_by_observer : public group_by_values
182     {
183         typedef group_by_observer<Subscriber> this_type;
184         typedef typename traits_type::grouped_observable_type value_type;
185         typedef rxu::decay_t<Subscriber> dest_type;
186         typedef observer<T, this_type> observer_type;
187 
188         dest_type dest;
189 
190         mutable std::shared_ptr<group_by_state_type> state;
191 
group_by_observerrxcpp::operators::detail::group_by::group_by_observer192         group_by_observer(composite_subscription l, dest_type d, group_by_values v)
193             : group_by_values(v)
194             , dest(std::move(d))
195             , state(std::make_shared<group_by_state_type>(l, group_by_values::predicate))
196         {
197             group_by::stopsource(dest, state);
198         }
on_nextrxcpp::operators::detail::group_by::group_by_observer199         void on_next(T v) const {
200             auto selectedKey = on_exception(
201                 [&](){
202                     return this->keySelector(v);},
203                 [this](rxu::error_ptr e){on_error(e);});
204             if (selectedKey.empty()) {
205                 return;
206             }
207             auto g = state->groups.find(selectedKey.get());
208             if (g == state->groups.end()) {
209                 if (!dest.is_subscribed()) {
210                     return;
211                 }
212                 auto sub = subject_type();
213                 g = state->groups.insert(std::make_pair(selectedKey.get(), sub.get_subscriber())).first;
214                 auto obs = make_dynamic_grouped_observable<key_type, marble_type>(group_by_observable(state, sub, selectedKey.get()));
215                 auto durationObs = on_exception(
216                     [&](){
217                         return this->durationSelector(obs);},
218                     [this](rxu::error_ptr e){on_error(e);});
219                 if (durationObs.empty()) {
220                     return;
221                 }
222 
223                 dest.on_next(obs);
224                 composite_subscription duration_sub;
225                 auto ssub = state->source_lifetime.add(duration_sub);
226 
227                 auto expire_state = state;
228                 auto expire_dest = g->second;
229                 auto expire = [=]() {
230                     auto g = expire_state->groups.find(selectedKey.get());
231                     if (g != expire_state->groups.end()) {
232                         expire_state->groups.erase(g);
233                         expire_dest.on_completed();
234                     }
235                     expire_state->source_lifetime.remove(ssub);
236                 };
237                 auto robs = durationObs.get().take(1);
238                 duration_sub.add(robs.subscribe(
239                     [](const typename decltype(robs)::value_type &){},
240                     [=](rxu::error_ptr) {expire();},
241                     [=](){expire();}
242                 ));
243             }
244             auto selectedMarble = on_exception(
245                 [&](){
246                     return this->marbleSelector(v);},
247                 [this](rxu::error_ptr e){on_error(e);});
248             if (selectedMarble.empty()) {
249                 return;
250             }
251             g->second.on_next(std::move(selectedMarble.get()));
252         }
on_errorrxcpp::operators::detail::group_by::group_by_observer253         void on_error(rxu::error_ptr e) const {
254             for(auto& g : state->groups) {
255                 g.second.on_error(e);
256             }
257             dest.on_error(e);
258         }
on_completedrxcpp::operators::detail::group_by::group_by_observer259         void on_completed() const {
260             for(auto& g : state->groups) {
261                 g.second.on_completed();
262             }
263             dest.on_completed();
264         }
265 
makerxcpp::operators::detail::group_by::group_by_observer266         static subscriber<T, observer_type> make(dest_type d, group_by_values v) {
267             auto cs = composite_subscription();
268             return make_subscriber<T>(cs, observer_type(this_type(cs, std::move(d), std::move(v))));
269         }
270     };
271 
272     template<class Subscriber>
operator ()rxcpp::operators::detail::group_by273     auto operator()(Subscriber dest) const
274         -> decltype(group_by_observer<Subscriber>::make(std::move(dest), initial)) {
275         return      group_by_observer<Subscriber>::make(std::move(dest), initial);
276     }
277 };
278 
279 template<class KeySelector, class MarbleSelector, class BinaryPredicate, class DurationSelector>
280 class group_by_factory
281 {
282     typedef rxu::decay_t<KeySelector> key_selector_type;
283     typedef rxu::decay_t<MarbleSelector> marble_selector_type;
284     typedef rxu::decay_t<BinaryPredicate> predicate_type;
285     typedef rxu::decay_t<DurationSelector> duration_selector_type;
286     key_selector_type keySelector;
287     marble_selector_type marbleSelector;
288     predicate_type predicate;
289     duration_selector_type durationSelector;
290 public:
group_by_factory(key_selector_type ks,marble_selector_type ms,predicate_type p,duration_selector_type ds)291     group_by_factory(key_selector_type ks, marble_selector_type ms, predicate_type p, duration_selector_type ds)
292         : keySelector(std::move(ks))
293         , marbleSelector(std::move(ms))
294         , predicate(std::move(p))
295         , durationSelector(std::move(ds))
296     {
297     }
298     template<class Observable>
299     struct group_by_factory_traits
300     {
301         typedef rxu::value_type_t<rxu::decay_t<Observable>> value_type;
302         typedef detail::group_by_traits<value_type, Observable, KeySelector, MarbleSelector, BinaryPredicate, DurationSelector> traits_type;
303         typedef detail::group_by<value_type, Observable, KeySelector, MarbleSelector, BinaryPredicate, DurationSelector> group_by_type;
304     };
305     template<class Observable>
operator ()(Observable && source)306     auto operator()(Observable&& source)
307         -> decltype(source.template lift<typename group_by_factory_traits<Observable>::traits_type::grouped_observable_type>(typename group_by_factory_traits<Observable>::group_by_type(std::move(keySelector), std::move(marbleSelector), std::move(predicate), std::move(durationSelector)))) {
308         return      source.template lift<typename group_by_factory_traits<Observable>::traits_type::grouped_observable_type>(typename group_by_factory_traits<Observable>::group_by_type(std::move(keySelector), std::move(marbleSelector), std::move(predicate), std::move(durationSelector)));
309     }
310 };
311 
312 }
313 
314 /*! @copydoc rx-group_by.hpp
315 */
316 template<class... AN>
group_by(AN &&...an)317 auto group_by(AN&&... an)
318     ->     operator_factory<group_by_tag, AN...> {
319     return operator_factory<group_by_tag, AN...>(std::make_tuple(std::forward<AN>(an)...));
320 }
321 
322 }
323 
324 template<>
325 struct member_overload<group_by_tag>
326 {
327     template<class Observable, class KeySelector, class MarbleSelector, class BinaryPredicate, class DurationSelector,
328         class SourceValue = rxu::value_type_t<Observable>,
329         class Traits = rxo::detail::group_by_traits<SourceValue, rxu::decay_t<Observable>, KeySelector, MarbleSelector, BinaryPredicate, DurationSelector>,
330         class GroupBy = rxo::detail::group_by<SourceValue, rxu::decay_t<Observable>, rxu::decay_t<KeySelector>, rxu::decay_t<MarbleSelector>, rxu::decay_t<BinaryPredicate>, rxu::decay_t<DurationSelector>>,
331         class Value = typename Traits::grouped_observable_type>
memberrxcpp::member_overload332     static auto member(Observable&& o, KeySelector&& ks, MarbleSelector&& ms, BinaryPredicate&& p, DurationSelector&& ds)
333         -> decltype(o.template lift<Value>(GroupBy(std::forward<KeySelector>(ks), std::forward<MarbleSelector>(ms), std::forward<BinaryPredicate>(p), std::forward<DurationSelector>(ds)))) {
334         return      o.template lift<Value>(GroupBy(std::forward<KeySelector>(ks), std::forward<MarbleSelector>(ms), std::forward<BinaryPredicate>(p), std::forward<DurationSelector>(ds)));
335     }
336 
337     template<class Observable, class KeySelector, class MarbleSelector, class BinaryPredicate,
338         class DurationSelector=rxu::ret<observable<int, rxs::detail::never<int>>>,
339         class SourceValue = rxu::value_type_t<Observable>,
340         class Traits = rxo::detail::group_by_traits<SourceValue, rxu::decay_t<Observable>, KeySelector, MarbleSelector, BinaryPredicate, DurationSelector>,
341         class GroupBy = rxo::detail::group_by<SourceValue, rxu::decay_t<Observable>, rxu::decay_t<KeySelector>, rxu::decay_t<MarbleSelector>, rxu::decay_t<BinaryPredicate>, rxu::decay_t<DurationSelector>>,
342         class Value = typename Traits::grouped_observable_type>
memberrxcpp::member_overload343     static auto member(Observable&& o, KeySelector&& ks, MarbleSelector&& ms, BinaryPredicate&& p)
344         -> decltype(o.template lift<Value>(GroupBy(std::forward<KeySelector>(ks), std::forward<MarbleSelector>(ms), std::forward<BinaryPredicate>(p), rxu::ret<observable<int, rxs::detail::never<int>>>()))) {
345         return      o.template lift<Value>(GroupBy(std::forward<KeySelector>(ks), std::forward<MarbleSelector>(ms), std::forward<BinaryPredicate>(p), rxu::ret<observable<int, rxs::detail::never<int>>>()));
346     }
347 
348     template<class Observable, class KeySelector, class MarbleSelector,
349         class BinaryPredicate=rxu::less,
350         class DurationSelector=rxu::ret<observable<int, rxs::detail::never<int>>>,
351         class SourceValue = rxu::value_type_t<Observable>,
352         class Traits = rxo::detail::group_by_traits<SourceValue, rxu::decay_t<Observable>, KeySelector, MarbleSelector, BinaryPredicate, DurationSelector>,
353         class GroupBy = rxo::detail::group_by<SourceValue, rxu::decay_t<Observable>, rxu::decay_t<KeySelector>, rxu::decay_t<MarbleSelector>, rxu::decay_t<BinaryPredicate>, rxu::decay_t<DurationSelector>>,
354         class Value = typename Traits::grouped_observable_type>
memberrxcpp::member_overload355     static auto member(Observable&& o, KeySelector&& ks, MarbleSelector&& ms)
356         -> decltype(o.template lift<Value>(GroupBy(std::forward<KeySelector>(ks), std::forward<MarbleSelector>(ms), rxu::less(), rxu::ret<observable<int, rxs::detail::never<int>>>()))) {
357         return      o.template lift<Value>(GroupBy(std::forward<KeySelector>(ks), std::forward<MarbleSelector>(ms), rxu::less(), rxu::ret<observable<int, rxs::detail::never<int>>>()));
358     }
359 
360 
361     template<class Observable, class KeySelector,
362         class MarbleSelector=rxu::detail::take_at<0>,
363         class BinaryPredicate=rxu::less,
364         class DurationSelector=rxu::ret<observable<int, rxs::detail::never<int>>>,
365         class SourceValue = rxu::value_type_t<Observable>,
366         class Traits = rxo::detail::group_by_traits<SourceValue, rxu::decay_t<Observable>, KeySelector, MarbleSelector, BinaryPredicate, DurationSelector>,
367         class GroupBy = rxo::detail::group_by<SourceValue, rxu::decay_t<Observable>, rxu::decay_t<KeySelector>, rxu::decay_t<MarbleSelector>, rxu::decay_t<BinaryPredicate>, rxu::decay_t<DurationSelector>>,
368         class Value = typename Traits::grouped_observable_type>
memberrxcpp::member_overload369     static auto member(Observable&& o, KeySelector&& ks)
370         -> decltype(o.template lift<Value>(GroupBy(std::forward<KeySelector>(ks), rxu::detail::take_at<0>(), rxu::less(), rxu::ret<observable<int, rxs::detail::never<int>>>()))) {
371         return      o.template lift<Value>(GroupBy(std::forward<KeySelector>(ks), rxu::detail::take_at<0>(), rxu::less(), rxu::ret<observable<int, rxs::detail::never<int>>>()));
372     }
373 
374     template<class Observable,
375         class KeySelector=rxu::detail::take_at<0>,
376         class MarbleSelector=rxu::detail::take_at<0>,
377         class BinaryPredicate=rxu::less,
378         class DurationSelector=rxu::ret<observable<int, rxs::detail::never<int>>>,
379         class Enabled = rxu::enable_if_all_true_type_t<
380             all_observables<Observable>>,
381         class SourceValue = rxu::value_type_t<Observable>,
382         class Traits = rxo::detail::group_by_traits<SourceValue, rxu::decay_t<Observable>, KeySelector, MarbleSelector, BinaryPredicate, DurationSelector>,
383         class GroupBy = rxo::detail::group_by<SourceValue, rxu::decay_t<Observable>, rxu::decay_t<KeySelector>, rxu::decay_t<MarbleSelector>, rxu::decay_t<BinaryPredicate>, rxu::decay_t<DurationSelector>>,
384         class Value = typename Traits::grouped_observable_type>
memberrxcpp::member_overload385     static auto member(Observable&& o)
386         -> decltype(o.template lift<Value>(GroupBy(rxu::detail::take_at<0>(), rxu::detail::take_at<0>(), rxu::less(), rxu::ret<observable<int, rxs::detail::never<int>>>()))) {
387         return      o.template lift<Value>(GroupBy(rxu::detail::take_at<0>(), rxu::detail::take_at<0>(), rxu::less(), rxu::ret<observable<int, rxs::detail::never<int>>>()));
388     }
389 
390     template<class... AN>
memberrxcpp::member_overload391     static operators::detail::group_by_invalid_t<AN...> member(const AN&...) {
392         std::terminate();
393         return {};
394         static_assert(sizeof...(AN) == 10000, "group_by takes (optional KeySelector, optional MarbleSelector, optional BinaryKeyPredicate, optional DurationSelector), KeySelector takes (Observable::value_type) -> KeyValue, MarbleSelector takes (Observable::value_type) -> MarbleValue, BinaryKeyPredicate takes (KeyValue, KeyValue) -> bool, DurationSelector takes (Observable::value_type) -> Observable");
395     }
396 
397 };
398 
399 }
400 
401 #endif
402 
403