1# Copyright 2020 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""
15This contains helpers for gRPC services defined in
16https://github.com/grpc/grpc-proto/blob/master/grpc/channelz/v1/channelz.proto
17"""
18import ipaddress
19import logging
20from typing import Iterator, Optional
21
22import grpc
23from grpc_channelz.v1 import channelz_pb2
24from grpc_channelz.v1 import channelz_pb2_grpc
25
26import framework.rpc
27
28logger = logging.getLogger(__name__)
29
30# Type aliases
31# Channel
32Channel = channelz_pb2.Channel
33ChannelConnectivityState = channelz_pb2.ChannelConnectivityState
34ChannelState = ChannelConnectivityState.State  # pylint: disable=no-member
35_GetTopChannelsRequest = channelz_pb2.GetTopChannelsRequest
36_GetTopChannelsResponse = channelz_pb2.GetTopChannelsResponse
37# Subchannel
38Subchannel = channelz_pb2.Subchannel
39_GetSubchannelRequest = channelz_pb2.GetSubchannelRequest
40_GetSubchannelResponse = channelz_pb2.GetSubchannelResponse
41# Server
42Server = channelz_pb2.Server
43_GetServersRequest = channelz_pb2.GetServersRequest
44_GetServersResponse = channelz_pb2.GetServersResponse
45# Sockets
46Socket = channelz_pb2.Socket
47SocketRef = channelz_pb2.SocketRef
48_GetSocketRequest = channelz_pb2.GetSocketRequest
49_GetSocketResponse = channelz_pb2.GetSocketResponse
50Address = channelz_pb2.Address
51Security = channelz_pb2.Security
52# Server Sockets
53_GetServerSocketsRequest = channelz_pb2.GetServerSocketsRequest
54_GetServerSocketsResponse = channelz_pb2.GetServerSocketsResponse
55
56
57class ChannelzServiceClient(framework.rpc.grpc.GrpcClientHelper):
58    stub: channelz_pb2_grpc.ChannelzStub
59
60    def __init__(self, channel: grpc.Channel):
61        super().__init__(channel, channelz_pb2_grpc.ChannelzStub)
62
63    @staticmethod
64    def is_sock_tcpip_address(address: Address):
65        return address.WhichOneof('address') == 'tcpip_address'
66
67    @staticmethod
68    def is_ipv4(tcpip_address: Address.TcpIpAddress):
69        # According to proto, tcpip_address.ip_address is either IPv4 or IPv6.
70        # Correspondingly, it's either 4 bytes or 16 bytes in length.
71        return len(tcpip_address.ip_address) == 4
72
73    @classmethod
74    def sock_address_to_str(cls, address: Address):
75        if cls.is_sock_tcpip_address(address):
76            tcpip_address: Address.TcpIpAddress = address.tcpip_address
77            if cls.is_ipv4(tcpip_address):
78                ip = ipaddress.IPv4Address(tcpip_address.ip_address)
79            else:
80                ip = ipaddress.IPv6Address(tcpip_address.ip_address)
81            return f'{ip}:{tcpip_address.port}'
82        else:
83            raise NotImplementedError('Only tcpip_address implemented')
84
85    @classmethod
86    def sock_addresses_pretty(cls, socket: Socket):
87        return (f'local={cls.sock_address_to_str(socket.local)}, '
88                f'remote={cls.sock_address_to_str(socket.remote)}')
89
90    @staticmethod
91    def find_server_socket_matching_client(server_sockets: Iterator[Socket],
92                                           client_socket: Socket) -> Socket:
93        for server_socket in server_sockets:
94            if server_socket.remote == client_socket.local:
95                return server_socket
96        return None
97
98    def find_channels_for_target(self, target: str) -> Iterator[Channel]:
99        return (channel for channel in self.list_channels()
100                if channel.data.target == target)
101
102    def find_server_listening_on_port(self, port: int) -> Optional[Server]:
103        for server in self.list_servers():
104            listen_socket_ref: SocketRef
105            for listen_socket_ref in server.listen_socket:
106                listen_socket = self.get_socket(listen_socket_ref.socket_id)
107                listen_address: Address = listen_socket.local
108                if (self.is_sock_tcpip_address(listen_address) and
109                        listen_address.tcpip_address.port == port):
110                    return server
111        return None
112
113    def list_channels(self) -> Iterator[Channel]:
114        """
115        Iterate over all pages of all root channels.
116
117        Root channels are those which application has directly created.
118        This does not include subchannels nor non-top level channels.
119        """
120        start: int = -1
121        response: Optional[_GetTopChannelsResponse] = None
122        while start < 0 or not response.end:
123            # From proto: To request subsequent pages, the client generates this
124            # value by adding 1 to the highest seen result ID.
125            start += 1
126            response = self.call_unary_with_deadline(
127                rpc='GetTopChannels',
128                req=_GetTopChannelsRequest(start_channel_id=start))
129            for channel in response.channel:
130                start = max(start, channel.ref.channel_id)
131                yield channel
132
133    def list_servers(self) -> Iterator[Server]:
134        """Iterate over all pages of all servers that exist in the process."""
135        start: int = -1
136        response: Optional[_GetServersResponse] = None
137        while start < 0 or not response.end:
138            # From proto: To request subsequent pages, the client generates this
139            # value by adding 1 to the highest seen result ID.
140            start += 1
141            response = self.call_unary_with_deadline(
142                rpc='GetServers', req=_GetServersRequest(start_server_id=start))
143            for server in response.server:
144                start = max(start, server.ref.server_id)
145                yield server
146
147    def list_server_sockets(self, server: Server) -> Iterator[Socket]:
148        """List all server sockets that exist in server process.
149
150        Iterating over the results will resolve additional pages automatically.
151        """
152        start: int = -1
153        response: Optional[_GetServerSocketsResponse] = None
154        while start < 0 or not response.end:
155            # From proto: To request subsequent pages, the client generates this
156            # value by adding 1 to the highest seen result ID.
157            start += 1
158            response = self.call_unary_with_deadline(
159                rpc='GetServerSockets',
160                req=_GetServerSocketsRequest(server_id=server.ref.server_id,
161                                             start_socket_id=start))
162            socket_ref: SocketRef
163            for socket_ref in response.socket_ref:
164                start = max(start, socket_ref.socket_id)
165                # Yield actual socket
166                yield self.get_socket(socket_ref.socket_id)
167
168    def list_channel_sockets(self, channel: Channel) -> Iterator[Socket]:
169        """List all sockets of all subchannels of a given channel."""
170        for subchannel in self.list_channel_subchannels(channel):
171            yield from self.list_subchannels_sockets(subchannel)
172
173    def list_channel_subchannels(self,
174                                 channel: Channel) -> Iterator[Subchannel]:
175        """List all subchannels of a given channel."""
176        for subchannel_ref in channel.subchannel_ref:
177            yield self.get_subchannel(subchannel_ref.subchannel_id)
178
179    def list_subchannels_sockets(self,
180                                 subchannel: Subchannel) -> Iterator[Socket]:
181        """List all sockets of a given subchannel."""
182        for socket_ref in subchannel.socket_ref:
183            yield self.get_socket(socket_ref.socket_id)
184
185    def get_subchannel(self, subchannel_id) -> Subchannel:
186        """Return a single Subchannel, otherwise raises RpcError."""
187        response: _GetSubchannelResponse = self.call_unary_with_deadline(
188            rpc='GetSubchannel',
189            req=_GetSubchannelRequest(subchannel_id=subchannel_id))
190        return response.subchannel
191
192    def get_socket(self, socket_id) -> Socket:
193        """Return a single Socket, otherwise raises RpcError."""
194        response: _GetSocketResponse = self.call_unary_with_deadline(
195            rpc='GetSocket', req=_GetSocketRequest(socket_id=socket_id))
196        return response.socket
197