1# Copyright 2016 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests server and client side metadata API."""
15
16import unittest
17import weakref
18
19import grpc
20from grpc import _channel
21
22from tests.unit import test_common
23from tests.unit.framework.common import test_constants
24
25_CHANNEL_ARGS = (('grpc.primary_user_agent', 'primary-agent'),
26                 ('grpc.secondary_user_agent', 'secondary-agent'))
27
28_REQUEST = b'\x00\x00\x00'
29_RESPONSE = b'\x00\x00\x00'
30
31_UNARY_UNARY = '/test/UnaryUnary'
32_UNARY_STREAM = '/test/UnaryStream'
33_STREAM_UNARY = '/test/StreamUnary'
34_STREAM_STREAM = '/test/StreamStream'
35
36_INVOCATION_METADATA = (
37    (
38        b'invocation-md-key',
39        u'invocation-md-value',
40    ),
41    (
42        u'invocation-md-key-bin',
43        b'\x00\x01',
44    ),
45)
46_EXPECTED_INVOCATION_METADATA = (
47    (
48        'invocation-md-key',
49        'invocation-md-value',
50    ),
51    (
52        'invocation-md-key-bin',
53        b'\x00\x01',
54    ),
55)
56
57_INITIAL_METADATA = ((b'initial-md-key', u'initial-md-value'),
58                     (u'initial-md-key-bin', b'\x00\x02'))
59_EXPECTED_INITIAL_METADATA = (
60    (
61        'initial-md-key',
62        'initial-md-value',
63    ),
64    (
65        'initial-md-key-bin',
66        b'\x00\x02',
67    ),
68)
69
70_TRAILING_METADATA = (
71    (
72        'server-trailing-md-key',
73        'server-trailing-md-value',
74    ),
75    (
76        'server-trailing-md-key-bin',
77        b'\x00\x03',
78    ),
79)
80_EXPECTED_TRAILING_METADATA = _TRAILING_METADATA
81
82
83def _user_agent(metadata):
84    for key, val in metadata:
85        if key == 'user-agent':
86            return val
87    raise KeyError('No user agent!')
88
89
90def validate_client_metadata(test, servicer_context):
91    invocation_metadata = servicer_context.invocation_metadata()
92    test.assertTrue(
93        test_common.metadata_transmitted(_EXPECTED_INVOCATION_METADATA,
94                                         invocation_metadata))
95    user_agent = _user_agent(invocation_metadata)
96    test.assertTrue(
97        user_agent.startswith('primary-agent ' + _channel._USER_AGENT))
98    test.assertTrue(user_agent.endswith('secondary-agent'))
99
100
101def handle_unary_unary(test, request, servicer_context):
102    validate_client_metadata(test, servicer_context)
103    servicer_context.send_initial_metadata(_INITIAL_METADATA)
104    servicer_context.set_trailing_metadata(_TRAILING_METADATA)
105    return _RESPONSE
106
107
108def handle_unary_stream(test, request, servicer_context):
109    validate_client_metadata(test, servicer_context)
110    servicer_context.send_initial_metadata(_INITIAL_METADATA)
111    servicer_context.set_trailing_metadata(_TRAILING_METADATA)
112    for _ in range(test_constants.STREAM_LENGTH):
113        yield _RESPONSE
114
115
116def handle_stream_unary(test, request_iterator, servicer_context):
117    validate_client_metadata(test, servicer_context)
118    servicer_context.send_initial_metadata(_INITIAL_METADATA)
119    servicer_context.set_trailing_metadata(_TRAILING_METADATA)
120    # TODO(issue:#6891) We should be able to remove this loop
121    for request in request_iterator:
122        pass
123    return _RESPONSE
124
125
126def handle_stream_stream(test, request_iterator, servicer_context):
127    validate_client_metadata(test, servicer_context)
128    servicer_context.send_initial_metadata(_INITIAL_METADATA)
129    servicer_context.set_trailing_metadata(_TRAILING_METADATA)
130    # TODO(issue:#6891) We should be able to remove this loop,
131    # and replace with return; yield
132    for request in request_iterator:
133        yield _RESPONSE
134
135
136class _MethodHandler(grpc.RpcMethodHandler):
137
138    def __init__(self, test, request_streaming, response_streaming):
139        self.request_streaming = request_streaming
140        self.response_streaming = response_streaming
141        self.request_deserializer = None
142        self.response_serializer = None
143        self.unary_unary = None
144        self.unary_stream = None
145        self.stream_unary = None
146        self.stream_stream = None
147        if self.request_streaming and self.response_streaming:
148            self.stream_stream = lambda x, y: handle_stream_stream(test, x, y)
149        elif self.request_streaming:
150            self.stream_unary = lambda x, y: handle_stream_unary(test, x, y)
151        elif self.response_streaming:
152            self.unary_stream = lambda x, y: handle_unary_stream(test, x, y)
153        else:
154            self.unary_unary = lambda x, y: handle_unary_unary(test, x, y)
155
156
157class _GenericHandler(grpc.GenericRpcHandler):
158
159    def __init__(self, test):
160        self._test = test
161
162    def service(self, handler_call_details):
163        if handler_call_details.method == _UNARY_UNARY:
164            return _MethodHandler(self._test, False, False)
165        elif handler_call_details.method == _UNARY_STREAM:
166            return _MethodHandler(self._test, False, True)
167        elif handler_call_details.method == _STREAM_UNARY:
168            return _MethodHandler(self._test, True, False)
169        elif handler_call_details.method == _STREAM_STREAM:
170            return _MethodHandler(self._test, True, True)
171        else:
172            return None
173
174
175class MetadataTest(unittest.TestCase):
176
177    def setUp(self):
178        self._server = test_common.test_server()
179        self._server.add_generic_rpc_handlers((_GenericHandler(
180            weakref.proxy(self)),))
181        port = self._server.add_insecure_port('[::]:0')
182        self._server.start()
183        self._channel = grpc.insecure_channel(
184            'localhost:%d' % port, options=_CHANNEL_ARGS)
185
186    def tearDown(self):
187        self._server.stop(0)
188
189    def testUnaryUnary(self):
190        multi_callable = self._channel.unary_unary(_UNARY_UNARY)
191        unused_response, call = multi_callable.with_call(
192            _REQUEST, metadata=_INVOCATION_METADATA)
193        self.assertTrue(
194            test_common.metadata_transmitted(_EXPECTED_INITIAL_METADATA,
195                                             call.initial_metadata()))
196        self.assertTrue(
197            test_common.metadata_transmitted(_EXPECTED_TRAILING_METADATA,
198                                             call.trailing_metadata()))
199
200    def testUnaryStream(self):
201        multi_callable = self._channel.unary_stream(_UNARY_STREAM)
202        call = multi_callable(_REQUEST, metadata=_INVOCATION_METADATA)
203        self.assertTrue(
204            test_common.metadata_transmitted(_EXPECTED_INITIAL_METADATA,
205                                             call.initial_metadata()))
206        for _ in call:
207            pass
208        self.assertTrue(
209            test_common.metadata_transmitted(_EXPECTED_TRAILING_METADATA,
210                                             call.trailing_metadata()))
211
212    def testStreamUnary(self):
213        multi_callable = self._channel.stream_unary(_STREAM_UNARY)
214        unused_response, call = multi_callable.with_call(
215            iter([_REQUEST] * test_constants.STREAM_LENGTH),
216            metadata=_INVOCATION_METADATA)
217        self.assertTrue(
218            test_common.metadata_transmitted(_EXPECTED_INITIAL_METADATA,
219                                             call.initial_metadata()))
220        self.assertTrue(
221            test_common.metadata_transmitted(_EXPECTED_TRAILING_METADATA,
222                                             call.trailing_metadata()))
223
224    def testStreamStream(self):
225        multi_callable = self._channel.stream_stream(_STREAM_STREAM)
226        call = multi_callable(
227            iter([_REQUEST] * test_constants.STREAM_LENGTH),
228            metadata=_INVOCATION_METADATA)
229        self.assertTrue(
230            test_common.metadata_transmitted(_EXPECTED_INITIAL_METADATA,
231                                             call.initial_metadata()))
232        for _ in call:
233            pass
234        self.assertTrue(
235            test_common.metadata_transmitted(_EXPECTED_TRAILING_METADATA,
236                                             call.trailing_metadata()))
237
238
239if __name__ == '__main__':
240    unittest.main(verbosity=2)
241