1# Copyright 2020 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14""" 15This contains helpers for gRPC services defined in 16https://github.com/grpc/grpc-proto/blob/master/grpc/channelz/v1/channelz.proto 17""" 18import ipaddress 19import logging 20from typing import Iterator, Optional 21 22import grpc 23from grpc_channelz.v1 import channelz_pb2 24from grpc_channelz.v1 import channelz_pb2_grpc 25 26import framework.rpc 27 28logger = logging.getLogger(__name__) 29 30# Type aliases 31# Channel 32Channel = channelz_pb2.Channel 33ChannelConnectivityState = channelz_pb2.ChannelConnectivityState 34ChannelState = ChannelConnectivityState.State # pylint: disable=no-member 35_GetTopChannelsRequest = channelz_pb2.GetTopChannelsRequest 36_GetTopChannelsResponse = channelz_pb2.GetTopChannelsResponse 37# Subchannel 38Subchannel = channelz_pb2.Subchannel 39_GetSubchannelRequest = channelz_pb2.GetSubchannelRequest 40_GetSubchannelResponse = channelz_pb2.GetSubchannelResponse 41# Server 42Server = channelz_pb2.Server 43_GetServersRequest = channelz_pb2.GetServersRequest 44_GetServersResponse = channelz_pb2.GetServersResponse 45# Sockets 46Socket = channelz_pb2.Socket 47SocketRef = channelz_pb2.SocketRef 48_GetSocketRequest = channelz_pb2.GetSocketRequest 49_GetSocketResponse = channelz_pb2.GetSocketResponse 50Address = channelz_pb2.Address 51Security = channelz_pb2.Security 52# Server Sockets 53_GetServerSocketsRequest = channelz_pb2.GetServerSocketsRequest 54_GetServerSocketsResponse = channelz_pb2.GetServerSocketsResponse 55 56 57class ChannelzServiceClient(framework.rpc.grpc.GrpcClientHelper): 58 stub: channelz_pb2_grpc.ChannelzStub 59 60 def __init__(self, channel: grpc.Channel): 61 super().__init__(channel, channelz_pb2_grpc.ChannelzStub) 62 63 @staticmethod 64 def is_sock_tcpip_address(address: Address): 65 return address.WhichOneof('address') == 'tcpip_address' 66 67 @staticmethod 68 def is_ipv4(tcpip_address: Address.TcpIpAddress): 69 # According to proto, tcpip_address.ip_address is either IPv4 or IPv6. 70 # Correspondingly, it's either 4 bytes or 16 bytes in length. 71 return len(tcpip_address.ip_address) == 4 72 73 @classmethod 74 def sock_address_to_str(cls, address: Address): 75 if cls.is_sock_tcpip_address(address): 76 tcpip_address: Address.TcpIpAddress = address.tcpip_address 77 if cls.is_ipv4(tcpip_address): 78 ip = ipaddress.IPv4Address(tcpip_address.ip_address) 79 else: 80 ip = ipaddress.IPv6Address(tcpip_address.ip_address) 81 return f'{ip}:{tcpip_address.port}' 82 else: 83 raise NotImplementedError('Only tcpip_address implemented') 84 85 @classmethod 86 def sock_addresses_pretty(cls, socket: Socket): 87 return (f'local={cls.sock_address_to_str(socket.local)}, ' 88 f'remote={cls.sock_address_to_str(socket.remote)}') 89 90 @staticmethod 91 def find_server_socket_matching_client(server_sockets: Iterator[Socket], 92 client_socket: Socket) -> Socket: 93 for server_socket in server_sockets: 94 if server_socket.remote == client_socket.local: 95 return server_socket 96 return None 97 98 def find_channels_for_target(self, target: str) -> Iterator[Channel]: 99 return (channel for channel in self.list_channels() 100 if channel.data.target == target) 101 102 def find_server_listening_on_port(self, port: int) -> Optional[Server]: 103 for server in self.list_servers(): 104 listen_socket_ref: SocketRef 105 for listen_socket_ref in server.listen_socket: 106 listen_socket = self.get_socket(listen_socket_ref.socket_id) 107 listen_address: Address = listen_socket.local 108 if (self.is_sock_tcpip_address(listen_address) and 109 listen_address.tcpip_address.port == port): 110 return server 111 return None 112 113 def list_channels(self) -> Iterator[Channel]: 114 """ 115 Iterate over all pages of all root channels. 116 117 Root channels are those which application has directly created. 118 This does not include subchannels nor non-top level channels. 119 """ 120 start: int = -1 121 response: Optional[_GetTopChannelsResponse] = None 122 while start < 0 or not response.end: 123 # From proto: To request subsequent pages, the client generates this 124 # value by adding 1 to the highest seen result ID. 125 start += 1 126 response = self.call_unary_with_deadline( 127 rpc='GetTopChannels', 128 req=_GetTopChannelsRequest(start_channel_id=start)) 129 for channel in response.channel: 130 start = max(start, channel.ref.channel_id) 131 yield channel 132 133 def list_servers(self) -> Iterator[Server]: 134 """Iterate over all pages of all servers that exist in the process.""" 135 start: int = -1 136 response: Optional[_GetServersResponse] = None 137 while start < 0 or not response.end: 138 # From proto: To request subsequent pages, the client generates this 139 # value by adding 1 to the highest seen result ID. 140 start += 1 141 response = self.call_unary_with_deadline( 142 rpc='GetServers', req=_GetServersRequest(start_server_id=start)) 143 for server in response.server: 144 start = max(start, server.ref.server_id) 145 yield server 146 147 def list_server_sockets(self, server: Server) -> Iterator[Socket]: 148 """List all server sockets that exist in server process. 149 150 Iterating over the results will resolve additional pages automatically. 151 """ 152 start: int = -1 153 response: Optional[_GetServerSocketsResponse] = None 154 while start < 0 or not response.end: 155 # From proto: To request subsequent pages, the client generates this 156 # value by adding 1 to the highest seen result ID. 157 start += 1 158 response = self.call_unary_with_deadline( 159 rpc='GetServerSockets', 160 req=_GetServerSocketsRequest(server_id=server.ref.server_id, 161 start_socket_id=start)) 162 socket_ref: SocketRef 163 for socket_ref in response.socket_ref: 164 start = max(start, socket_ref.socket_id) 165 # Yield actual socket 166 yield self.get_socket(socket_ref.socket_id) 167 168 def list_channel_sockets(self, channel: Channel) -> Iterator[Socket]: 169 """List all sockets of all subchannels of a given channel.""" 170 for subchannel in self.list_channel_subchannels(channel): 171 yield from self.list_subchannels_sockets(subchannel) 172 173 def list_channel_subchannels(self, 174 channel: Channel) -> Iterator[Subchannel]: 175 """List all subchannels of a given channel.""" 176 for subchannel_ref in channel.subchannel_ref: 177 yield self.get_subchannel(subchannel_ref.subchannel_id) 178 179 def list_subchannels_sockets(self, 180 subchannel: Subchannel) -> Iterator[Socket]: 181 """List all sockets of a given subchannel.""" 182 for socket_ref in subchannel.socket_ref: 183 yield self.get_socket(socket_ref.socket_id) 184 185 def get_subchannel(self, subchannel_id) -> Subchannel: 186 """Return a single Subchannel, otherwise raises RpcError.""" 187 response: _GetSubchannelResponse = self.call_unary_with_deadline( 188 rpc='GetSubchannel', 189 req=_GetSubchannelRequest(subchannel_id=subchannel_id)) 190 return response.subchannel 191 192 def get_socket(self, socket_id) -> Socket: 193 """Return a single Socket, otherwise raises RpcError.""" 194 response: _GetSocketResponse = self.call_unary_with_deadline( 195 rpc='GetSocket', req=_GetSocketRequest(socket_id=socket_id)) 196 return response.socket 197