1# Copyright 2018 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Implementations of fork support test methods."""
15
16import enum
17import json
18import logging
19import multiprocessing
20import os
21import threading
22import time
23
24import grpc
25
26from six.moves import queue
27
28from src.proto.grpc.testing import empty_pb2
29from src.proto.grpc.testing import messages_pb2
30from src.proto.grpc.testing import test_pb2_grpc
31
32_LOGGER = logging.getLogger(__name__)
33
34
35def _channel(args):
36    target = '{}:{}'.format(args.server_host, args.server_port)
37    if args.use_tls:
38        channel_credentials = grpc.ssl_channel_credentials()
39        channel = grpc.secure_channel(target, channel_credentials)
40    else:
41        channel = grpc.insecure_channel(target)
42    return channel
43
44
45def _validate_payload_type_and_length(response, expected_type, expected_length):
46    if response.payload.type is not expected_type:
47        raise ValueError('expected payload type %s, got %s' %
48                         (expected_type, type(response.payload.type)))
49    elif len(response.payload.body) != expected_length:
50        raise ValueError('expected payload body size %d, got %d' %
51                         (expected_length, len(response.payload.body)))
52
53
54def _async_unary(stub):
55    size = 314159
56    request = messages_pb2.SimpleRequest(
57        response_type=messages_pb2.COMPRESSABLE,
58        response_size=size,
59        payload=messages_pb2.Payload(body=b'\x00' * 271828))
60    response_future = stub.UnaryCall.future(request)
61    response = response_future.result()
62    _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, size)
63
64
65def _blocking_unary(stub):
66    size = 314159
67    request = messages_pb2.SimpleRequest(
68        response_type=messages_pb2.COMPRESSABLE,
69        response_size=size,
70        payload=messages_pb2.Payload(body=b'\x00' * 271828))
71    response = stub.UnaryCall(request)
72    _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, size)
73
74
75class _Pipe(object):
76
77    def __init__(self):
78        self._condition = threading.Condition()
79        self._values = []
80        self._open = True
81
82    def __iter__(self):
83        return self
84
85    def __next__(self):
86        return self.next()
87
88    def next(self):
89        with self._condition:
90            while not self._values and self._open:
91                self._condition.wait()
92            if self._values:
93                return self._values.pop(0)
94            else:
95                raise StopIteration()
96
97    def add(self, value):
98        with self._condition:
99            self._values.append(value)
100            self._condition.notify()
101
102    def close(self):
103        with self._condition:
104            self._open = False
105            self._condition.notify()
106
107    def __enter__(self):
108        return self
109
110    def __exit__(self, type, value, traceback):
111        self.close()
112
113
114class _ChildProcess(object):
115
116    def __init__(self, task, args=None):
117        if args is None:
118            args = ()
119        self._exceptions = multiprocessing.Queue()
120
121        def record_exceptions():
122            try:
123                task(*args)
124            except Exception as e:  # pylint: disable=broad-except
125                self._exceptions.put(e)
126
127        self._process = multiprocessing.Process(target=record_exceptions)
128
129    def start(self):
130        self._process.start()
131
132    def finish(self):
133        self._process.join()
134        if self._process.exitcode != 0:
135            raise ValueError('Child process failed with exitcode %d' %
136                             self._process.exitcode)
137        try:
138            exception = self._exceptions.get(block=False)
139            raise ValueError('Child process failed: %s' % exception)
140        except queue.Empty:
141            pass
142
143
144def _async_unary_same_channel(channel):
145
146    def child_target():
147        try:
148            _async_unary(stub)
149            raise Exception(
150                'Child should not be able to re-use channel after fork')
151        except ValueError as expected_value_error:
152            pass
153
154    stub = test_pb2_grpc.TestServiceStub(channel)
155    _async_unary(stub)
156    child_process = _ChildProcess(child_target)
157    child_process.start()
158    _async_unary(stub)
159    child_process.finish()
160
161
162def _async_unary_new_channel(channel, args):
163
164    def child_target():
165        child_channel = _channel(args)
166        child_stub = test_pb2_grpc.TestServiceStub(child_channel)
167        _async_unary(child_stub)
168        child_channel.close()
169
170    stub = test_pb2_grpc.TestServiceStub(channel)
171    _async_unary(stub)
172    child_process = _ChildProcess(child_target)
173    child_process.start()
174    _async_unary(stub)
175    child_process.finish()
176
177
178def _blocking_unary_same_channel(channel):
179
180    def child_target():
181        try:
182            _blocking_unary(stub)
183            raise Exception(
184                'Child should not be able to re-use channel after fork')
185        except ValueError as expected_value_error:
186            pass
187
188    stub = test_pb2_grpc.TestServiceStub(channel)
189    _blocking_unary(stub)
190    child_process = _ChildProcess(child_target)
191    child_process.start()
192    child_process.finish()
193
194
195def _blocking_unary_new_channel(channel, args):
196
197    def child_target():
198        child_channel = _channel(args)
199        child_stub = test_pb2_grpc.TestServiceStub(child_channel)
200        _blocking_unary(child_stub)
201        child_channel.close()
202
203    stub = test_pb2_grpc.TestServiceStub(channel)
204    _blocking_unary(stub)
205    child_process = _ChildProcess(child_target)
206    child_process.start()
207    _blocking_unary(stub)
208    child_process.finish()
209
210
211# Verify that the fork channel registry can handle already closed channels
212def _close_channel_before_fork(channel, args):
213
214    def child_target():
215        new_channel.close()
216        child_channel = _channel(args)
217        child_stub = test_pb2_grpc.TestServiceStub(child_channel)
218        _blocking_unary(child_stub)
219        child_channel.close()
220
221    stub = test_pb2_grpc.TestServiceStub(channel)
222    _blocking_unary(stub)
223    channel.close()
224
225    new_channel = _channel(args)
226    new_stub = test_pb2_grpc.TestServiceStub(new_channel)
227    child_process = _ChildProcess(child_target)
228    child_process.start()
229    _blocking_unary(new_stub)
230    child_process.finish()
231
232
233def _connectivity_watch(channel, args):
234
235    def child_target():
236
237        def child_connectivity_callback(state):
238            child_states.append(state)
239
240        child_states = []
241        child_channel = _channel(args)
242        child_stub = test_pb2_grpc.TestServiceStub(child_channel)
243        child_channel.subscribe(child_connectivity_callback)
244        _async_unary(child_stub)
245        if len(child_states
246              ) < 2 or child_states[-1] != grpc.ChannelConnectivity.READY:
247            raise ValueError('Channel did not move to READY')
248        if len(parent_states) > 1:
249            raise ValueError('Received connectivity updates on parent callback')
250        child_channel.unsubscribe(child_connectivity_callback)
251        child_channel.close()
252
253    def parent_connectivity_callback(state):
254        parent_states.append(state)
255
256    parent_states = []
257    channel.subscribe(parent_connectivity_callback)
258    stub = test_pb2_grpc.TestServiceStub(channel)
259    child_process = _ChildProcess(child_target)
260    child_process.start()
261    _async_unary(stub)
262    if len(parent_states
263          ) < 2 or parent_states[-1] != grpc.ChannelConnectivity.READY:
264        raise ValueError('Channel did not move to READY')
265    channel.unsubscribe(parent_connectivity_callback)
266    child_process.finish()
267
268    # Need to unsubscribe or _channel.py in _poll_connectivity triggers a
269    # "Cannot invoke RPC on closed channel!" error.
270    # TODO(ericgribkoff) Fix issue with channel.close() and connectivity polling
271    channel.unsubscribe(parent_connectivity_callback)
272
273
274def _ping_pong_with_child_processes_after_first_response(
275        channel, args, child_target, run_after_close=True):
276    request_response_sizes = (
277        31415,
278        9,
279        2653,
280        58979,
281    )
282    request_payload_sizes = (
283        27182,
284        8,
285        1828,
286        45904,
287    )
288    stub = test_pb2_grpc.TestServiceStub(channel)
289    pipe = _Pipe()
290    parent_bidi_call = stub.FullDuplexCall(pipe)
291    child_processes = []
292    first_message_received = False
293    for response_size, payload_size in zip(request_response_sizes,
294                                           request_payload_sizes):
295        request = messages_pb2.StreamingOutputCallRequest(
296            response_type=messages_pb2.COMPRESSABLE,
297            response_parameters=(
298                messages_pb2.ResponseParameters(size=response_size),),
299            payload=messages_pb2.Payload(body=b'\x00' * payload_size))
300        pipe.add(request)
301        if first_message_received:
302            child_process = _ChildProcess(child_target,
303                                          (parent_bidi_call, channel, args))
304            child_process.start()
305            child_processes.append(child_process)
306        response = next(parent_bidi_call)
307        first_message_received = True
308        child_process = _ChildProcess(child_target,
309                                      (parent_bidi_call, channel, args))
310        child_process.start()
311        child_processes.append(child_process)
312        _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE,
313                                          response_size)
314    pipe.close()
315    if run_after_close:
316        child_process = _ChildProcess(child_target,
317                                      (parent_bidi_call, channel, args))
318        child_process.start()
319        child_processes.append(child_process)
320    for child_process in child_processes:
321        child_process.finish()
322
323
324def _in_progress_bidi_continue_call(channel):
325
326    def child_target(parent_bidi_call, parent_channel, args):
327        stub = test_pb2_grpc.TestServiceStub(parent_channel)
328        try:
329            _async_unary(stub)
330            raise Exception(
331                'Child should not be able to re-use channel after fork')
332        except ValueError as expected_value_error:
333            pass
334        inherited_code = parent_bidi_call.code()
335        inherited_details = parent_bidi_call.details()
336        if inherited_code != grpc.StatusCode.CANCELLED:
337            raise ValueError(
338                'Expected inherited code CANCELLED, got %s' % inherited_code)
339        if inherited_details != 'Channel closed due to fork':
340            raise ValueError(
341                'Expected inherited details Channel closed due to fork, got %s'
342                % inherited_details)
343
344    # Don't run child_target after closing the parent call, as the call may have
345    # received a status from the  server before fork occurs.
346    _ping_pong_with_child_processes_after_first_response(
347        channel, None, child_target, run_after_close=False)
348
349
350def _in_progress_bidi_same_channel_async_call(channel):
351
352    def child_target(parent_bidi_call, parent_channel, args):
353        stub = test_pb2_grpc.TestServiceStub(parent_channel)
354        try:
355            _async_unary(stub)
356            raise Exception(
357                'Child should not be able to re-use channel after fork')
358        except ValueError as expected_value_error:
359            pass
360
361    _ping_pong_with_child_processes_after_first_response(
362        channel, None, child_target)
363
364
365def _in_progress_bidi_same_channel_blocking_call(channel):
366
367    def child_target(parent_bidi_call, parent_channel, args):
368        stub = test_pb2_grpc.TestServiceStub(parent_channel)
369        try:
370            _blocking_unary(stub)
371            raise Exception(
372                'Child should not be able to re-use channel after fork')
373        except ValueError as expected_value_error:
374            pass
375
376    _ping_pong_with_child_processes_after_first_response(
377        channel, None, child_target)
378
379
380def _in_progress_bidi_new_channel_async_call(channel, args):
381
382    def child_target(parent_bidi_call, parent_channel, args):
383        channel = _channel(args)
384        stub = test_pb2_grpc.TestServiceStub(channel)
385        _async_unary(stub)
386
387    _ping_pong_with_child_processes_after_first_response(
388        channel, args, child_target)
389
390
391def _in_progress_bidi_new_channel_blocking_call(channel, args):
392
393    def child_target(parent_bidi_call, parent_channel, args):
394        channel = _channel(args)
395        stub = test_pb2_grpc.TestServiceStub(channel)
396        _blocking_unary(stub)
397
398    _ping_pong_with_child_processes_after_first_response(
399        channel, args, child_target)
400
401
402@enum.unique
403class TestCase(enum.Enum):
404
405    CONNECTIVITY_WATCH = 'connectivity_watch'
406    CLOSE_CHANNEL_BEFORE_FORK = 'close_channel_before_fork'
407    ASYNC_UNARY_SAME_CHANNEL = 'async_unary_same_channel'
408    ASYNC_UNARY_NEW_CHANNEL = 'async_unary_new_channel'
409    BLOCKING_UNARY_SAME_CHANNEL = 'blocking_unary_same_channel'
410    BLOCKING_UNARY_NEW_CHANNEL = 'blocking_unary_new_channel'
411    IN_PROGRESS_BIDI_CONTINUE_CALL = 'in_progress_bidi_continue_call'
412    IN_PROGRESS_BIDI_SAME_CHANNEL_ASYNC_CALL = 'in_progress_bidi_same_channel_async_call'
413    IN_PROGRESS_BIDI_SAME_CHANNEL_BLOCKING_CALL = 'in_progress_bidi_same_channel_blocking_call'
414    IN_PROGRESS_BIDI_NEW_CHANNEL_ASYNC_CALL = 'in_progress_bidi_new_channel_async_call'
415    IN_PROGRESS_BIDI_NEW_CHANNEL_BLOCKING_CALL = 'in_progress_bidi_new_channel_blocking_call'
416
417    def run_test(self, args):
418        _LOGGER.info("Running %s", self)
419        channel = _channel(args)
420        if self is TestCase.ASYNC_UNARY_SAME_CHANNEL:
421            _async_unary_same_channel(channel)
422        elif self is TestCase.ASYNC_UNARY_NEW_CHANNEL:
423            _async_unary_new_channel(channel, args)
424        elif self is TestCase.BLOCKING_UNARY_SAME_CHANNEL:
425            _blocking_unary_same_channel(channel)
426        elif self is TestCase.BLOCKING_UNARY_NEW_CHANNEL:
427            _blocking_unary_new_channel(channel, args)
428        elif self is TestCase.CLOSE_CHANNEL_BEFORE_FORK:
429            _close_channel_before_fork(channel, args)
430        elif self is TestCase.CONNECTIVITY_WATCH:
431            _connectivity_watch(channel, args)
432        elif self is TestCase.IN_PROGRESS_BIDI_CONTINUE_CALL:
433            _in_progress_bidi_continue_call(channel)
434        elif self is TestCase.IN_PROGRESS_BIDI_SAME_CHANNEL_ASYNC_CALL:
435            _in_progress_bidi_same_channel_async_call(channel)
436        elif self is TestCase.IN_PROGRESS_BIDI_SAME_CHANNEL_BLOCKING_CALL:
437            _in_progress_bidi_same_channel_blocking_call(channel)
438        elif self is TestCase.IN_PROGRESS_BIDI_NEW_CHANNEL_ASYNC_CALL:
439            _in_progress_bidi_new_channel_async_call(channel, args)
440        elif self is TestCase.IN_PROGRESS_BIDI_NEW_CHANNEL_BLOCKING_CALL:
441            _in_progress_bidi_new_channel_blocking_call(channel, args)
442        else:
443            raise NotImplementedError(
444                'Test case "%s" not implemented!' % self.name)
445        channel.close()
446