1# Copyright 2015 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import time
16import threading
17import unittest
18import platform
19
20from grpc._cython import cygrpc
21from tests.unit._cython import test_utilities
22from tests.unit import test_common
23from tests.unit import resources
24
25_SSL_HOST_OVERRIDE = b'foo.test.google.fr'
26_CALL_CREDENTIALS_METADATA_KEY = 'call-creds-key'
27_CALL_CREDENTIALS_METADATA_VALUE = 'call-creds-value'
28_EMPTY_FLAGS = 0
29
30
31def _metadata_plugin(context, callback):
32    callback(((
33        _CALL_CREDENTIALS_METADATA_KEY,
34        _CALL_CREDENTIALS_METADATA_VALUE,
35    ),), cygrpc.StatusCode.ok, b'')
36
37
38class TypeSmokeTest(unittest.TestCase):
39
40    def testCompletionQueueUpDown(self):
41        completion_queue = cygrpc.CompletionQueue()
42        del completion_queue
43
44    def testServerUpDown(self):
45        server = cygrpc.Server(set([
46            (
47                b'grpc.so_reuseport',
48                0,
49            ),
50        ]))
51        del server
52
53    def testChannelUpDown(self):
54        channel = cygrpc.Channel(b'[::]:0', None, None)
55        channel.close(cygrpc.StatusCode.cancelled, 'Test method anyway!')
56
57    def test_metadata_plugin_call_credentials_up_down(self):
58        cygrpc.MetadataPluginCallCredentials(_metadata_plugin,
59                                             b'test plugin name!')
60
61    def testServerStartNoExplicitShutdown(self):
62        server = cygrpc.Server([
63            (
64                b'grpc.so_reuseport',
65                0,
66            ),
67        ])
68        completion_queue = cygrpc.CompletionQueue()
69        server.register_completion_queue(completion_queue)
70        port = server.add_http2_port(b'[::]:0')
71        self.assertIsInstance(port, int)
72        server.start()
73        del server
74
75    def testServerStartShutdown(self):
76        completion_queue = cygrpc.CompletionQueue()
77        server = cygrpc.Server([
78            (
79                b'grpc.so_reuseport',
80                0,
81            ),
82        ])
83        server.add_http2_port(b'[::]:0')
84        server.register_completion_queue(completion_queue)
85        server.start()
86        shutdown_tag = object()
87        server.shutdown(completion_queue, shutdown_tag)
88        event = completion_queue.poll()
89        self.assertEqual(cygrpc.CompletionType.operation_complete,
90                         event.completion_type)
91        self.assertIs(shutdown_tag, event.tag)
92        del server
93        del completion_queue
94
95
96class ServerClientMixin(object):
97
98    def setUpMixin(self, server_credentials, client_credentials, host_override):
99        self.server_completion_queue = cygrpc.CompletionQueue()
100        self.server = cygrpc.Server([
101            (
102                b'grpc.so_reuseport',
103                0,
104            ),
105        ])
106        self.server.register_completion_queue(self.server_completion_queue)
107        if server_credentials:
108            self.port = self.server.add_http2_port(b'[::]:0',
109                                                   server_credentials)
110        else:
111            self.port = self.server.add_http2_port(b'[::]:0')
112        self.server.start()
113        self.client_completion_queue = cygrpc.CompletionQueue()
114        if client_credentials:
115            client_channel_arguments = ((
116                cygrpc.ChannelArgKey.ssl_target_name_override,
117                host_override,
118            ),)
119            self.client_channel = cygrpc.Channel('localhost:{}'.format(
120                self.port).encode(), client_channel_arguments,
121                                                 client_credentials)
122        else:
123            self.client_channel = cygrpc.Channel('localhost:{}'.format(
124                self.port).encode(), set(), None)
125        if host_override:
126            self.host_argument = None  # default host
127            self.expected_host = host_override
128        else:
129            # arbitrary host name necessitating no further identification
130            self.host_argument = b'hostess'
131            self.expected_host = self.host_argument
132
133    def tearDownMixin(self):
134        self.client_channel.close(cygrpc.StatusCode.ok, 'test being torn down!')
135        del self.client_channel
136        del self.server
137        del self.client_completion_queue
138        del self.server_completion_queue
139
140    def _perform_queue_operations(self, operations, call, queue, deadline,
141                                  description):
142        """Perform the operations with given call, queue, and deadline.
143
144        Invocation errors are reported with as an exception with `description`
145        in the message. Performs the operations asynchronously, returning a
146        future.
147        """
148
149        def performer():
150            tag = object()
151            try:
152                call_result = call.start_client_batch(operations, tag)
153                self.assertEqual(cygrpc.CallError.ok, call_result)
154                event = queue.poll(deadline=deadline)
155                self.assertEqual(cygrpc.CompletionType.operation_complete,
156                                 event.completion_type)
157                self.assertTrue(event.success)
158                self.assertIs(tag, event.tag)
159            except Exception as error:
160                raise Exception("Error in '{}': {}".format(
161                    description, error.message))
162            return event
163
164        return test_utilities.SimpleFuture(performer)
165
166    def test_echo(self):
167        DEADLINE = time.time() + 5
168        DEADLINE_TOLERANCE = 0.25
169        CLIENT_METADATA_ASCII_KEY = 'key'
170        CLIENT_METADATA_ASCII_VALUE = 'val'
171        CLIENT_METADATA_BIN_KEY = 'key-bin'
172        CLIENT_METADATA_BIN_VALUE = b'\0' * 1000
173        SERVER_INITIAL_METADATA_KEY = 'init_me_me_me'
174        SERVER_INITIAL_METADATA_VALUE = 'whodawha?'
175        SERVER_TRAILING_METADATA_KEY = 'california_is_in_a_drought'
176        SERVER_TRAILING_METADATA_VALUE = 'zomg it is'
177        SERVER_STATUS_CODE = cygrpc.StatusCode.ok
178        SERVER_STATUS_DETAILS = 'our work is never over'
179        REQUEST = b'in death a member of project mayhem has a name'
180        RESPONSE = b'his name is robert paulson'
181        METHOD = b'twinkies'
182
183        server_request_tag = object()
184        request_call_result = self.server.request_call(
185            self.server_completion_queue, self.server_completion_queue,
186            server_request_tag)
187
188        self.assertEqual(cygrpc.CallError.ok, request_call_result)
189
190        client_call_tag = object()
191        client_initial_metadata = (
192            (
193                CLIENT_METADATA_ASCII_KEY,
194                CLIENT_METADATA_ASCII_VALUE,
195            ),
196            (
197                CLIENT_METADATA_BIN_KEY,
198                CLIENT_METADATA_BIN_VALUE,
199            ),
200        )
201        client_call = self.client_channel.integrated_call(
202            0, METHOD, self.host_argument, DEADLINE, client_initial_metadata,
203            None, [
204                (
205                    [
206                        cygrpc.SendInitialMetadataOperation(
207                            client_initial_metadata, _EMPTY_FLAGS),
208                        cygrpc.SendMessageOperation(REQUEST, _EMPTY_FLAGS),
209                        cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS),
210                        cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),
211                        cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),
212                        cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),
213                    ],
214                    client_call_tag,
215                ),
216            ])
217        client_event_future = test_utilities.SimpleFuture(
218            self.client_channel.next_call_event)
219
220        request_event = self.server_completion_queue.poll(deadline=DEADLINE)
221        self.assertEqual(cygrpc.CompletionType.operation_complete,
222                         request_event.completion_type)
223        self.assertIsInstance(request_event.call, cygrpc.Call)
224        self.assertIs(server_request_tag, request_event.tag)
225        self.assertTrue(
226            test_common.metadata_transmitted(client_initial_metadata,
227                                             request_event.invocation_metadata))
228        self.assertEqual(METHOD, request_event.call_details.method)
229        self.assertEqual(self.expected_host, request_event.call_details.host)
230        self.assertLess(
231            abs(DEADLINE - request_event.call_details.deadline),
232            DEADLINE_TOLERANCE)
233
234        server_call_tag = object()
235        server_call = request_event.call
236        server_initial_metadata = ((
237            SERVER_INITIAL_METADATA_KEY,
238            SERVER_INITIAL_METADATA_VALUE,
239        ),)
240        server_trailing_metadata = ((
241            SERVER_TRAILING_METADATA_KEY,
242            SERVER_TRAILING_METADATA_VALUE,
243        ),)
244        server_start_batch_result = server_call.start_server_batch([
245            cygrpc.SendInitialMetadataOperation(server_initial_metadata,
246                                                _EMPTY_FLAGS),
247            cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),
248            cygrpc.SendMessageOperation(RESPONSE, _EMPTY_FLAGS),
249            cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS),
250            cygrpc.SendStatusFromServerOperation(
251                server_trailing_metadata, SERVER_STATUS_CODE,
252                SERVER_STATUS_DETAILS, _EMPTY_FLAGS)
253        ], server_call_tag)
254        self.assertEqual(cygrpc.CallError.ok, server_start_batch_result)
255
256        server_event = self.server_completion_queue.poll(deadline=DEADLINE)
257        client_event = client_event_future.result()
258
259        self.assertEqual(6, len(client_event.batch_operations))
260        found_client_op_types = set()
261        for client_result in client_event.batch_operations:
262            # we expect each op type to be unique
263            self.assertNotIn(client_result.type(), found_client_op_types)
264            found_client_op_types.add(client_result.type())
265            if client_result.type(
266            ) == cygrpc.OperationType.receive_initial_metadata:
267                self.assertTrue(
268                    test_common.metadata_transmitted(
269                        server_initial_metadata,
270                        client_result.initial_metadata()))
271            elif client_result.type() == cygrpc.OperationType.receive_message:
272                self.assertEqual(RESPONSE, client_result.message())
273            elif client_result.type(
274            ) == cygrpc.OperationType.receive_status_on_client:
275                self.assertTrue(
276                    test_common.metadata_transmitted(
277                        server_trailing_metadata,
278                        client_result.trailing_metadata()))
279                self.assertEqual(SERVER_STATUS_DETAILS, client_result.details())
280                self.assertEqual(SERVER_STATUS_CODE, client_result.code())
281        self.assertEqual(
282            set([
283                cygrpc.OperationType.send_initial_metadata,
284                cygrpc.OperationType.send_message,
285                cygrpc.OperationType.send_close_from_client,
286                cygrpc.OperationType.receive_initial_metadata,
287                cygrpc.OperationType.receive_message,
288                cygrpc.OperationType.receive_status_on_client
289            ]), found_client_op_types)
290
291        self.assertEqual(5, len(server_event.batch_operations))
292        found_server_op_types = set()
293        for server_result in server_event.batch_operations:
294            self.assertNotIn(server_result.type(), found_server_op_types)
295            found_server_op_types.add(server_result.type())
296            if server_result.type() == cygrpc.OperationType.receive_message:
297                self.assertEqual(REQUEST, server_result.message())
298            elif server_result.type(
299            ) == cygrpc.OperationType.receive_close_on_server:
300                self.assertFalse(server_result.cancelled())
301        self.assertEqual(
302            set([
303                cygrpc.OperationType.send_initial_metadata,
304                cygrpc.OperationType.receive_message,
305                cygrpc.OperationType.send_message,
306                cygrpc.OperationType.receive_close_on_server,
307                cygrpc.OperationType.send_status_from_server
308            ]), found_server_op_types)
309
310        del client_call
311        del server_call
312
313    def test_6522(self):
314        DEADLINE = time.time() + 5
315        DEADLINE_TOLERANCE = 0.25
316        METHOD = b'twinkies'
317
318        empty_metadata = ()
319
320        # Prologue
321        server_request_tag = object()
322        self.server.request_call(self.server_completion_queue,
323                                 self.server_completion_queue,
324                                 server_request_tag)
325        client_call = self.client_channel.segregated_call(
326            0, METHOD, self.host_argument, DEADLINE, None, None, ([(
327                [
328                    cygrpc.SendInitialMetadataOperation(empty_metadata,
329                                                        _EMPTY_FLAGS),
330                    cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),
331                ],
332                object(),
333            ), (
334                [
335                    cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),
336                ],
337                object(),
338            )]))
339
340        client_initial_metadata_event_future = test_utilities.SimpleFuture(
341            client_call.next_event)
342
343        request_event = self.server_completion_queue.poll(deadline=DEADLINE)
344        server_call = request_event.call
345
346        def perform_server_operations(operations, description):
347            return self._perform_queue_operations(operations, server_call,
348                                                  self.server_completion_queue,
349                                                  DEADLINE, description)
350
351        server_event_future = perform_server_operations([
352            cygrpc.SendInitialMetadataOperation(empty_metadata, _EMPTY_FLAGS),
353        ], "Server prologue")
354
355        client_initial_metadata_event_future.result()  # force completion
356        server_event_future.result()
357
358        # Messaging
359        for _ in range(10):
360            client_call.operate([
361                cygrpc.SendMessageOperation(b'', _EMPTY_FLAGS),
362                cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),
363            ], "Client message")
364            client_message_event_future = test_utilities.SimpleFuture(
365                client_call.next_event)
366            server_event_future = perform_server_operations([
367                cygrpc.SendMessageOperation(b'', _EMPTY_FLAGS),
368                cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),
369            ], "Server receive")
370
371            client_message_event_future.result()  # force completion
372            server_event_future.result()
373
374        # Epilogue
375        client_call.operate([
376            cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS),
377        ], "Client epilogue")
378        # One for ReceiveStatusOnClient, one for SendCloseFromClient.
379        client_events_future = test_utilities.SimpleFuture(
380            lambda: {
381                client_call.next_event(),
382                client_call.next_event(),})
383
384        server_event_future = perform_server_operations([
385            cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS),
386            cygrpc.SendStatusFromServerOperation(
387                empty_metadata, cygrpc.StatusCode.ok, b'', _EMPTY_FLAGS)
388        ], "Server epilogue")
389
390        client_events_future.result()  # force completion
391        server_event_future.result()
392
393
394class InsecureServerInsecureClient(unittest.TestCase, ServerClientMixin):
395
396    def setUp(self):
397        self.setUpMixin(None, None, None)
398
399    def tearDown(self):
400        self.tearDownMixin()
401
402
403class SecureServerSecureClient(unittest.TestCase, ServerClientMixin):
404
405    def setUp(self):
406        server_credentials = cygrpc.server_credentials_ssl(
407            None, [
408                cygrpc.SslPemKeyCertPair(resources.private_key(),
409                                         resources.certificate_chain())
410            ], False)
411        client_credentials = cygrpc.SSLChannelCredentials(
412            resources.test_root_certificates(), None, None)
413        self.setUpMixin(server_credentials, client_credentials,
414                        _SSL_HOST_OVERRIDE)
415
416    def tearDown(self):
417        self.tearDownMixin()
418
419
420if __name__ == '__main__':
421    unittest.main(verbosity=2)
422