1# Copyright 2015 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Implementations of interoperability test methods."""
15
16import enum
17import json
18import os
19import threading
20
21from google import auth as google_auth
22from google.auth import environment_vars as google_auth_environment_vars
23from google.auth.transport import grpc as google_auth_transport_grpc
24from google.auth.transport import requests as google_auth_transport_requests
25import grpc
26from grpc.beta import implementations
27
28from src.proto.grpc.testing import empty_pb2
29from src.proto.grpc.testing import messages_pb2
30from src.proto.grpc.testing import test_pb2_grpc
31
32_INITIAL_METADATA_KEY = "x-grpc-test-echo-initial"
33_TRAILING_METADATA_KEY = "x-grpc-test-echo-trailing-bin"
34
35
36def _maybe_echo_metadata(servicer_context):
37    """Copies metadata from request to response if it is present."""
38    invocation_metadata = dict(servicer_context.invocation_metadata())
39    if _INITIAL_METADATA_KEY in invocation_metadata:
40        initial_metadatum = (_INITIAL_METADATA_KEY,
41                             invocation_metadata[_INITIAL_METADATA_KEY])
42        servicer_context.send_initial_metadata((initial_metadatum,))
43    if _TRAILING_METADATA_KEY in invocation_metadata:
44        trailing_metadatum = (_TRAILING_METADATA_KEY,
45                              invocation_metadata[_TRAILING_METADATA_KEY])
46        servicer_context.set_trailing_metadata((trailing_metadatum,))
47
48
49def _maybe_echo_status_and_message(request, servicer_context):
50    """Sets the response context code and details if the request asks for them"""
51    if request.HasField('response_status'):
52        servicer_context.set_code(request.response_status.code)
53        servicer_context.set_details(request.response_status.message)
54
55
56class TestService(test_pb2_grpc.TestServiceServicer):
57
58    def EmptyCall(self, request, context):
59        _maybe_echo_metadata(context)
60        return empty_pb2.Empty()
61
62    def UnaryCall(self, request, context):
63        _maybe_echo_metadata(context)
64        _maybe_echo_status_and_message(request, context)
65        return messages_pb2.SimpleResponse(
66            payload=messages_pb2.Payload(
67                type=messages_pb2.COMPRESSABLE,
68                body=b'\x00' * request.response_size))
69
70    def StreamingOutputCall(self, request, context):
71        _maybe_echo_status_and_message(request, context)
72        for response_parameters in request.response_parameters:
73            yield messages_pb2.StreamingOutputCallResponse(
74                payload=messages_pb2.Payload(
75                    type=request.response_type,
76                    body=b'\x00' * response_parameters.size))
77
78    def StreamingInputCall(self, request_iterator, context):
79        aggregate_size = 0
80        for request in request_iterator:
81            if request.payload is not None and request.payload.body:
82                aggregate_size += len(request.payload.body)
83        return messages_pb2.StreamingInputCallResponse(
84            aggregated_payload_size=aggregate_size)
85
86    def FullDuplexCall(self, request_iterator, context):
87        _maybe_echo_metadata(context)
88        for request in request_iterator:
89            _maybe_echo_status_and_message(request, context)
90            for response_parameters in request.response_parameters:
91                yield messages_pb2.StreamingOutputCallResponse(
92                    payload=messages_pb2.Payload(
93                        type=request.payload.type,
94                        body=b'\x00' * response_parameters.size))
95
96    # NOTE(nathaniel): Apparently this is the same as the full-duplex call?
97    # NOTE(atash): It isn't even called in the interop spec (Oct 22 2015)...
98    def HalfDuplexCall(self, request_iterator, context):
99        return self.FullDuplexCall(request_iterator, context)
100
101
102def _expect_status_code(call, expected_code):
103    if call.code() != expected_code:
104        raise ValueError('expected code %s, got %s' % (expected_code,
105                                                       call.code()))
106
107
108def _expect_status_details(call, expected_details):
109    if call.details() != expected_details:
110        raise ValueError('expected message %s, got %s' % (expected_details,
111                                                          call.details()))
112
113
114def _validate_status_code_and_details(call, expected_code, expected_details):
115    _expect_status_code(call, expected_code)
116    _expect_status_details(call, expected_details)
117
118
119def _validate_payload_type_and_length(response, expected_type, expected_length):
120    if response.payload.type is not expected_type:
121        raise ValueError('expected payload type %s, got %s' %
122                         (expected_type, type(response.payload.type)))
123    elif len(response.payload.body) != expected_length:
124        raise ValueError('expected payload body size %d, got %d' %
125                         (expected_length, len(response.payload.body)))
126
127
128def _large_unary_common_behavior(stub, fill_username, fill_oauth_scope,
129                                 call_credentials):
130    size = 314159
131    request = messages_pb2.SimpleRequest(
132        response_type=messages_pb2.COMPRESSABLE,
133        response_size=size,
134        payload=messages_pb2.Payload(body=b'\x00' * 271828),
135        fill_username=fill_username,
136        fill_oauth_scope=fill_oauth_scope)
137    response_future = stub.UnaryCall.future(
138        request, credentials=call_credentials)
139    response = response_future.result()
140    _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, size)
141    return response
142
143
144def _empty_unary(stub):
145    response = stub.EmptyCall(empty_pb2.Empty())
146    if not isinstance(response, empty_pb2.Empty):
147        raise TypeError(
148            'response is of type "%s", not empty_pb2.Empty!' % type(response))
149
150
151def _large_unary(stub):
152    _large_unary_common_behavior(stub, False, False, None)
153
154
155def _client_streaming(stub):
156    payload_body_sizes = (
157        27182,
158        8,
159        1828,
160        45904,
161    )
162    payloads = (messages_pb2.Payload(body=b'\x00' * size)
163                for size in payload_body_sizes)
164    requests = (messages_pb2.StreamingInputCallRequest(payload=payload)
165                for payload in payloads)
166    response = stub.StreamingInputCall(requests)
167    if response.aggregated_payload_size != 74922:
168        raise ValueError(
169            'incorrect size %d!' % response.aggregated_payload_size)
170
171
172def _server_streaming(stub):
173    sizes = (
174        31415,
175        9,
176        2653,
177        58979,
178    )
179
180    request = messages_pb2.StreamingOutputCallRequest(
181        response_type=messages_pb2.COMPRESSABLE,
182        response_parameters=(
183            messages_pb2.ResponseParameters(size=sizes[0]),
184            messages_pb2.ResponseParameters(size=sizes[1]),
185            messages_pb2.ResponseParameters(size=sizes[2]),
186            messages_pb2.ResponseParameters(size=sizes[3]),
187        ))
188    response_iterator = stub.StreamingOutputCall(request)
189    for index, response in enumerate(response_iterator):
190        _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE,
191                                          sizes[index])
192
193
194class _Pipe(object):
195
196    def __init__(self):
197        self._condition = threading.Condition()
198        self._values = []
199        self._open = True
200
201    def __iter__(self):
202        return self
203
204    def __next__(self):
205        return self.next()
206
207    def next(self):
208        with self._condition:
209            while not self._values and self._open:
210                self._condition.wait()
211            if self._values:
212                return self._values.pop(0)
213            else:
214                raise StopIteration()
215
216    def add(self, value):
217        with self._condition:
218            self._values.append(value)
219            self._condition.notify()
220
221    def close(self):
222        with self._condition:
223            self._open = False
224            self._condition.notify()
225
226    def __enter__(self):
227        return self
228
229    def __exit__(self, type, value, traceback):
230        self.close()
231
232
233def _ping_pong(stub):
234    request_response_sizes = (
235        31415,
236        9,
237        2653,
238        58979,
239    )
240    request_payload_sizes = (
241        27182,
242        8,
243        1828,
244        45904,
245    )
246
247    with _Pipe() as pipe:
248        response_iterator = stub.FullDuplexCall(pipe)
249        for response_size, payload_size in zip(request_response_sizes,
250                                               request_payload_sizes):
251            request = messages_pb2.StreamingOutputCallRequest(
252                response_type=messages_pb2.COMPRESSABLE,
253                response_parameters=(
254                    messages_pb2.ResponseParameters(size=response_size),),
255                payload=messages_pb2.Payload(body=b'\x00' * payload_size))
256            pipe.add(request)
257            response = next(response_iterator)
258            _validate_payload_type_and_length(
259                response, messages_pb2.COMPRESSABLE, response_size)
260
261
262def _cancel_after_begin(stub):
263    with _Pipe() as pipe:
264        response_future = stub.StreamingInputCall.future(pipe)
265        response_future.cancel()
266        if not response_future.cancelled():
267            raise ValueError('expected cancelled method to return True')
268        if response_future.code() is not grpc.StatusCode.CANCELLED:
269            raise ValueError('expected status code CANCELLED')
270
271
272def _cancel_after_first_response(stub):
273    request_response_sizes = (
274        31415,
275        9,
276        2653,
277        58979,
278    )
279    request_payload_sizes = (
280        27182,
281        8,
282        1828,
283        45904,
284    )
285    with _Pipe() as pipe:
286        response_iterator = stub.FullDuplexCall(pipe)
287
288        response_size = request_response_sizes[0]
289        payload_size = request_payload_sizes[0]
290        request = messages_pb2.StreamingOutputCallRequest(
291            response_type=messages_pb2.COMPRESSABLE,
292            response_parameters=(
293                messages_pb2.ResponseParameters(size=response_size),),
294            payload=messages_pb2.Payload(body=b'\x00' * payload_size))
295        pipe.add(request)
296        response = next(response_iterator)
297        # We test the contents of `response` in the Ping Pong test - don't check
298        # them here.
299        response_iterator.cancel()
300
301        try:
302            next(response_iterator)
303        except grpc.RpcError as rpc_error:
304            if rpc_error.code() is not grpc.StatusCode.CANCELLED:
305                raise
306        else:
307            raise ValueError('expected call to be cancelled')
308
309
310def _timeout_on_sleeping_server(stub):
311    request_payload_size = 27182
312    with _Pipe() as pipe:
313        response_iterator = stub.FullDuplexCall(pipe, timeout=0.001)
314
315        request = messages_pb2.StreamingOutputCallRequest(
316            response_type=messages_pb2.COMPRESSABLE,
317            payload=messages_pb2.Payload(body=b'\x00' * request_payload_size))
318        pipe.add(request)
319        try:
320            next(response_iterator)
321        except grpc.RpcError as rpc_error:
322            if rpc_error.code() is not grpc.StatusCode.DEADLINE_EXCEEDED:
323                raise
324        else:
325            raise ValueError('expected call to exceed deadline')
326
327
328def _empty_stream(stub):
329    with _Pipe() as pipe:
330        response_iterator = stub.FullDuplexCall(pipe)
331        pipe.close()
332        try:
333            next(response_iterator)
334            raise ValueError('expected exactly 0 responses')
335        except StopIteration:
336            pass
337
338
339def _status_code_and_message(stub):
340    details = 'test status message'
341    code = 2
342    status = grpc.StatusCode.UNKNOWN  # code = 2
343
344    # Test with a UnaryCall
345    request = messages_pb2.SimpleRequest(
346        response_type=messages_pb2.COMPRESSABLE,
347        response_size=1,
348        payload=messages_pb2.Payload(body=b'\x00'),
349        response_status=messages_pb2.EchoStatus(code=code, message=details))
350    response_future = stub.UnaryCall.future(request)
351    _validate_status_code_and_details(response_future, status, details)
352
353    # Test with a FullDuplexCall
354    with _Pipe() as pipe:
355        response_iterator = stub.FullDuplexCall(pipe)
356        request = messages_pb2.StreamingOutputCallRequest(
357            response_type=messages_pb2.COMPRESSABLE,
358            response_parameters=(messages_pb2.ResponseParameters(size=1),),
359            payload=messages_pb2.Payload(body=b'\x00'),
360            response_status=messages_pb2.EchoStatus(code=code, message=details))
361        pipe.add(request)  # sends the initial request.
362    # Dropping out of with block closes the pipe
363    _validate_status_code_and_details(response_iterator, status, details)
364
365
366def _unimplemented_method(test_service_stub):
367    response_future = (test_service_stub.UnimplementedCall.future(
368        empty_pb2.Empty()))
369    _expect_status_code(response_future, grpc.StatusCode.UNIMPLEMENTED)
370
371
372def _unimplemented_service(unimplemented_service_stub):
373    response_future = (unimplemented_service_stub.UnimplementedCall.future(
374        empty_pb2.Empty()))
375    _expect_status_code(response_future, grpc.StatusCode.UNIMPLEMENTED)
376
377
378def _custom_metadata(stub):
379    initial_metadata_value = "test_initial_metadata_value"
380    trailing_metadata_value = "\x0a\x0b\x0a\x0b\x0a\x0b"
381    metadata = ((_INITIAL_METADATA_KEY, initial_metadata_value),
382                (_TRAILING_METADATA_KEY, trailing_metadata_value))
383
384    def _validate_metadata(response):
385        initial_metadata = dict(response.initial_metadata())
386        if initial_metadata[_INITIAL_METADATA_KEY] != initial_metadata_value:
387            raise ValueError('expected initial metadata %s, got %s' %
388                             (initial_metadata_value,
389                              initial_metadata[_INITIAL_METADATA_KEY]))
390        trailing_metadata = dict(response.trailing_metadata())
391        if trailing_metadata[_TRAILING_METADATA_KEY] != trailing_metadata_value:
392            raise ValueError('expected trailing metadata %s, got %s' %
393                             (trailing_metadata_value,
394                              initial_metadata[_TRAILING_METADATA_KEY]))
395
396    # Testing with UnaryCall
397    request = messages_pb2.SimpleRequest(
398        response_type=messages_pb2.COMPRESSABLE,
399        response_size=1,
400        payload=messages_pb2.Payload(body=b'\x00'))
401    response_future = stub.UnaryCall.future(request, metadata=metadata)
402    _validate_metadata(response_future)
403
404    # Testing with FullDuplexCall
405    with _Pipe() as pipe:
406        response_iterator = stub.FullDuplexCall(pipe, metadata=metadata)
407        request = messages_pb2.StreamingOutputCallRequest(
408            response_type=messages_pb2.COMPRESSABLE,
409            response_parameters=(messages_pb2.ResponseParameters(size=1),))
410        pipe.add(request)  # Sends the request
411        next(response_iterator)  # Causes server to send trailing metadata
412    # Dropping out of the with block closes the pipe
413    _validate_metadata(response_iterator)
414
415
416def _compute_engine_creds(stub, args):
417    response = _large_unary_common_behavior(stub, True, True, None)
418    if args.default_service_account != response.username:
419        raise ValueError('expected username %s, got %s' %
420                         (args.default_service_account, response.username))
421
422
423def _oauth2_auth_token(stub, args):
424    json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS]
425    wanted_email = json.load(open(json_key_filename, 'rb'))['client_email']
426    response = _large_unary_common_behavior(stub, True, True, None)
427    if wanted_email != response.username:
428        raise ValueError('expected username %s, got %s' % (wanted_email,
429                                                           response.username))
430    if args.oauth_scope.find(response.oauth_scope) == -1:
431        raise ValueError(
432            'expected to find oauth scope "{}" in received "{}"'.format(
433                response.oauth_scope, args.oauth_scope))
434
435
436def _jwt_token_creds(stub, args):
437    json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS]
438    wanted_email = json.load(open(json_key_filename, 'rb'))['client_email']
439    response = _large_unary_common_behavior(stub, True, False, None)
440    if wanted_email != response.username:
441        raise ValueError('expected username %s, got %s' % (wanted_email,
442                                                           response.username))
443
444
445def _per_rpc_creds(stub, args):
446    json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS]
447    wanted_email = json.load(open(json_key_filename, 'rb'))['client_email']
448    google_credentials, unused_project_id = google_auth.default(
449        scopes=[args.oauth_scope])
450    call_credentials = grpc.metadata_call_credentials(
451        google_auth_transport_grpc.AuthMetadataPlugin(
452            credentials=google_credentials,
453            request=google_auth_transport_requests.Request()))
454    response = _large_unary_common_behavior(stub, True, False, call_credentials)
455    if wanted_email != response.username:
456        raise ValueError('expected username %s, got %s' % (wanted_email,
457                                                           response.username))
458
459
460@enum.unique
461class TestCase(enum.Enum):
462    EMPTY_UNARY = 'empty_unary'
463    LARGE_UNARY = 'large_unary'
464    SERVER_STREAMING = 'server_streaming'
465    CLIENT_STREAMING = 'client_streaming'
466    PING_PONG = 'ping_pong'
467    CANCEL_AFTER_BEGIN = 'cancel_after_begin'
468    CANCEL_AFTER_FIRST_RESPONSE = 'cancel_after_first_response'
469    EMPTY_STREAM = 'empty_stream'
470    STATUS_CODE_AND_MESSAGE = 'status_code_and_message'
471    UNIMPLEMENTED_METHOD = 'unimplemented_method'
472    UNIMPLEMENTED_SERVICE = 'unimplemented_service'
473    CUSTOM_METADATA = "custom_metadata"
474    COMPUTE_ENGINE_CREDS = 'compute_engine_creds'
475    OAUTH2_AUTH_TOKEN = 'oauth2_auth_token'
476    JWT_TOKEN_CREDS = 'jwt_token_creds'
477    PER_RPC_CREDS = 'per_rpc_creds'
478    TIMEOUT_ON_SLEEPING_SERVER = 'timeout_on_sleeping_server'
479
480    def test_interoperability(self, stub, args):
481        if self is TestCase.EMPTY_UNARY:
482            _empty_unary(stub)
483        elif self is TestCase.LARGE_UNARY:
484            _large_unary(stub)
485        elif self is TestCase.SERVER_STREAMING:
486            _server_streaming(stub)
487        elif self is TestCase.CLIENT_STREAMING:
488            _client_streaming(stub)
489        elif self is TestCase.PING_PONG:
490            _ping_pong(stub)
491        elif self is TestCase.CANCEL_AFTER_BEGIN:
492            _cancel_after_begin(stub)
493        elif self is TestCase.CANCEL_AFTER_FIRST_RESPONSE:
494            _cancel_after_first_response(stub)
495        elif self is TestCase.TIMEOUT_ON_SLEEPING_SERVER:
496            _timeout_on_sleeping_server(stub)
497        elif self is TestCase.EMPTY_STREAM:
498            _empty_stream(stub)
499        elif self is TestCase.STATUS_CODE_AND_MESSAGE:
500            _status_code_and_message(stub)
501        elif self is TestCase.UNIMPLEMENTED_METHOD:
502            _unimplemented_method(stub)
503        elif self is TestCase.UNIMPLEMENTED_SERVICE:
504            _unimplemented_service(stub)
505        elif self is TestCase.CUSTOM_METADATA:
506            _custom_metadata(stub)
507        elif self is TestCase.COMPUTE_ENGINE_CREDS:
508            _compute_engine_creds(stub, args)
509        elif self is TestCase.OAUTH2_AUTH_TOKEN:
510            _oauth2_auth_token(stub, args)
511        elif self is TestCase.JWT_TOKEN_CREDS:
512            _jwt_token_creds(stub, args)
513        elif self is TestCase.PER_RPC_CREDS:
514            _per_rpc_creds(stub, args)
515        else:
516            raise NotImplementedError(
517                'Test case "%s" not implemented!' % self.name)
518