1# Copyright 2020 gRPC authors.
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
7#     http://www.apache.org/licenses/LICENSE-2.0
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15xDS Test Client.
17TODO(sergiitk): separate XdsTestClient and KubernetesClientRunner to individual
20import datetime
21import functools
22import logging
23from typing import Iterator, Optional
25from framework.helpers import retryers
26from framework.infrastructure import k8s
27import framework.rpc
28from framework.rpc import grpc_channelz
29from framework.rpc import grpc_testing
30from framework.test_app import base_runner
32logger = logging.getLogger(__name__)
34# Type aliases
35_timedelta = datetime.timedelta
36_LoadBalancerStatsServiceClient = grpc_testing.LoadBalancerStatsServiceClient
37_ChannelzServiceClient = grpc_channelz.ChannelzServiceClient
38_ChannelzChannel = grpc_channelz.Channel
39_ChannelzChannelState = grpc_channelz.ChannelState
40_ChannelzSubchannel = grpc_channelz.Subchannel
41_ChannelzSocket = grpc_channelz.Socket
44class XdsTestClient(framework.rpc.grpc.GrpcApp):
45    """
46    Represents RPC services implemented in Client component of the xds test app.
47    https://github.com/grpc/grpc/blob/master/doc/xds-test-descriptions.md#client
48    """
50    def __init__(self,
51                 *,
52                 ip: str,
53                 rpc_port: int,
54                 server_target: str,
55                 rpc_host: Optional[str] = None,
56                 maintenance_port: Optional[int] = None):
57        super().__init__(rpc_host=(rpc_host or ip))
58        self.ip = ip
59        self.rpc_port = rpc_port
60        self.server_target = server_target
61        self.maintenance_port = maintenance_port or rpc_port
63    @property
64    @functools.lru_cache(None)
65    def load_balancer_stats(self) -> _LoadBalancerStatsServiceClient:
66        return _LoadBalancerStatsServiceClient(self._make_channel(
67            self.rpc_port))
69    @property
70    @functools.lru_cache(None)
71    def channelz(self) -> _ChannelzServiceClient:
72        return _ChannelzServiceClient(self._make_channel(self.maintenance_port))
74    def get_load_balancer_stats(
75            self,
76            *,
77            num_rpcs: int,
78            timeout_sec: Optional[int] = None,
79    ) -> grpc_testing.LoadBalancerStatsResponse:
80        """
81        Shortcut to LoadBalancerStatsServiceClient.get_client_stats()
82        """
83        return self.load_balancer_stats.get_client_stats(
84            num_rpcs=num_rpcs, timeout_sec=timeout_sec)
86    def get_server_channels(self) -> Iterator[_ChannelzChannel]:
87        return self.channelz.find_channels_for_target(self.server_target)
89    def wait_for_active_server_channel(self) -> _ChannelzChannel:
90        """Wait for the channel to the server to transition to READY.
92        Raises:
93            GrpcApp.NotFound: If the channel never transitioned to READY.
94        """
95        return self.wait_for_server_channel_state(_ChannelzChannelState.READY)
97    def get_active_server_channel(self) -> _ChannelzChannel:
98        """Return a READY channel to the server.
100        Raises:
101            GrpcApp.NotFound: If there's no READY channel to the server.
102        """
103        return self.find_server_channel_with_state(_ChannelzChannelState.READY)
105    def get_active_server_channel_socket(self) -> _ChannelzSocket:
106        channel = self.get_active_server_channel()
107        # Get the first subchannel of the active channel to the server.
108        logger.debug(
109            'Retrieving client -> server socket, '
110            'channel_id: %s, subchannel: %s', channel.ref.channel_id,
111            channel.subchannel_ref[0].name)
112        subchannel, *subchannels = list(
113            self.channelz.list_channel_subchannels(channel))
114        if subchannels:
115            logger.warning('Unexpected subchannels: %r', subchannels)
116        # Get the first socket of the subchannel
117        socket, *sockets = list(
118            self.channelz.list_subchannels_sockets(subchannel))
119        if sockets:
120            logger.warning('Unexpected sockets: %r', subchannels)
121        logger.debug('Found client -> server socket: %s', socket.ref.name)
122        return socket
124    def wait_for_server_channel_state(self,
125                                      state: _ChannelzChannelState,
126                                      *,
127                                      timeout: Optional[_timedelta] = None
128                                     ) -> _ChannelzChannel:
129        # Fine-tuned to wait for the channel to the server.
130        retryer = retryers.exponential_retryer_with_timeout(
131            wait_min=_timedelta(seconds=10),
132            wait_max=_timedelta(seconds=25),
133            timeout=_timedelta(minutes=3) if timeout is None else timeout)
135        logger.info('Waiting for client %s to report a %s channel to %s',
136                    self.ip, _ChannelzChannelState.Name(state),
137                    self.server_target)
138        channel = retryer(self.find_server_channel_with_state, state)
139        logger.info('Client %s channel to %s transitioned to state %s:\n%s',
140                    self.ip, self.server_target,
141                    _ChannelzChannelState.Name(state), channel)
142        return channel
144    def find_server_channel_with_state(self,
145                                       state: _ChannelzChannelState,
146                                       *,
147                                       check_subchannel=True
148                                      ) -> _ChannelzChannel:
149        for channel in self.get_server_channels():
150            channel_state: _ChannelzChannelState = channel.data.state.state
151            logger.info('Server channel: %s, state: %s', channel.ref.name,
152                        _ChannelzChannelState.Name(channel_state))
153            if channel_state is state:
154                if check_subchannel:
155                    # When requested, check if the channel has at least
156                    # one subchannel in the requested state.
157                    try:
158                        subchannel = self.find_subchannel_with_state(
159                            channel, state)
160                        logger.info('Found subchannel in state %s: %s', state,
161                                    subchannel)
162                    except self.NotFound as e:
163                        # Otherwise, keep searching.
164                        logger.info(e.message)
165                        continue
166                return channel
168        raise self.NotFound(
169            f'Client has no {_ChannelzChannelState.Name(state)} channel with '
170            'the server')
172    def find_subchannel_with_state(self, channel: _ChannelzChannel,
173                                   state: _ChannelzChannelState
174                                  ) -> _ChannelzSubchannel:
175        for subchannel in self.channelz.list_channel_subchannels(channel):
176            if subchannel.data.state.state is state:
177                return subchannel
179        raise self.NotFound(
180            f'Not found a {_ChannelzChannelState.Name(state)} '
181            f'subchannel for channel_id {channel.ref.channel_id}')
184class KubernetesClientRunner(base_runner.KubernetesBaseRunner):
186    def __init__(self,
187                 k8s_namespace,
188                 *,
189                 deployment_name,
190                 image_name,
191                 gcp_service_account,
192                 td_bootstrap_image,
193                 service_account_name=None,
194                 stats_port=8079,
195                 network='default',
196                 deployment_template='client.deployment.yaml',
197                 service_account_template='service-account.yaml',
198                 reuse_namespace=False,
199                 namespace_template=None,
200                 debug_use_port_forwarding=False):
201        super().__init__(k8s_namespace, namespace_template, reuse_namespace)
203        # Settings
204        self.deployment_name = deployment_name
205        self.image_name = image_name
206        self.gcp_service_account = gcp_service_account
207        self.service_account_name = service_account_name or deployment_name
208        self.stats_port = stats_port
209        # xDS bootstrap generator
210        self.td_bootstrap_image = td_bootstrap_image
211        self.network = network
212        self.deployment_template = deployment_template
213        self.service_account_template = service_account_template
214        self.debug_use_port_forwarding = debug_use_port_forwarding
216        # Mutable state
217        self.deployment: Optional[k8s.V1Deployment] = None
218        self.service_account: Optional[k8s.V1ServiceAccount] = None
219        self.port_forwarder = None
221    def run(self,
222            *,
223            server_target,
224            rpc='UnaryCall',
225            qps=25,
226            secure_mode=False,
227            print_response=False) -> XdsTestClient:
228        super().run()
229        # TODO(sergiitk): make rpc UnaryCall enum or get it from proto
231        # Create service account
232        self.service_account = self._create_service_account(
233            self.service_account_template,
234            service_account_name=self.service_account_name,
235            namespace_name=self.k8s_namespace.name,
236            gcp_service_account=self.gcp_service_account)
238        # Always create a new deployment
239        self.deployment = self._create_deployment(
240            self.deployment_template,
241            deployment_name=self.deployment_name,
242            image_name=self.image_name,
243            namespace_name=self.k8s_namespace.name,
244            service_account_name=self.service_account_name,
245            td_bootstrap_image=self.td_bootstrap_image,
246            network_name=self.network,
247            stats_port=self.stats_port,
248            server_target=server_target,
249            rpc=rpc,
250            qps=qps,
251            secure_mode=secure_mode,
252            print_response=print_response)
254        self._wait_deployment_with_available_replicas(self.deployment_name)
256        # Load test client pod. We need only one client at the moment
257        pod = self.k8s_namespace.list_deployment_pods(self.deployment)[0]
258        self._wait_pod_started(pod.metadata.name)
259        pod_ip = pod.status.pod_ip
260        rpc_host = None
262        # Experimental, for local debugging.
263        if self.debug_use_port_forwarding:
264            logger.info('LOCAL DEV MODE: Enabling port forwarding to %s:%s',
265                        pod_ip, self.stats_port)
266            self.port_forwarder = self.k8s_namespace.port_forward_pod(
267                pod, remote_port=self.stats_port)
268            rpc_host = self.k8s_namespace.PORT_FORWARD_LOCAL_ADDRESS
270        return XdsTestClient(ip=pod_ip,
271                             rpc_port=self.stats_port,
272                             server_target=server_target,
273                             rpc_host=rpc_host)
275    def cleanup(self, *, force=False, force_namespace=False):
276        if self.port_forwarder:
277            self.k8s_namespace.port_forward_stop(self.port_forwarder)
278            self.port_forwarder = None
279        if self.deployment or force:
280            self._delete_deployment(self.deployment_name)
281            self.deployment = None
282        if self.service_account or force:
283            self._delete_service_account(self.service_account_name)
284            self.service_account = None
285        super().cleanup(force=force_namespace and force)