1# Copyright 2020 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14import logging
15import re
16from typing import ClassVar, Dict, Optional
17
18# Workaround: `grpc` must be imported before `google.protobuf.json_format`,
19# to prevent "Segmentation fault". Ref https://github.com/grpc/grpc/issues/24897
20import grpc
21from google.protobuf import json_format
22import google.protobuf.message
23
24logger = logging.getLogger(__name__)
25
26# Type aliases
27Message = google.protobuf.message.Message
28
29
30class GrpcClientHelper:
31    channel: grpc.Channel
32    DEFAULT_CONNECTION_TIMEOUT_SEC = 60
33    DEFAULT_WAIT_FOR_READY_SEC = 60
34
35    def __init__(self, channel: grpc.Channel, stub_class: ClassVar):
36        self.channel = channel
37        self.stub = stub_class(channel)
38        # This is purely cosmetic to make RPC logs look like method calls.
39        self.log_service_name = re.sub('Stub$', '',
40                                       self.stub.__class__.__name__)
41
42    def call_unary_with_deadline(
43            self,
44            *,
45            rpc: str,
46            req: Message,
47            wait_for_ready_sec: Optional[int] = DEFAULT_WAIT_FOR_READY_SEC,
48            connection_timeout_sec: Optional[
49                int] = DEFAULT_CONNECTION_TIMEOUT_SEC,
50            log_level: Optional[int] = logging.DEBUG) -> Message:
51        if wait_for_ready_sec is None:
52            wait_for_ready_sec = self.DEFAULT_WAIT_FOR_READY_SEC
53        if connection_timeout_sec is None:
54            connection_timeout_sec = self.DEFAULT_CONNECTION_TIMEOUT_SEC
55
56        timeout_sec = wait_for_ready_sec + connection_timeout_sec
57        rpc_callable: grpc.UnaryUnaryMultiCallable = getattr(self.stub, rpc)
58
59        call_kwargs = dict(wait_for_ready=True, timeout=timeout_sec)
60        self._log_rpc_request(rpc, req, call_kwargs, log_level)
61        return rpc_callable(req, **call_kwargs)
62
63    def _log_rpc_request(self, rpc, req, call_kwargs, log_level=logging.DEBUG):
64        logger.log(logging.DEBUG if log_level is None else log_level,
65                   'RPC %s.%s(request=%s(%r), %s)', self.log_service_name, rpc,
66                   req.__class__.__name__, json_format.MessageToDict(req),
67                   ', '.join({f'{k}={v}' for k, v in call_kwargs.items()}))
68
69
70class GrpcApp:
71    channels: Dict[int, grpc.Channel]
72
73    class NotFound(Exception):
74        """Requested resource not found"""
75
76        def __init__(self, message):
77            self.message = message
78            super().__init__(message)
79
80    def __init__(self, rpc_host):
81        self.rpc_host = rpc_host
82        # Cache gRPC channels per port
83        self.channels = dict()
84
85    def _make_channel(self, port) -> grpc.Channel:
86        if port not in self.channels:
87            target = f'{self.rpc_host}:{port}'
88            self.channels[port] = grpc.insecure_channel(target)
89        return self.channels[port]
90
91    def close(self):
92        # Close all channels
93        for channel in self.channels.values():
94            channel.close()
95
96    def __enter__(self):
97        return self
98
99    def __exit__(self, exc_type, exc_val, exc_tb):
100        self.close()
101        return False
102
103    def __del__(self):
104        self.close()
105