1# Copyright 2016 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Translates gRPC's server-side API into gRPC's server-side Beta API."""
15
16import collections
17import threading
18
19import grpc
20from grpc import _common
21from grpc.beta import _metadata
22from grpc.beta import interfaces
23from grpc.framework.common import cardinality
24from grpc.framework.common import style
25from grpc.framework.foundation import abandonment
26from grpc.framework.foundation import logging_pool
27from grpc.framework.foundation import stream
28from grpc.framework.interfaces.face import face
29
30# pylint: disable=too-many-return-statements
31
32_DEFAULT_POOL_SIZE = 8
33
34
35class _ServerProtocolContext(interfaces.GRPCServicerContext):
36
37    def __init__(self, servicer_context):
38        self._servicer_context = servicer_context
39
40    def peer(self):
41        return self._servicer_context.peer()
42
43    def disable_next_response_compression(self):
44        pass  # TODO(https://github.com/grpc/grpc/issues/4078): design, implement.
45
46
47class _FaceServicerContext(face.ServicerContext):
48
49    def __init__(self, servicer_context):
50        self._servicer_context = servicer_context
51
52    def is_active(self):
53        return self._servicer_context.is_active()
54
55    def time_remaining(self):
56        return self._servicer_context.time_remaining()
57
58    def add_abortion_callback(self, abortion_callback):
59        raise NotImplementedError(
60            'add_abortion_callback no longer supported server-side!')
61
62    def cancel(self):
63        self._servicer_context.cancel()
64
65    def protocol_context(self):
66        return _ServerProtocolContext(self._servicer_context)
67
68    def invocation_metadata(self):
69        return _metadata.beta(self._servicer_context.invocation_metadata())
70
71    def initial_metadata(self, initial_metadata):
72        self._servicer_context.send_initial_metadata(
73            _metadata.unbeta(initial_metadata))
74
75    def terminal_metadata(self, terminal_metadata):
76        self._servicer_context.set_terminal_metadata(
77            _metadata.unbeta(terminal_metadata))
78
79    def code(self, code):
80        self._servicer_context.set_code(code)
81
82    def details(self, details):
83        self._servicer_context.set_details(details)
84
85
86def _adapt_unary_request_inline(unary_request_inline):
87
88    def adaptation(request, servicer_context):
89        return unary_request_inline(request,
90                                    _FaceServicerContext(servicer_context))
91
92    return adaptation
93
94
95def _adapt_stream_request_inline(stream_request_inline):
96
97    def adaptation(request_iterator, servicer_context):
98        return stream_request_inline(request_iterator,
99                                     _FaceServicerContext(servicer_context))
100
101    return adaptation
102
103
104class _Callback(stream.Consumer):
105
106    def __init__(self):
107        self._condition = threading.Condition()
108        self._values = []
109        self._terminated = False
110        self._cancelled = False
111
112    def consume(self, value):
113        with self._condition:
114            self._values.append(value)
115            self._condition.notify_all()
116
117    def terminate(self):
118        with self._condition:
119            self._terminated = True
120            self._condition.notify_all()
121
122    def consume_and_terminate(self, value):
123        with self._condition:
124            self._values.append(value)
125            self._terminated = True
126            self._condition.notify_all()
127
128    def cancel(self):
129        with self._condition:
130            self._cancelled = True
131            self._condition.notify_all()
132
133    def draw_one_value(self):
134        with self._condition:
135            while True:
136                if self._cancelled:
137                    raise abandonment.Abandoned()
138                elif self._values:
139                    return self._values.pop(0)
140                elif self._terminated:
141                    return None
142                else:
143                    self._condition.wait()
144
145    def draw_all_values(self):
146        with self._condition:
147            while True:
148                if self._cancelled:
149                    raise abandonment.Abandoned()
150                elif self._terminated:
151                    all_values = tuple(self._values)
152                    self._values = None
153                    return all_values
154                else:
155                    self._condition.wait()
156
157
158def _run_request_pipe_thread(request_iterator, request_consumer,
159                             servicer_context):
160    thread_joined = threading.Event()
161
162    def pipe_requests():
163        for request in request_iterator:
164            if not servicer_context.is_active() or thread_joined.is_set():
165                return
166            request_consumer.consume(request)
167            if not servicer_context.is_active() or thread_joined.is_set():
168                return
169        request_consumer.terminate()
170
171    request_pipe_thread = threading.Thread(target=pipe_requests)
172    request_pipe_thread.daemon = True
173    request_pipe_thread.start()
174
175
176def _adapt_unary_unary_event(unary_unary_event):
177
178    def adaptation(request, servicer_context):
179        callback = _Callback()
180        if not servicer_context.add_callback(callback.cancel):
181            raise abandonment.Abandoned()
182        unary_unary_event(request, callback.consume_and_terminate,
183                          _FaceServicerContext(servicer_context))
184        return callback.draw_all_values()[0]
185
186    return adaptation
187
188
189def _adapt_unary_stream_event(unary_stream_event):
190
191    def adaptation(request, servicer_context):
192        callback = _Callback()
193        if not servicer_context.add_callback(callback.cancel):
194            raise abandonment.Abandoned()
195        unary_stream_event(request, callback,
196                           _FaceServicerContext(servicer_context))
197        while True:
198            response = callback.draw_one_value()
199            if response is None:
200                return
201            else:
202                yield response
203
204    return adaptation
205
206
207def _adapt_stream_unary_event(stream_unary_event):
208
209    def adaptation(request_iterator, servicer_context):
210        callback = _Callback()
211        if not servicer_context.add_callback(callback.cancel):
212            raise abandonment.Abandoned()
213        request_consumer = stream_unary_event(
214            callback.consume_and_terminate,
215            _FaceServicerContext(servicer_context))
216        _run_request_pipe_thread(request_iterator, request_consumer,
217                                 servicer_context)
218        return callback.draw_all_values()[0]
219
220    return adaptation
221
222
223def _adapt_stream_stream_event(stream_stream_event):
224
225    def adaptation(request_iterator, servicer_context):
226        callback = _Callback()
227        if not servicer_context.add_callback(callback.cancel):
228            raise abandonment.Abandoned()
229        request_consumer = stream_stream_event(
230            callback, _FaceServicerContext(servicer_context))
231        _run_request_pipe_thread(request_iterator, request_consumer,
232                                 servicer_context)
233        while True:
234            response = callback.draw_one_value()
235            if response is None:
236                return
237            else:
238                yield response
239
240    return adaptation
241
242
243class _SimpleMethodHandler(
244        collections.namedtuple('_MethodHandler', (
245            'request_streaming',
246            'response_streaming',
247            'request_deserializer',
248            'response_serializer',
249            'unary_unary',
250            'unary_stream',
251            'stream_unary',
252            'stream_stream',
253        )), grpc.RpcMethodHandler):
254    pass
255
256
257def _simple_method_handler(implementation, request_deserializer,
258                           response_serializer):
259    if implementation.style is style.Service.INLINE:
260        if implementation.cardinality is cardinality.Cardinality.UNARY_UNARY:
261            return _SimpleMethodHandler(False, False, request_deserializer,
262                                        response_serializer,
263                                        _adapt_unary_request_inline(
264                                            implementation.unary_unary_inline),
265                                        None, None, None)
266        elif implementation.cardinality is cardinality.Cardinality.UNARY_STREAM:
267            return _SimpleMethodHandler(False, True, request_deserializer,
268                                        response_serializer, None,
269                                        _adapt_unary_request_inline(
270                                            implementation.unary_stream_inline),
271                                        None, None)
272        elif implementation.cardinality is cardinality.Cardinality.STREAM_UNARY:
273            return _SimpleMethodHandler(True, False, request_deserializer,
274                                        response_serializer, None, None,
275                                        _adapt_stream_request_inline(
276                                            implementation.stream_unary_inline),
277                                        None)
278        elif implementation.cardinality is cardinality.Cardinality.STREAM_STREAM:
279            return _SimpleMethodHandler(
280                True, True, request_deserializer, response_serializer, None,
281                None, None,
282                _adapt_stream_request_inline(
283                    implementation.stream_stream_inline))
284    elif implementation.style is style.Service.EVENT:
285        if implementation.cardinality is cardinality.Cardinality.UNARY_UNARY:
286            return _SimpleMethodHandler(False, False, request_deserializer,
287                                        response_serializer,
288                                        _adapt_unary_unary_event(
289                                            implementation.unary_unary_event),
290                                        None, None, None)
291        elif implementation.cardinality is cardinality.Cardinality.UNARY_STREAM:
292            return _SimpleMethodHandler(False, True, request_deserializer,
293                                        response_serializer, None,
294                                        _adapt_unary_stream_event(
295                                            implementation.unary_stream_event),
296                                        None, None)
297        elif implementation.cardinality is cardinality.Cardinality.STREAM_UNARY:
298            return _SimpleMethodHandler(True, False, request_deserializer,
299                                        response_serializer, None, None,
300                                        _adapt_stream_unary_event(
301                                            implementation.stream_unary_event),
302                                        None)
303        elif implementation.cardinality is cardinality.Cardinality.STREAM_STREAM:
304            return _SimpleMethodHandler(True, True, request_deserializer,
305                                        response_serializer, None, None, None,
306                                        _adapt_stream_stream_event(
307                                            implementation.stream_stream_event))
308    raise ValueError()
309
310
311def _flatten_method_pair_map(method_pair_map):
312    method_pair_map = method_pair_map or {}
313    flat_map = {}
314    for method_pair in method_pair_map:
315        method = _common.fully_qualified_method(method_pair[0], method_pair[1])
316        flat_map[method] = method_pair_map[method_pair]
317    return flat_map
318
319
320class _GenericRpcHandler(grpc.GenericRpcHandler):
321
322    def __init__(self, method_implementations, multi_method_implementation,
323                 request_deserializers, response_serializers):
324        self._method_implementations = _flatten_method_pair_map(
325            method_implementations)
326        self._request_deserializers = _flatten_method_pair_map(
327            request_deserializers)
328        self._response_serializers = _flatten_method_pair_map(
329            response_serializers)
330        self._multi_method_implementation = multi_method_implementation
331
332    def service(self, handler_call_details):
333        method_implementation = self._method_implementations.get(
334            handler_call_details.method)
335        if method_implementation is not None:
336            return _simple_method_handler(method_implementation,
337                                          self._request_deserializers.get(
338                                              handler_call_details.method),
339                                          self._response_serializers.get(
340                                              handler_call_details.method))
341        elif self._multi_method_implementation is None:
342            return None
343        else:
344            try:
345                return None  #TODO(nathaniel): call the multimethod.
346            except face.NoSuchMethodError:
347                return None
348
349
350class _Server(interfaces.Server):
351
352    def __init__(self, grpc_server):
353        self._grpc_server = grpc_server
354
355    def add_insecure_port(self, address):
356        return self._grpc_server.add_insecure_port(address)
357
358    def add_secure_port(self, address, server_credentials):
359        return self._grpc_server.add_secure_port(address, server_credentials)
360
361    def start(self):
362        self._grpc_server.start()
363
364    def stop(self, grace):
365        return self._grpc_server.stop(grace)
366
367    def __enter__(self):
368        self._grpc_server.start()
369        return self
370
371    def __exit__(self, exc_type, exc_val, exc_tb):
372        self._grpc_server.stop(None)
373        return False
374
375
376def server(service_implementations, multi_method_implementation,
377           request_deserializers, response_serializers, thread_pool,
378           thread_pool_size):
379    generic_rpc_handler = _GenericRpcHandler(
380        service_implementations, multi_method_implementation,
381        request_deserializers, response_serializers)
382    if thread_pool is None:
383        effective_thread_pool = logging_pool.pool(_DEFAULT_POOL_SIZE
384                                                  if thread_pool_size is None
385                                                  else thread_pool_size)
386    else:
387        effective_thread_pool = thread_pool
388    return _Server(
389        grpc.server(effective_thread_pool, handlers=(generic_rpc_handler,)))
390