1# Copyright 2016 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Test making many calls and immediately cancelling most of them."""
15
16import threading
17import unittest
18
19from grpc._cython import cygrpc
20from grpc.framework.foundation import logging_pool
21from tests.unit.framework.common import test_constants
22from tests.unit._cython import test_utilities
23
24_EMPTY_FLAGS = 0
25_EMPTY_METADATA = ()
26
27_SERVER_SHUTDOWN_TAG = 'server_shutdown'
28_REQUEST_CALL_TAG = 'request_call'
29_RECEIVE_CLOSE_ON_SERVER_TAG = 'receive_close_on_server'
30_RECEIVE_MESSAGE_TAG = 'receive_message'
31_SERVER_COMPLETE_CALL_TAG = 'server_complete_call'
32
33_SUCCESS_CALL_FRACTION = 1.0 / 8.0
34_SUCCESSFUL_CALLS = int(test_constants.RPC_CONCURRENCY * _SUCCESS_CALL_FRACTION)
35_UNSUCCESSFUL_CALLS = test_constants.RPC_CONCURRENCY - _SUCCESSFUL_CALLS
36
37
38class _State(object):
39
40    def __init__(self):
41        self.condition = threading.Condition()
42        self.handlers_released = False
43        self.parked_handlers = 0
44        self.handled_rpcs = 0
45
46
47def _is_cancellation_event(event):
48    return (event.tag is _RECEIVE_CLOSE_ON_SERVER_TAG and
49            event.batch_operations[0].cancelled())
50
51
52class _Handler(object):
53
54    def __init__(self, state, completion_queue, rpc_event):
55        self._state = state
56        self._lock = threading.Lock()
57        self._completion_queue = completion_queue
58        self._call = rpc_event.call
59
60    def __call__(self):
61        with self._state.condition:
62            self._state.parked_handlers += 1
63            if self._state.parked_handlers == test_constants.THREAD_CONCURRENCY:
64                self._state.condition.notify_all()
65            while not self._state.handlers_released:
66                self._state.condition.wait()
67
68        with self._lock:
69            self._call.start_server_batch(
70                (cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS),),
71                _RECEIVE_CLOSE_ON_SERVER_TAG)
72            self._call.start_server_batch(
73                (cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),),
74                _RECEIVE_MESSAGE_TAG)
75        first_event = self._completion_queue.poll()
76        if _is_cancellation_event(first_event):
77            self._completion_queue.poll()
78        else:
79            with self._lock:
80                operations = (
81                    cygrpc.SendInitialMetadataOperation(_EMPTY_METADATA,
82                                                        _EMPTY_FLAGS),
83                    cygrpc.SendMessageOperation(b'\x79\x57', _EMPTY_FLAGS),
84                    cygrpc.SendStatusFromServerOperation(
85                        _EMPTY_METADATA, cygrpc.StatusCode.ok, b'test details!',
86                        _EMPTY_FLAGS),
87                )
88                self._call.start_server_batch(operations,
89                                              _SERVER_COMPLETE_CALL_TAG)
90            self._completion_queue.poll()
91            self._completion_queue.poll()
92
93
94def _serve(state, server, server_completion_queue, thread_pool):
95    for _ in range(test_constants.RPC_CONCURRENCY):
96        call_completion_queue = cygrpc.CompletionQueue()
97        server.request_call(call_completion_queue, server_completion_queue,
98                            _REQUEST_CALL_TAG)
99        rpc_event = server_completion_queue.poll()
100        thread_pool.submit(_Handler(state, call_completion_queue, rpc_event))
101        with state.condition:
102            state.handled_rpcs += 1
103            if test_constants.RPC_CONCURRENCY <= state.handled_rpcs:
104                state.condition.notify_all()
105    server_completion_queue.poll()
106
107
108class _QueueDriver(object):
109
110    def __init__(self, condition, completion_queue, due):
111        self._condition = condition
112        self._completion_queue = completion_queue
113        self._due = due
114        self._events = []
115        self._returned = False
116
117    def start(self):
118
119        def in_thread():
120            while True:
121                event = self._completion_queue.poll()
122                with self._condition:
123                    self._events.append(event)
124                    self._due.remove(event.tag)
125                    self._condition.notify_all()
126                    if not self._due:
127                        self._returned = True
128                        return
129
130        thread = threading.Thread(target=in_thread)
131        thread.start()
132
133    def events(self, at_least):
134        with self._condition:
135            while len(self._events) < at_least:
136                self._condition.wait()
137            return tuple(self._events)
138
139
140class CancelManyCallsTest(unittest.TestCase):
141
142    def testCancelManyCalls(self):
143        server_thread_pool = logging_pool.pool(
144            test_constants.THREAD_CONCURRENCY)
145
146        server_completion_queue = cygrpc.CompletionQueue()
147        server = cygrpc.Server([
148            (
149                b'grpc.so_reuseport',
150                0,
151            ),
152        ])
153        server.register_completion_queue(server_completion_queue)
154        port = server.add_http2_port(b'[::]:0')
155        server.start()
156        channel = cygrpc.Channel('localhost:{}'.format(port).encode(), None,
157                                 None)
158
159        state = _State()
160
161        server_thread_args = (
162            state,
163            server,
164            server_completion_queue,
165            server_thread_pool,
166        )
167        server_thread = threading.Thread(target=_serve, args=server_thread_args)
168        server_thread.start()
169
170        client_condition = threading.Condition()
171        client_due = set()
172
173        with client_condition:
174            client_calls = []
175            for index in range(test_constants.RPC_CONCURRENCY):
176                tag = 'client_complete_call_{0:04d}_tag'.format(index)
177                client_call = channel.integrated_call(
178                    _EMPTY_FLAGS, b'/twinkies', None, None, _EMPTY_METADATA,
179                    None, ((
180                        (
181                            cygrpc.SendInitialMetadataOperation(
182                                _EMPTY_METADATA, _EMPTY_FLAGS),
183                            cygrpc.SendMessageOperation(b'\x45\x56',
184                                                        _EMPTY_FLAGS),
185                            cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS),
186                            cygrpc.ReceiveInitialMetadataOperation(
187                                _EMPTY_FLAGS),
188                            cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),
189                            cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),
190                        ),
191                        tag,
192                    ),))
193                client_due.add(tag)
194                client_calls.append(client_call)
195
196        client_events_future = test_utilities.SimpleFuture(
197            lambda: tuple(channel.next_call_event() for _ in range(_SUCCESSFUL_CALLS)))
198
199        with state.condition:
200            while True:
201                if state.parked_handlers < test_constants.THREAD_CONCURRENCY:
202                    state.condition.wait()
203                elif state.handled_rpcs < test_constants.RPC_CONCURRENCY:
204                    state.condition.wait()
205                else:
206                    state.handlers_released = True
207                    state.condition.notify_all()
208                    break
209
210        client_events_future.result()
211        with client_condition:
212            for client_call in client_calls:
213                client_call.cancel(cygrpc.StatusCode.cancelled, 'Cancelled!')
214        for _ in range(_UNSUCCESSFUL_CALLS):
215            channel.next_call_event()
216
217        channel.close(cygrpc.StatusCode.unknown, 'Cancelled on channel close!')
218        with state.condition:
219            server.shutdown(server_completion_queue, _SERVER_SHUTDOWN_TAG)
220
221
222if __name__ == '__main__':
223    unittest.main(verbosity=2)
224