1# Copyright 2015 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import collections
16import contextlib
17import distutils.spawn
18import errno
19import os
20import shutil
21import subprocess
22import sys
23import tempfile
24import threading
25import unittest
26
27from six import moves
28
29import grpc
30from tests.unit import test_common
31from tests.unit.framework.common import test_constants
32
33import tests.protoc_plugin.protos.payload.test_payload_pb2 as payload_pb2
34import tests.protoc_plugin.protos.requests.r.test_requests_pb2 as request_pb2
35import tests.protoc_plugin.protos.responses.test_responses_pb2 as response_pb2
36import tests.protoc_plugin.protos.service.test_service_pb2_grpc as service_pb2_grpc
37
38# Identifiers of entities we expect to find in the generated module.
39STUB_IDENTIFIER = 'TestServiceStub'
40SERVICER_IDENTIFIER = 'TestServiceServicer'
41ADD_SERVICER_TO_SERVER_IDENTIFIER = 'add_TestServiceServicer_to_server'
42
43
44class _ServicerMethods(object):
45
46    def __init__(self):
47        self._condition = threading.Condition()
48        self._paused = False
49        self._fail = False
50
51    @contextlib.contextmanager
52    def pause(self):  # pylint: disable=invalid-name
53        with self._condition:
54            self._paused = True
55        yield
56        with self._condition:
57            self._paused = False
58            self._condition.notify_all()
59
60    @contextlib.contextmanager
61    def fail(self):  # pylint: disable=invalid-name
62        with self._condition:
63            self._fail = True
64        yield
65        with self._condition:
66            self._fail = False
67
68    def _control(self):  # pylint: disable=invalid-name
69        with self._condition:
70            if self._fail:
71                raise ValueError()
72            while self._paused:
73                self._condition.wait()
74
75    def UnaryCall(self, request, unused_rpc_context):
76        response = response_pb2.SimpleResponse()
77        response.payload.payload_type = payload_pb2.COMPRESSABLE
78        response.payload.payload_compressable = 'a' * request.response_size
79        self._control()
80        return response
81
82    def StreamingOutputCall(self, request, unused_rpc_context):
83        for parameter in request.response_parameters:
84            response = response_pb2.StreamingOutputCallResponse()
85            response.payload.payload_type = payload_pb2.COMPRESSABLE
86            response.payload.payload_compressable = 'a' * parameter.size
87            self._control()
88            yield response
89
90    def StreamingInputCall(self, request_iter, unused_rpc_context):
91        response = response_pb2.StreamingInputCallResponse()
92        aggregated_payload_size = 0
93        for request in request_iter:
94            aggregated_payload_size += len(request.payload.payload_compressable)
95        response.aggregated_payload_size = aggregated_payload_size
96        self._control()
97        return response
98
99    def FullDuplexCall(self, request_iter, unused_rpc_context):
100        for request in request_iter:
101            for parameter in request.response_parameters:
102                response = response_pb2.StreamingOutputCallResponse()
103                response.payload.payload_type = payload_pb2.COMPRESSABLE
104                response.payload.payload_compressable = 'a' * parameter.size
105                self._control()
106                yield response
107
108    def HalfDuplexCall(self, request_iter, unused_rpc_context):
109        responses = []
110        for request in request_iter:
111            for parameter in request.response_parameters:
112                response = response_pb2.StreamingOutputCallResponse()
113                response.payload.payload_type = payload_pb2.COMPRESSABLE
114                response.payload.payload_compressable = 'a' * parameter.size
115                self._control()
116                responses.append(response)
117        for response in responses:
118            yield response
119
120
121class _Service(
122        collections.namedtuple('_Service', (
123            'servicer_methods',
124            'server',
125            'stub',
126        ))):
127    """A live and running service.
128
129  Attributes:
130    servicer_methods: The _ServicerMethods servicing RPCs.
131    server: The grpc.Server servicing RPCs.
132    stub: A stub on which to invoke RPCs.
133  """
134
135
136def _CreateService():
137    """Provides a servicer backend and a stub.
138
139  Returns:
140    A _Service with which to test RPCs.
141  """
142    servicer_methods = _ServicerMethods()
143
144    class Servicer(getattr(service_pb2_grpc, SERVICER_IDENTIFIER)):
145
146        def UnaryCall(self, request, context):
147            return servicer_methods.UnaryCall(request, context)
148
149        def StreamingOutputCall(self, request, context):
150            return servicer_methods.StreamingOutputCall(request, context)
151
152        def StreamingInputCall(self, request_iter, context):
153            return servicer_methods.StreamingInputCall(request_iter, context)
154
155        def FullDuplexCall(self, request_iter, context):
156            return servicer_methods.FullDuplexCall(request_iter, context)
157
158        def HalfDuplexCall(self, request_iter, context):
159            return servicer_methods.HalfDuplexCall(request_iter, context)
160
161    server = test_common.test_server()
162    getattr(service_pb2_grpc, ADD_SERVICER_TO_SERVER_IDENTIFIER)(Servicer(),
163                                                                 server)
164    port = server.add_insecure_port('[::]:0')
165    server.start()
166    channel = grpc.insecure_channel('localhost:{}'.format(port))
167    stub = getattr(service_pb2_grpc, STUB_IDENTIFIER)(channel)
168    return _Service(servicer_methods, server, stub)
169
170
171def _CreateIncompleteService():
172    """Provides a servicer backend that fails to implement methods and its stub.
173
174  Returns:
175    A _Service with which to test RPCs. The returned _Service's
176      servicer_methods implements none of the methods required of it.
177  """
178
179    class Servicer(getattr(service_pb2_grpc, SERVICER_IDENTIFIER)):
180        pass
181
182    server = test_common.test_server()
183    getattr(service_pb2_grpc, ADD_SERVICER_TO_SERVER_IDENTIFIER)(Servicer(),
184                                                                 server)
185    port = server.add_insecure_port('[::]:0')
186    server.start()
187    channel = grpc.insecure_channel('localhost:{}'.format(port))
188    stub = getattr(service_pb2_grpc, STUB_IDENTIFIER)(channel)
189    return _Service(None, server, stub)
190
191
192def _streaming_input_request_iterator():
193    for _ in range(3):
194        request = request_pb2.StreamingInputCallRequest()
195        request.payload.payload_type = payload_pb2.COMPRESSABLE
196        request.payload.payload_compressable = 'a'
197        yield request
198
199
200def _streaming_output_request():
201    request = request_pb2.StreamingOutputCallRequest()
202    sizes = [1, 2, 3]
203    request.response_parameters.add(size=sizes[0], interval_us=0)
204    request.response_parameters.add(size=sizes[1], interval_us=0)
205    request.response_parameters.add(size=sizes[2], interval_us=0)
206    return request
207
208
209def _full_duplex_request_iterator():
210    request = request_pb2.StreamingOutputCallRequest()
211    request.response_parameters.add(size=1, interval_us=0)
212    yield request
213    request = request_pb2.StreamingOutputCallRequest()
214    request.response_parameters.add(size=2, interval_us=0)
215    request.response_parameters.add(size=3, interval_us=0)
216    yield request
217
218
219class PythonPluginTest(unittest.TestCase):
220    """Test case for the gRPC Python protoc-plugin.
221
222  While reading these tests, remember that the futures API
223  (`stub.method.future()`) only gives futures for the *response-unary*
224  methods and does not exist for response-streaming methods.
225  """
226
227    def testImportAttributes(self):
228        # check that we can access the generated module and its members.
229        self.assertIsNotNone(getattr(service_pb2_grpc, STUB_IDENTIFIER, None))
230        self.assertIsNotNone(
231            getattr(service_pb2_grpc, SERVICER_IDENTIFIER, None))
232        self.assertIsNotNone(
233            getattr(service_pb2_grpc, ADD_SERVICER_TO_SERVER_IDENTIFIER, None))
234
235    def testUpDown(self):
236        service = _CreateService()
237        self.assertIsNotNone(service.servicer_methods)
238        self.assertIsNotNone(service.server)
239        self.assertIsNotNone(service.stub)
240        service.server.stop(None)
241
242    def testIncompleteServicer(self):
243        service = _CreateIncompleteService()
244        request = request_pb2.SimpleRequest(response_size=13)
245        with self.assertRaises(grpc.RpcError) as exception_context:
246            service.stub.UnaryCall(request)
247        self.assertIs(exception_context.exception.code(),
248                      grpc.StatusCode.UNIMPLEMENTED)
249        service.server.stop(None)
250
251    def testUnaryCall(self):
252        service = _CreateService()
253        request = request_pb2.SimpleRequest(response_size=13)
254        response = service.stub.UnaryCall(request)
255        expected_response = service.servicer_methods.UnaryCall(
256            request, 'not a real context!')
257        self.assertEqual(expected_response, response)
258        service.server.stop(None)
259
260    def testUnaryCallFuture(self):
261        service = _CreateService()
262        request = request_pb2.SimpleRequest(response_size=13)
263        # Check that the call does not block waiting for the server to respond.
264        with service.servicer_methods.pause():
265            response_future = service.stub.UnaryCall.future(request)
266        response = response_future.result()
267        expected_response = service.servicer_methods.UnaryCall(
268            request, 'not a real RpcContext!')
269        self.assertEqual(expected_response, response)
270        service.server.stop(None)
271
272    def testUnaryCallFutureExpired(self):
273        service = _CreateService()
274        request = request_pb2.SimpleRequest(response_size=13)
275        with service.servicer_methods.pause():
276            response_future = service.stub.UnaryCall.future(
277                request, timeout=test_constants.SHORT_TIMEOUT)
278            with self.assertRaises(grpc.RpcError) as exception_context:
279                response_future.result()
280        self.assertIs(exception_context.exception.code(),
281                      grpc.StatusCode.DEADLINE_EXCEEDED)
282        self.assertIs(response_future.code(), grpc.StatusCode.DEADLINE_EXCEEDED)
283        service.server.stop(None)
284
285    def testUnaryCallFutureCancelled(self):
286        service = _CreateService()
287        request = request_pb2.SimpleRequest(response_size=13)
288        with service.servicer_methods.pause():
289            response_future = service.stub.UnaryCall.future(request)
290            response_future.cancel()
291        self.assertTrue(response_future.cancelled())
292        self.assertIs(response_future.code(), grpc.StatusCode.CANCELLED)
293        service.server.stop(None)
294
295    def testUnaryCallFutureFailed(self):
296        service = _CreateService()
297        request = request_pb2.SimpleRequest(response_size=13)
298        with service.servicer_methods.fail():
299            response_future = service.stub.UnaryCall.future(request)
300            self.assertIsNotNone(response_future.exception())
301        self.assertIs(response_future.code(), grpc.StatusCode.UNKNOWN)
302        service.server.stop(None)
303
304    def testStreamingOutputCall(self):
305        service = _CreateService()
306        request = _streaming_output_request()
307        responses = service.stub.StreamingOutputCall(request)
308        expected_responses = service.servicer_methods.StreamingOutputCall(
309            request, 'not a real RpcContext!')
310        for expected_response, response in moves.zip_longest(
311                expected_responses, responses):
312            self.assertEqual(expected_response, response)
313        service.server.stop(None)
314
315    def testStreamingOutputCallExpired(self):
316        service = _CreateService()
317        request = _streaming_output_request()
318        with service.servicer_methods.pause():
319            responses = service.stub.StreamingOutputCall(
320                request, timeout=test_constants.SHORT_TIMEOUT)
321            with self.assertRaises(grpc.RpcError) as exception_context:
322                list(responses)
323        self.assertIs(exception_context.exception.code(),
324                      grpc.StatusCode.DEADLINE_EXCEEDED)
325        service.server.stop(None)
326
327    def testStreamingOutputCallCancelled(self):
328        service = _CreateService()
329        request = _streaming_output_request()
330        responses = service.stub.StreamingOutputCall(request)
331        next(responses)
332        responses.cancel()
333        with self.assertRaises(grpc.RpcError) as exception_context:
334            next(responses)
335        self.assertIs(responses.code(), grpc.StatusCode.CANCELLED)
336        service.server.stop(None)
337
338    def testStreamingOutputCallFailed(self):
339        service = _CreateService()
340        request = _streaming_output_request()
341        with service.servicer_methods.fail():
342            responses = service.stub.StreamingOutputCall(request)
343            self.assertIsNotNone(responses)
344            with self.assertRaises(grpc.RpcError) as exception_context:
345                next(responses)
346        self.assertIs(exception_context.exception.code(),
347                      grpc.StatusCode.UNKNOWN)
348        service.server.stop(None)
349
350    def testStreamingInputCall(self):
351        service = _CreateService()
352        response = service.stub.StreamingInputCall(
353            _streaming_input_request_iterator())
354        expected_response = service.servicer_methods.StreamingInputCall(
355            _streaming_input_request_iterator(), 'not a real RpcContext!')
356        self.assertEqual(expected_response, response)
357        service.server.stop(None)
358
359    def testStreamingInputCallFuture(self):
360        service = _CreateService()
361        with service.servicer_methods.pause():
362            response_future = service.stub.StreamingInputCall.future(
363                _streaming_input_request_iterator())
364        response = response_future.result()
365        expected_response = service.servicer_methods.StreamingInputCall(
366            _streaming_input_request_iterator(), 'not a real RpcContext!')
367        self.assertEqual(expected_response, response)
368        service.server.stop(None)
369
370    def testStreamingInputCallFutureExpired(self):
371        service = _CreateService()
372        with service.servicer_methods.pause():
373            response_future = service.stub.StreamingInputCall.future(
374                _streaming_input_request_iterator(),
375                timeout=test_constants.SHORT_TIMEOUT)
376            with self.assertRaises(grpc.RpcError) as exception_context:
377                response_future.result()
378        self.assertIsInstance(response_future.exception(), grpc.RpcError)
379        self.assertIs(response_future.exception().code(),
380                      grpc.StatusCode.DEADLINE_EXCEEDED)
381        self.assertIs(exception_context.exception.code(),
382                      grpc.StatusCode.DEADLINE_EXCEEDED)
383        service.server.stop(None)
384
385    def testStreamingInputCallFutureCancelled(self):
386        service = _CreateService()
387        with service.servicer_methods.pause():
388            response_future = service.stub.StreamingInputCall.future(
389                _streaming_input_request_iterator())
390            response_future.cancel()
391        self.assertTrue(response_future.cancelled())
392        with self.assertRaises(grpc.FutureCancelledError):
393            response_future.result()
394        service.server.stop(None)
395
396    def testStreamingInputCallFutureFailed(self):
397        service = _CreateService()
398        with service.servicer_methods.fail():
399            response_future = service.stub.StreamingInputCall.future(
400                _streaming_input_request_iterator())
401            self.assertIsNotNone(response_future.exception())
402            self.assertIs(response_future.code(), grpc.StatusCode.UNKNOWN)
403        service.server.stop(None)
404
405    def testFullDuplexCall(self):
406        service = _CreateService()
407        responses = service.stub.FullDuplexCall(_full_duplex_request_iterator())
408        expected_responses = service.servicer_methods.FullDuplexCall(
409            _full_duplex_request_iterator(), 'not a real RpcContext!')
410        for expected_response, response in moves.zip_longest(
411                expected_responses, responses):
412            self.assertEqual(expected_response, response)
413        service.server.stop(None)
414
415    def testFullDuplexCallExpired(self):
416        request_iterator = _full_duplex_request_iterator()
417        service = _CreateService()
418        with service.servicer_methods.pause():
419            responses = service.stub.FullDuplexCall(
420                request_iterator, timeout=test_constants.SHORT_TIMEOUT)
421            with self.assertRaises(grpc.RpcError) as exception_context:
422                list(responses)
423        self.assertIs(exception_context.exception.code(),
424                      grpc.StatusCode.DEADLINE_EXCEEDED)
425        service.server.stop(None)
426
427    def testFullDuplexCallCancelled(self):
428        service = _CreateService()
429        request_iterator = _full_duplex_request_iterator()
430        responses = service.stub.FullDuplexCall(request_iterator)
431        next(responses)
432        responses.cancel()
433        with self.assertRaises(grpc.RpcError) as exception_context:
434            next(responses)
435        self.assertIs(exception_context.exception.code(),
436                      grpc.StatusCode.CANCELLED)
437        service.server.stop(None)
438
439    def testFullDuplexCallFailed(self):
440        request_iterator = _full_duplex_request_iterator()
441        service = _CreateService()
442        with service.servicer_methods.fail():
443            responses = service.stub.FullDuplexCall(request_iterator)
444            with self.assertRaises(grpc.RpcError) as exception_context:
445                next(responses)
446        self.assertIs(exception_context.exception.code(),
447                      grpc.StatusCode.UNKNOWN)
448        service.server.stop(None)
449
450    def testHalfDuplexCall(self):
451        service = _CreateService()
452
453        def half_duplex_request_iterator():
454            request = request_pb2.StreamingOutputCallRequest()
455            request.response_parameters.add(size=1, interval_us=0)
456            yield request
457            request = request_pb2.StreamingOutputCallRequest()
458            request.response_parameters.add(size=2, interval_us=0)
459            request.response_parameters.add(size=3, interval_us=0)
460            yield request
461
462        responses = service.stub.HalfDuplexCall(half_duplex_request_iterator())
463        expected_responses = service.servicer_methods.HalfDuplexCall(
464            half_duplex_request_iterator(), 'not a real RpcContext!')
465        for expected_response, response in moves.zip_longest(
466                expected_responses, responses):
467            self.assertEqual(expected_response, response)
468        service.server.stop(None)
469
470    def testHalfDuplexCallWedged(self):
471        condition = threading.Condition()
472        wait_cell = [False]
473
474        @contextlib.contextmanager
475        def wait():  # pylint: disable=invalid-name
476            # Where's Python 3's 'nonlocal' statement when you need it?
477            with condition:
478                wait_cell[0] = True
479            yield
480            with condition:
481                wait_cell[0] = False
482                condition.notify_all()
483
484        def half_duplex_request_iterator():
485            request = request_pb2.StreamingOutputCallRequest()
486            request.response_parameters.add(size=1, interval_us=0)
487            yield request
488            with condition:
489                while wait_cell[0]:
490                    condition.wait()
491
492        service = _CreateService()
493        with wait():
494            responses = service.stub.HalfDuplexCall(
495                half_duplex_request_iterator(),
496                timeout=test_constants.SHORT_TIMEOUT)
497            # half-duplex waits for the client to send all info
498            with self.assertRaises(grpc.RpcError) as exception_context:
499                next(responses)
500        self.assertIs(exception_context.exception.code(),
501                      grpc.StatusCode.DEADLINE_EXCEEDED)
502        service.server.stop(None)
503
504
505if __name__ == '__main__':
506    unittest.main(verbosity=2)
507