1# Copyright 2016 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Translates gRPC's client-side API into gRPC's client-side Beta API."""
15
16import grpc
17from grpc import _common
18from grpc.beta import _metadata
19from grpc.beta import interfaces
20from grpc.framework.common import cardinality
21from grpc.framework.foundation import future
22from grpc.framework.interfaces.face import face
23
24# pylint: disable=too-many-arguments,too-many-locals,unused-argument
25
26_STATUS_CODE_TO_ABORTION_KIND_AND_ABORTION_ERROR_CLASS = {
27    grpc.StatusCode.CANCELLED: (face.Abortion.Kind.CANCELLED,
28                                face.CancellationError),
29    grpc.StatusCode.UNKNOWN: (face.Abortion.Kind.REMOTE_FAILURE,
30                              face.RemoteError),
31    grpc.StatusCode.DEADLINE_EXCEEDED: (face.Abortion.Kind.EXPIRED,
32                                        face.ExpirationError),
33    grpc.StatusCode.UNIMPLEMENTED: (face.Abortion.Kind.LOCAL_FAILURE,
34                                    face.LocalError),
35}
36
37
38def _effective_metadata(metadata, metadata_transformer):
39    non_none_metadata = () if metadata is None else metadata
40    if metadata_transformer is None:
41        return non_none_metadata
42    else:
43        return metadata_transformer(non_none_metadata)
44
45
46def _credentials(grpc_call_options):
47    return None if grpc_call_options is None else grpc_call_options.credentials
48
49
50def _abortion(rpc_error_call):
51    code = rpc_error_call.code()
52    pair = _STATUS_CODE_TO_ABORTION_KIND_AND_ABORTION_ERROR_CLASS.get(code)
53    error_kind = face.Abortion.Kind.LOCAL_FAILURE if pair is None else pair[0]
54    return face.Abortion(error_kind, rpc_error_call.initial_metadata(),
55                         rpc_error_call.trailing_metadata(), code,
56                         rpc_error_call.details())
57
58
59def _abortion_error(rpc_error_call):
60    code = rpc_error_call.code()
61    pair = _STATUS_CODE_TO_ABORTION_KIND_AND_ABORTION_ERROR_CLASS.get(code)
62    exception_class = face.AbortionError if pair is None else pair[1]
63    return exception_class(rpc_error_call.initial_metadata(),
64                           rpc_error_call.trailing_metadata(), code,
65                           rpc_error_call.details())
66
67
68class _InvocationProtocolContext(interfaces.GRPCInvocationContext):
69
70    def disable_next_request_compression(self):
71        pass  # TODO(https://github.com/grpc/grpc/issues/4078): design, implement.
72
73
74class _Rendezvous(future.Future, face.Call):
75
76    def __init__(self, response_future, response_iterator, call):
77        self._future = response_future
78        self._iterator = response_iterator
79        self._call = call
80
81    def cancel(self):
82        return self._call.cancel()
83
84    def cancelled(self):
85        return self._future.cancelled()
86
87    def running(self):
88        return self._future.running()
89
90    def done(self):
91        return self._future.done()
92
93    def result(self, timeout=None):
94        try:
95            return self._future.result(timeout=timeout)
96        except grpc.RpcError as rpc_error_call:
97            raise _abortion_error(rpc_error_call)
98        except grpc.FutureTimeoutError:
99            raise future.TimeoutError()
100        except grpc.FutureCancelledError:
101            raise future.CancelledError()
102
103    def exception(self, timeout=None):
104        try:
105            rpc_error_call = self._future.exception(timeout=timeout)
106            if rpc_error_call is None:
107                return None
108            else:
109                return _abortion_error(rpc_error_call)
110        except grpc.FutureTimeoutError:
111            raise future.TimeoutError()
112        except grpc.FutureCancelledError:
113            raise future.CancelledError()
114
115    def traceback(self, timeout=None):
116        try:
117            return self._future.traceback(timeout=timeout)
118        except grpc.FutureTimeoutError:
119            raise future.TimeoutError()
120        except grpc.FutureCancelledError:
121            raise future.CancelledError()
122
123    def add_done_callback(self, fn):
124        self._future.add_done_callback(lambda ignored_callback: fn(self))
125
126    def __iter__(self):
127        return self
128
129    def _next(self):
130        try:
131            return next(self._iterator)
132        except grpc.RpcError as rpc_error_call:
133            raise _abortion_error(rpc_error_call)
134
135    def __next__(self):
136        return self._next()
137
138    def next(self):
139        return self._next()
140
141    def is_active(self):
142        return self._call.is_active()
143
144    def time_remaining(self):
145        return self._call.time_remaining()
146
147    def add_abortion_callback(self, abortion_callback):
148
149        def done_callback():
150            if self.code() is not grpc.StatusCode.OK:
151                abortion_callback(_abortion(self._call))
152
153        registered = self._call.add_callback(done_callback)
154        return None if registered else done_callback()
155
156    def protocol_context(self):
157        return _InvocationProtocolContext()
158
159    def initial_metadata(self):
160        return _metadata.beta(self._call.initial_metadata())
161
162    def terminal_metadata(self):
163        return _metadata.beta(self._call.terminal_metadata())
164
165    def code(self):
166        return self._call.code()
167
168    def details(self):
169        return self._call.details()
170
171
172def _blocking_unary_unary(channel, group, method, timeout, with_call,
173                          protocol_options, metadata, metadata_transformer,
174                          request, request_serializer, response_deserializer):
175    try:
176        multi_callable = channel.unary_unary(
177            _common.fully_qualified_method(group, method),
178            request_serializer=request_serializer,
179            response_deserializer=response_deserializer)
180        effective_metadata = _effective_metadata(metadata, metadata_transformer)
181        if with_call:
182            response, call = multi_callable.with_call(
183                request,
184                timeout=timeout,
185                metadata=_metadata.unbeta(effective_metadata),
186                credentials=_credentials(protocol_options))
187            return response, _Rendezvous(None, None, call)
188        else:
189            return multi_callable(
190                request,
191                timeout=timeout,
192                metadata=_metadata.unbeta(effective_metadata),
193                credentials=_credentials(protocol_options))
194    except grpc.RpcError as rpc_error_call:
195        raise _abortion_error(rpc_error_call)
196
197
198def _future_unary_unary(channel, group, method, timeout, protocol_options,
199                        metadata, metadata_transformer, request,
200                        request_serializer, response_deserializer):
201    multi_callable = channel.unary_unary(
202        _common.fully_qualified_method(group, method),
203        request_serializer=request_serializer,
204        response_deserializer=response_deserializer)
205    effective_metadata = _effective_metadata(metadata, metadata_transformer)
206    response_future = multi_callable.future(
207        request,
208        timeout=timeout,
209        metadata=_metadata.unbeta(effective_metadata),
210        credentials=_credentials(protocol_options))
211    return _Rendezvous(response_future, None, response_future)
212
213
214def _unary_stream(channel, group, method, timeout, protocol_options, metadata,
215                  metadata_transformer, request, request_serializer,
216                  response_deserializer):
217    multi_callable = channel.unary_stream(
218        _common.fully_qualified_method(group, method),
219        request_serializer=request_serializer,
220        response_deserializer=response_deserializer)
221    effective_metadata = _effective_metadata(metadata, metadata_transformer)
222    response_iterator = multi_callable(
223        request,
224        timeout=timeout,
225        metadata=_metadata.unbeta(effective_metadata),
226        credentials=_credentials(protocol_options))
227    return _Rendezvous(None, response_iterator, response_iterator)
228
229
230def _blocking_stream_unary(channel, group, method, timeout, with_call,
231                           protocol_options, metadata, metadata_transformer,
232                           request_iterator, request_serializer,
233                           response_deserializer):
234    try:
235        multi_callable = channel.stream_unary(
236            _common.fully_qualified_method(group, method),
237            request_serializer=request_serializer,
238            response_deserializer=response_deserializer)
239        effective_metadata = _effective_metadata(metadata, metadata_transformer)
240        if with_call:
241            response, call = multi_callable.with_call(
242                request_iterator,
243                timeout=timeout,
244                metadata=_metadata.unbeta(effective_metadata),
245                credentials=_credentials(protocol_options))
246            return response, _Rendezvous(None, None, call)
247        else:
248            return multi_callable(
249                request_iterator,
250                timeout=timeout,
251                metadata=_metadata.unbeta(effective_metadata),
252                credentials=_credentials(protocol_options))
253    except grpc.RpcError as rpc_error_call:
254        raise _abortion_error(rpc_error_call)
255
256
257def _future_stream_unary(channel, group, method, timeout, protocol_options,
258                         metadata, metadata_transformer, request_iterator,
259                         request_serializer, response_deserializer):
260    multi_callable = channel.stream_unary(
261        _common.fully_qualified_method(group, method),
262        request_serializer=request_serializer,
263        response_deserializer=response_deserializer)
264    effective_metadata = _effective_metadata(metadata, metadata_transformer)
265    response_future = multi_callable.future(
266        request_iterator,
267        timeout=timeout,
268        metadata=_metadata.unbeta(effective_metadata),
269        credentials=_credentials(protocol_options))
270    return _Rendezvous(response_future, None, response_future)
271
272
273def _stream_stream(channel, group, method, timeout, protocol_options, metadata,
274                   metadata_transformer, request_iterator, request_serializer,
275                   response_deserializer):
276    multi_callable = channel.stream_stream(
277        _common.fully_qualified_method(group, method),
278        request_serializer=request_serializer,
279        response_deserializer=response_deserializer)
280    effective_metadata = _effective_metadata(metadata, metadata_transformer)
281    response_iterator = multi_callable(
282        request_iterator,
283        timeout=timeout,
284        metadata=_metadata.unbeta(effective_metadata),
285        credentials=_credentials(protocol_options))
286    return _Rendezvous(None, response_iterator, response_iterator)
287
288
289class _UnaryUnaryMultiCallable(face.UnaryUnaryMultiCallable):
290
291    def __init__(self, channel, group, method, metadata_transformer,
292                 request_serializer, response_deserializer):
293        self._channel = channel
294        self._group = group
295        self._method = method
296        self._metadata_transformer = metadata_transformer
297        self._request_serializer = request_serializer
298        self._response_deserializer = response_deserializer
299
300    def __call__(self,
301                 request,
302                 timeout,
303                 metadata=None,
304                 with_call=False,
305                 protocol_options=None):
306        return _blocking_unary_unary(
307            self._channel, self._group, self._method, timeout, with_call,
308            protocol_options, metadata, self._metadata_transformer, request,
309            self._request_serializer, self._response_deserializer)
310
311    def future(self, request, timeout, metadata=None, protocol_options=None):
312        return _future_unary_unary(
313            self._channel, self._group, self._method, timeout, protocol_options,
314            metadata, self._metadata_transformer, request,
315            self._request_serializer, self._response_deserializer)
316
317    def event(self,
318              request,
319              receiver,
320              abortion_callback,
321              timeout,
322              metadata=None,
323              protocol_options=None):
324        raise NotImplementedError()
325
326
327class _UnaryStreamMultiCallable(face.UnaryStreamMultiCallable):
328
329    def __init__(self, channel, group, method, metadata_transformer,
330                 request_serializer, response_deserializer):
331        self._channel = channel
332        self._group = group
333        self._method = method
334        self._metadata_transformer = metadata_transformer
335        self._request_serializer = request_serializer
336        self._response_deserializer = response_deserializer
337
338    def __call__(self, request, timeout, metadata=None, protocol_options=None):
339        return _unary_stream(
340            self._channel, self._group, self._method, timeout, protocol_options,
341            metadata, self._metadata_transformer, request,
342            self._request_serializer, self._response_deserializer)
343
344    def event(self,
345              request,
346              receiver,
347              abortion_callback,
348              timeout,
349              metadata=None,
350              protocol_options=None):
351        raise NotImplementedError()
352
353
354class _StreamUnaryMultiCallable(face.StreamUnaryMultiCallable):
355
356    def __init__(self, channel, group, method, metadata_transformer,
357                 request_serializer, response_deserializer):
358        self._channel = channel
359        self._group = group
360        self._method = method
361        self._metadata_transformer = metadata_transformer
362        self._request_serializer = request_serializer
363        self._response_deserializer = response_deserializer
364
365    def __call__(self,
366                 request_iterator,
367                 timeout,
368                 metadata=None,
369                 with_call=False,
370                 protocol_options=None):
371        return _blocking_stream_unary(
372            self._channel, self._group, self._method, timeout, with_call,
373            protocol_options, metadata, self._metadata_transformer,
374            request_iterator, self._request_serializer,
375            self._response_deserializer)
376
377    def future(self,
378               request_iterator,
379               timeout,
380               metadata=None,
381               protocol_options=None):
382        return _future_stream_unary(
383            self._channel, self._group, self._method, timeout, protocol_options,
384            metadata, self._metadata_transformer, request_iterator,
385            self._request_serializer, self._response_deserializer)
386
387    def event(self,
388              receiver,
389              abortion_callback,
390              timeout,
391              metadata=None,
392              protocol_options=None):
393        raise NotImplementedError()
394
395
396class _StreamStreamMultiCallable(face.StreamStreamMultiCallable):
397
398    def __init__(self, channel, group, method, metadata_transformer,
399                 request_serializer, response_deserializer):
400        self._channel = channel
401        self._group = group
402        self._method = method
403        self._metadata_transformer = metadata_transformer
404        self._request_serializer = request_serializer
405        self._response_deserializer = response_deserializer
406
407    def __call__(self,
408                 request_iterator,
409                 timeout,
410                 metadata=None,
411                 protocol_options=None):
412        return _stream_stream(
413            self._channel, self._group, self._method, timeout, protocol_options,
414            metadata, self._metadata_transformer, request_iterator,
415            self._request_serializer, self._response_deserializer)
416
417    def event(self,
418              receiver,
419              abortion_callback,
420              timeout,
421              metadata=None,
422              protocol_options=None):
423        raise NotImplementedError()
424
425
426class _GenericStub(face.GenericStub):
427
428    def __init__(self, channel, metadata_transformer, request_serializers,
429                 response_deserializers):
430        self._channel = channel
431        self._metadata_transformer = metadata_transformer
432        self._request_serializers = request_serializers or {}
433        self._response_deserializers = response_deserializers or {}
434
435    def blocking_unary_unary(self,
436                             group,
437                             method,
438                             request,
439                             timeout,
440                             metadata=None,
441                             with_call=None,
442                             protocol_options=None):
443        request_serializer = self._request_serializers.get((
444            group,
445            method,
446        ))
447        response_deserializer = self._response_deserializers.get((
448            group,
449            method,
450        ))
451        return _blocking_unary_unary(self._channel, group, method, timeout,
452                                     with_call, protocol_options, metadata,
453                                     self._metadata_transformer, request,
454                                     request_serializer, response_deserializer)
455
456    def future_unary_unary(self,
457                           group,
458                           method,
459                           request,
460                           timeout,
461                           metadata=None,
462                           protocol_options=None):
463        request_serializer = self._request_serializers.get((
464            group,
465            method,
466        ))
467        response_deserializer = self._response_deserializers.get((
468            group,
469            method,
470        ))
471        return _future_unary_unary(self._channel, group, method, timeout,
472                                   protocol_options, metadata,
473                                   self._metadata_transformer, request,
474                                   request_serializer, response_deserializer)
475
476    def inline_unary_stream(self,
477                            group,
478                            method,
479                            request,
480                            timeout,
481                            metadata=None,
482                            protocol_options=None):
483        request_serializer = self._request_serializers.get((
484            group,
485            method,
486        ))
487        response_deserializer = self._response_deserializers.get((
488            group,
489            method,
490        ))
491        return _unary_stream(self._channel, group, method, timeout,
492                             protocol_options, metadata,
493                             self._metadata_transformer, request,
494                             request_serializer, response_deserializer)
495
496    def blocking_stream_unary(self,
497                              group,
498                              method,
499                              request_iterator,
500                              timeout,
501                              metadata=None,
502                              with_call=None,
503                              protocol_options=None):
504        request_serializer = self._request_serializers.get((
505            group,
506            method,
507        ))
508        response_deserializer = self._response_deserializers.get((
509            group,
510            method,
511        ))
512        return _blocking_stream_unary(
513            self._channel, group, method, timeout, with_call, protocol_options,
514            metadata, self._metadata_transformer, request_iterator,
515            request_serializer, response_deserializer)
516
517    def future_stream_unary(self,
518                            group,
519                            method,
520                            request_iterator,
521                            timeout,
522                            metadata=None,
523                            protocol_options=None):
524        request_serializer = self._request_serializers.get((
525            group,
526            method,
527        ))
528        response_deserializer = self._response_deserializers.get((
529            group,
530            method,
531        ))
532        return _future_stream_unary(
533            self._channel, group, method, timeout, protocol_options, metadata,
534            self._metadata_transformer, request_iterator, request_serializer,
535            response_deserializer)
536
537    def inline_stream_stream(self,
538                             group,
539                             method,
540                             request_iterator,
541                             timeout,
542                             metadata=None,
543                             protocol_options=None):
544        request_serializer = self._request_serializers.get((
545            group,
546            method,
547        ))
548        response_deserializer = self._response_deserializers.get((
549            group,
550            method,
551        ))
552        return _stream_stream(self._channel, group, method, timeout,
553                              protocol_options, metadata,
554                              self._metadata_transformer, request_iterator,
555                              request_serializer, response_deserializer)
556
557    def event_unary_unary(self,
558                          group,
559                          method,
560                          request,
561                          receiver,
562                          abortion_callback,
563                          timeout,
564                          metadata=None,
565                          protocol_options=None):
566        raise NotImplementedError()
567
568    def event_unary_stream(self,
569                           group,
570                           method,
571                           request,
572                           receiver,
573                           abortion_callback,
574                           timeout,
575                           metadata=None,
576                           protocol_options=None):
577        raise NotImplementedError()
578
579    def event_stream_unary(self,
580                           group,
581                           method,
582                           receiver,
583                           abortion_callback,
584                           timeout,
585                           metadata=None,
586                           protocol_options=None):
587        raise NotImplementedError()
588
589    def event_stream_stream(self,
590                            group,
591                            method,
592                            receiver,
593                            abortion_callback,
594                            timeout,
595                            metadata=None,
596                            protocol_options=None):
597        raise NotImplementedError()
598
599    def unary_unary(self, group, method):
600        request_serializer = self._request_serializers.get((
601            group,
602            method,
603        ))
604        response_deserializer = self._response_deserializers.get((
605            group,
606            method,
607        ))
608        return _UnaryUnaryMultiCallable(
609            self._channel, group, method, self._metadata_transformer,
610            request_serializer, response_deserializer)
611
612    def unary_stream(self, group, method):
613        request_serializer = self._request_serializers.get((
614            group,
615            method,
616        ))
617        response_deserializer = self._response_deserializers.get((
618            group,
619            method,
620        ))
621        return _UnaryStreamMultiCallable(
622            self._channel, group, method, self._metadata_transformer,
623            request_serializer, response_deserializer)
624
625    def stream_unary(self, group, method):
626        request_serializer = self._request_serializers.get((
627            group,
628            method,
629        ))
630        response_deserializer = self._response_deserializers.get((
631            group,
632            method,
633        ))
634        return _StreamUnaryMultiCallable(
635            self._channel, group, method, self._metadata_transformer,
636            request_serializer, response_deserializer)
637
638    def stream_stream(self, group, method):
639        request_serializer = self._request_serializers.get((
640            group,
641            method,
642        ))
643        response_deserializer = self._response_deserializers.get((
644            group,
645            method,
646        ))
647        return _StreamStreamMultiCallable(
648            self._channel, group, method, self._metadata_transformer,
649            request_serializer, response_deserializer)
650
651    def __enter__(self):
652        return self
653
654    def __exit__(self, exc_type, exc_val, exc_tb):
655        return False
656
657
658class _DynamicStub(face.DynamicStub):
659
660    def __init__(self, backing_generic_stub, group, cardinalities):
661        self._generic_stub = backing_generic_stub
662        self._group = group
663        self._cardinalities = cardinalities
664
665    def __getattr__(self, attr):
666        method_cardinality = self._cardinalities.get(attr)
667        if method_cardinality is cardinality.Cardinality.UNARY_UNARY:
668            return self._generic_stub.unary_unary(self._group, attr)
669        elif method_cardinality is cardinality.Cardinality.UNARY_STREAM:
670            return self._generic_stub.unary_stream(self._group, attr)
671        elif method_cardinality is cardinality.Cardinality.STREAM_UNARY:
672            return self._generic_stub.stream_unary(self._group, attr)
673        elif method_cardinality is cardinality.Cardinality.STREAM_STREAM:
674            return self._generic_stub.stream_stream(self._group, attr)
675        else:
676            raise AttributeError(
677                '_DynamicStub object has no attribute "%s"!' % attr)
678
679    def __enter__(self):
680        return self
681
682    def __exit__(self, exc_type, exc_val, exc_tb):
683        return False
684
685
686def generic_stub(channel, host, metadata_transformer, request_serializers,
687                 response_deserializers):
688    return _GenericStub(channel, metadata_transformer, request_serializers,
689                        response_deserializers)
690
691
692def dynamic_stub(channel, service, cardinalities, host, metadata_transformer,
693                 request_serializers, response_deserializers):
694    return _DynamicStub(
695        _GenericStub(channel, metadata_transformer, request_serializers,
696                     response_deserializers), service, cardinalities)
697