1 // Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. See License.txt in the project root for license information.
2 
3 #pragma once
4 
5 /*! \file rx-coroutine.hpp
6 
7     \brief The proposal to add couroutines to the standard adds `co_await`, `for co_await`, `co_yield` and `co_return`. This file adds `begin(observable<>)` & `end(observable<>)` which enables `for co_await` to work with the `observable<>` type.
8 
9     for co_await (auto c : interval(seconds(1), observe_on_event_loop()) | take(3)) {
10         printf("%d\n", c);
11     }
12 
13 */
14 
15 #if !defined(RXCPP_RX_COROUTINE_HPP)
16 #define RXCPP_RX_COROUTINE_HPP
17 
18 #include "rx-includes.hpp"
19 
20 #ifdef _RESUMABLE_FUNCTIONS_SUPPORTED
21 
22 #include <rxcpp/operators/rx-finally.hpp>
23 
24 #include <experimental/resumable>
25 
26 namespace rxcpp {
27 namespace coroutine {
28 
29 using namespace std;
30 using namespace std::chrono;
31 using namespace std::experimental;
32 
33 template<typename Source>
34 struct co_observable_iterator;
35 
36 template<typename Source>
37 struct co_observable_iterator_state : std::enable_shared_from_this<co_observable_iterator_state<Source>>
38 {
39     using value_type = typename Source::value_type;
40 
~co_observable_iterator_staterxcpp::coroutine::co_observable_iterator_state41     ~co_observable_iterator_state() {
42         lifetime.unsubscribe();
43     }
co_observable_iterator_staterxcpp::coroutine::co_observable_iterator_state44     explicit co_observable_iterator_state(const Source& o) : o(o) {}
45 
46     coroutine_handle<> caller{};
47     composite_subscription lifetime{};
48     const value_type* value{nullptr};
49     exception_ptr error{nullptr};
50     Source o;
51 };
52 
53 template<typename Source>
54 struct co_observable_inc_awaiter
55 {
await_readyrxcpp::coroutine::co_observable_inc_awaiter56     bool await_ready() {
57         return false;
58     }
59 
await_suspendrxcpp::coroutine::co_observable_inc_awaiter60     bool await_suspend(coroutine_handle<> handle) {
61         if (!state->lifetime.is_subscribed()) {return false;}
62         state->caller = handle;
63         return true;
64     }
65 
66     co_observable_iterator<Source> await_resume();
67 
68     shared_ptr<co_observable_iterator_state<Source>> state;
69 };
70 
71 template<typename Source>
72 struct co_observable_iterator : public iterator<input_iterator_tag, typename Source::value_type>
73 {
74     using value_type = typename Source::value_type;
75 
co_observable_iteratorrxcpp::coroutine::co_observable_iterator76     co_observable_iterator() {}
77 
co_observable_iteratorrxcpp::coroutine::co_observable_iterator78     explicit co_observable_iterator(const Source& o) : state(make_shared<co_observable_iterator_state<Source>>(o)) {}
co_observable_iteratorrxcpp::coroutine::co_observable_iterator79     explicit co_observable_iterator(const shared_ptr<co_observable_iterator_state<Source>>& o) : state(o) {}
80 
81     co_observable_iterator(co_observable_iterator&&)=default;
82     co_observable_iterator& operator=(co_observable_iterator&&)=default;
83 
operator ++rxcpp::coroutine::co_observable_iterator84     co_observable_inc_awaiter<Source> operator++()
85     {
86         return co_observable_inc_awaiter<Source>{state};
87     }
88 
89     co_observable_iterator& operator++(int) = delete;
90     // not implementing postincrement
91 
operator ==rxcpp::coroutine::co_observable_iterator92     bool operator==(co_observable_iterator const &rhs) const
93     {
94         return !!state && !rhs.state && !state->lifetime.is_subscribed();
95     }
96 
operator !=rxcpp::coroutine::co_observable_iterator97     bool operator!=(co_observable_iterator const &rhs) const
98     {
99         return !(*this == rhs);
100     }
101 
operator *rxcpp::coroutine::co_observable_iterator102     value_type const &operator*() const
103     {
104         return *(state->value);
105     }
106 
operator ->rxcpp::coroutine::co_observable_iterator107     value_type const *operator->() const
108     {
109         return std::addressof(operator*());
110     }
111 
112     shared_ptr<co_observable_iterator_state<Source>> state;
113 };
114 
115 template<typename Source>
await_resume()116 co_observable_iterator<Source> co_observable_inc_awaiter<Source>::await_resume() {
117     if (!!state->error) {rethrow_exception(state->error);}
118     return co_observable_iterator<Source>{state};
119 }
120 
121 template<typename Source>
122 struct co_observable_iterator_awaiter
123 {
124     using iterator=co_observable_iterator<Source>;
125     using value_type=typename iterator::value_type;
126 
co_observable_iterator_awaiterrxcpp::coroutine::co_observable_iterator_awaiter127     explicit co_observable_iterator_awaiter(const Source& o) : it(o) {
128     }
129 
await_readyrxcpp::coroutine::co_observable_iterator_awaiter130     bool await_ready() {
131         return false;
132     }
133 
await_suspendrxcpp::coroutine::co_observable_iterator_awaiter134     void await_suspend(coroutine_handle<> handle) {
135         weak_ptr<co_observable_iterator_state<Source>> wst=it.state;
136         it.state->caller = handle;
137         it.state->o |
138             rxo::finally([wst](){
139                 auto st = wst.lock();
140                 if (st && !!st->caller) {
141                     auto caller = st->caller;
142                     st->caller = nullptr;
143                     caller();
144                 }
145             }) |
146             rxo::subscribe<value_type>(
147                 it.state->lifetime,
148                 // next
149                 [wst](const value_type& v){
150                     auto st = wst.lock();
151                     if (!st || !st->caller) {terminate();}
152                     st->value = addressof(v);
153                     auto caller = st->caller;
154                     st->caller = nullptr;
155                     caller();
156                 },
157                 // error
158                 [wst](exception_ptr e){
159                     auto st = wst.lock();
160                     if (!st || !st->caller) {terminate();}
161                     st->error = e;
162                     auto caller = st->caller;
163                     st->caller = nullptr;
164                     caller();
165                 });
166     }
167 
await_resumerxcpp::coroutine::co_observable_iterator_awaiter168     iterator await_resume() {
169         if (!!it.state->error) {rethrow_exception(it.state->error);}
170         return std::move(it);
171     }
172 
173     iterator it;
174 };
175 
176 }
177 }
178 
179 namespace std
180 {
181 
182 template<typename T, typename SourceOperator>
begin(const rxcpp::observable<T,SourceOperator> & o)183 auto begin(const rxcpp::observable<T, SourceOperator>& o)
184     ->      rxcpp::coroutine::co_observable_iterator_awaiter<rxcpp::observable<T, SourceOperator>> {
185     return  rxcpp::coroutine::co_observable_iterator_awaiter<rxcpp::observable<T, SourceOperator>>{o};
186 }
187 
188 template<typename T, typename SourceOperator>
end(const rxcpp::observable<T,SourceOperator> &)189 auto end(const rxcpp::observable<T, SourceOperator>&)
190     ->      rxcpp::coroutine::co_observable_iterator<rxcpp::observable<T, SourceOperator>> {
191     return  rxcpp::coroutine::co_observable_iterator<rxcpp::observable<T, SourceOperator>>{};
192 }
193 
194 }
195 
196 #endif
197 
198 #endif
199