1# Copyright 2020 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14import logging 15import re 16from typing import ClassVar, Dict, Optional 17 18# Workaround: `grpc` must be imported before `google.protobuf.json_format`, 19# to prevent "Segmentation fault". Ref https://github.com/grpc/grpc/issues/24897 20import grpc 21from google.protobuf import json_format 22import google.protobuf.message 23 24logger = logging.getLogger(__name__) 25 26# Type aliases 27Message = google.protobuf.message.Message 28 29 30class GrpcClientHelper: 31 channel: grpc.Channel 32 DEFAULT_CONNECTION_TIMEOUT_SEC = 60 33 DEFAULT_WAIT_FOR_READY_SEC = 60 34 35 def __init__(self, channel: grpc.Channel, stub_class: ClassVar): 36 self.channel = channel 37 self.stub = stub_class(channel) 38 # This is purely cosmetic to make RPC logs look like method calls. 39 self.log_service_name = re.sub('Stub$', '', 40 self.stub.__class__.__name__) 41 42 def call_unary_with_deadline( 43 self, 44 *, 45 rpc: str, 46 req: Message, 47 wait_for_ready_sec: Optional[int] = DEFAULT_WAIT_FOR_READY_SEC, 48 connection_timeout_sec: Optional[ 49 int] = DEFAULT_CONNECTION_TIMEOUT_SEC, 50 log_level: Optional[int] = logging.DEBUG) -> Message: 51 if wait_for_ready_sec is None: 52 wait_for_ready_sec = self.DEFAULT_WAIT_FOR_READY_SEC 53 if connection_timeout_sec is None: 54 connection_timeout_sec = self.DEFAULT_CONNECTION_TIMEOUT_SEC 55 56 timeout_sec = wait_for_ready_sec + connection_timeout_sec 57 rpc_callable: grpc.UnaryUnaryMultiCallable = getattr(self.stub, rpc) 58 59 call_kwargs = dict(wait_for_ready=True, timeout=timeout_sec) 60 self._log_rpc_request(rpc, req, call_kwargs, log_level) 61 return rpc_callable(req, **call_kwargs) 62 63 def _log_rpc_request(self, rpc, req, call_kwargs, log_level=logging.DEBUG): 64 logger.log(logging.DEBUG if log_level is None else log_level, 65 'RPC %s.%s(request=%s(%r), %s)', self.log_service_name, rpc, 66 req.__class__.__name__, json_format.MessageToDict(req), 67 ', '.join({f'{k}={v}' for k, v in call_kwargs.items()})) 68 69 70class GrpcApp: 71 channels: Dict[int, grpc.Channel] 72 73 class NotFound(Exception): 74 """Requested resource not found""" 75 76 def __init__(self, message): 77 self.message = message 78 super().__init__(message) 79 80 def __init__(self, rpc_host): 81 self.rpc_host = rpc_host 82 # Cache gRPC channels per port 83 self.channels = dict() 84 85 def _make_channel(self, port) -> grpc.Channel: 86 if port not in self.channels: 87 target = f'{self.rpc_host}:{port}' 88 self.channels[port] = grpc.insecure_channel(target) 89 return self.channels[port] 90 91 def close(self): 92 # Close all channels 93 for channel in self.channels.values(): 94 channel.close() 95 96 def __enter__(self): 97 return self 98 99 def __exit__(self, exc_type, exc_val, exc_tb): 100 self.close() 101 return False 102 103 def __del__(self): 104 self.close() 105