1# Copyright 2020 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14import functools
15import json
16import logging
17import subprocess
18import time
19from typing import Optional, List, Tuple
20
21# TODO(sergiitk): replace with tenacity
22import retrying
23import kubernetes.config
24from kubernetes import client
25from kubernetes import utils
26
27logger = logging.getLogger(__name__)
28# Type aliases
29V1Deployment = client.V1Deployment
30V1ServiceAccount = client.V1ServiceAccount
31V1Pod = client.V1Pod
32V1PodList = client.V1PodList
33V1Service = client.V1Service
34V1Namespace = client.V1Namespace
35ApiException = client.ApiException
36
37
38def simple_resource_get(func):
39
40    def wrap_not_found_return_none(*args, **kwargs):
41        try:
42            return func(*args, **kwargs)
43        except client.ApiException as e:
44            if e.status == 404:
45                # Ignore 404
46                return None
47            raise
48
49    return wrap_not_found_return_none
50
51
52def label_dict_to_selector(labels: dict) -> str:
53    return ','.join(f'{k}=={v}' for k, v in labels.items())
54
55
56class KubernetesApiManager:
57
58    def __init__(self, context):
59        self.context = context
60        self.client = self._cached_api_client_for_context(context)
61        self.apps = client.AppsV1Api(self.client)
62        self.core = client.CoreV1Api(self.client)
63
64    def close(self):
65        self.client.close()
66
67    @classmethod
68    @functools.lru_cache(None)
69    def _cached_api_client_for_context(cls, context: str) -> client.ApiClient:
70        client_instance = kubernetes.config.new_client_from_config(
71            context=context)
72        logger.info('Using kubernetes context "%s", active host: %s', context,
73                    client_instance.configuration.host)
74        return client_instance
75
76
77class PortForwardingError(Exception):
78    """Error forwarding port"""
79
80
81class KubernetesNamespace:
82    NEG_STATUS_META = 'cloud.google.com/neg-status'
83    PORT_FORWARD_LOCAL_ADDRESS: str = '127.0.0.1'
84    DELETE_GRACE_PERIOD_SEC: int = 5
85
86    def __init__(self, api: KubernetesApiManager, name: str):
87        self.name = name
88        self.api = api
89
90    def apply_manifest(self, manifest):
91        return utils.create_from_dict(self.api.client,
92                                      manifest,
93                                      namespace=self.name)
94
95    @simple_resource_get
96    def get_service(self, name) -> V1Service:
97        return self.api.core.read_namespaced_service(name, self.name)
98
99    @simple_resource_get
100    def get_service_account(self, name) -> V1Service:
101        return self.api.core.read_namespaced_service_account(name, self.name)
102
103    def delete_service(self, name,
104                       grace_period_seconds=DELETE_GRACE_PERIOD_SEC):
105        self.api.core.delete_namespaced_service(
106            name=name,
107            namespace=self.name,
108            body=client.V1DeleteOptions(
109                propagation_policy='Foreground',
110                grace_period_seconds=grace_period_seconds))
111
112    def delete_service_account(self,
113                               name,
114                               grace_period_seconds=DELETE_GRACE_PERIOD_SEC):
115        self.api.core.delete_namespaced_service_account(
116            name=name,
117            namespace=self.name,
118            body=client.V1DeleteOptions(
119                propagation_policy='Foreground',
120                grace_period_seconds=grace_period_seconds))
121
122    @simple_resource_get
123    def get(self) -> V1Namespace:
124        return self.api.core.read_namespace(self.name)
125
126    def delete(self, grace_period_seconds=DELETE_GRACE_PERIOD_SEC):
127        self.api.core.delete_namespace(
128            name=self.name,
129            body=client.V1DeleteOptions(
130                propagation_policy='Foreground',
131                grace_period_seconds=grace_period_seconds))
132
133    def wait_for_service_deleted(self, name: str, timeout_sec=60, wait_sec=1):
134
135        @retrying.retry(retry_on_result=lambda r: r is not None,
136                        stop_max_delay=timeout_sec * 1000,
137                        wait_fixed=wait_sec * 1000)
138        def _wait_for_deleted_service_with_retry():
139            service = self.get_service(name)
140            if service is not None:
141                logger.debug('Waiting for service %s to be deleted',
142                             service.metadata.name)
143            return service
144
145        _wait_for_deleted_service_with_retry()
146
147    def wait_for_service_account_deleted(self,
148                                         name: str,
149                                         timeout_sec=60,
150                                         wait_sec=1):
151
152        @retrying.retry(retry_on_result=lambda r: r is not None,
153                        stop_max_delay=timeout_sec * 1000,
154                        wait_fixed=wait_sec * 1000)
155        def _wait_for_deleted_service_account_with_retry():
156            service_account = self.get_service_account(name)
157            if service_account is not None:
158                logger.debug('Waiting for service account %s to be deleted',
159                             service_account.metadata.name)
160            return service_account
161
162        _wait_for_deleted_service_account_with_retry()
163
164    def wait_for_namespace_deleted(self, timeout_sec=240, wait_sec=5):
165
166        @retrying.retry(retry_on_result=lambda r: r is not None,
167                        stop_max_delay=timeout_sec * 1000,
168                        wait_fixed=wait_sec * 1000)
169        def _wait_for_deleted_namespace_with_retry():
170            namespace = self.get()
171            if namespace is not None:
172                logger.debug('Waiting for namespace %s to be deleted',
173                             namespace.metadata.name)
174            return namespace
175
176        _wait_for_deleted_namespace_with_retry()
177
178    def wait_for_service_neg(self, name: str, timeout_sec=60, wait_sec=1):
179
180        @retrying.retry(retry_on_result=lambda r: not r,
181                        stop_max_delay=timeout_sec * 1000,
182                        wait_fixed=wait_sec * 1000)
183        def _wait_for_service_neg():
184            service = self.get_service(name)
185            if self.NEG_STATUS_META not in service.metadata.annotations:
186                logger.debug('Waiting for service %s NEG',
187                             service.metadata.name)
188                return False
189            return True
190
191        _wait_for_service_neg()
192
193    def get_service_neg(self, service_name: str,
194                        service_port: int) -> Tuple[str, List[str]]:
195        service = self.get_service(service_name)
196        neg_info: dict = json.loads(
197            service.metadata.annotations[self.NEG_STATUS_META])
198        neg_name: str = neg_info['network_endpoint_groups'][str(service_port)]
199        neg_zones: List[str] = neg_info['zones']
200        return neg_name, neg_zones
201
202    @simple_resource_get
203    def get_deployment(self, name) -> V1Deployment:
204        return self.api.apps.read_namespaced_deployment(name, self.name)
205
206    def delete_deployment(self,
207                          name,
208                          grace_period_seconds=DELETE_GRACE_PERIOD_SEC):
209        self.api.apps.delete_namespaced_deployment(
210            name=name,
211            namespace=self.name,
212            body=client.V1DeleteOptions(
213                propagation_policy='Foreground',
214                grace_period_seconds=grace_period_seconds))
215
216    def list_deployment_pods(self, deployment: V1Deployment) -> List[V1Pod]:
217        # V1LabelSelector.match_expressions not supported at the moment
218        return self.list_pods_with_labels(deployment.spec.selector.match_labels)
219
220    def wait_for_deployment_available_replicas(self,
221                                               name,
222                                               count=1,
223                                               timeout_sec=60,
224                                               wait_sec=3):
225
226        @retrying.retry(
227            retry_on_result=lambda r: not self._replicas_available(r, count),
228            stop_max_delay=timeout_sec * 1000,
229            wait_fixed=wait_sec * 1000)
230        def _wait_for_deployment_available_replicas():
231            deployment = self.get_deployment(name)
232            logger.debug(
233                'Waiting for deployment %s to have %s available '
234                'replicas, current count %s', deployment.metadata.name, count,
235                deployment.status.available_replicas)
236            return deployment
237
238        _wait_for_deployment_available_replicas()
239
240    def wait_for_deployment_deleted(self,
241                                    deployment_name: str,
242                                    timeout_sec=60,
243                                    wait_sec=1):
244
245        @retrying.retry(retry_on_result=lambda r: r is not None,
246                        stop_max_delay=timeout_sec * 1000,
247                        wait_fixed=wait_sec * 1000)
248        def _wait_for_deleted_deployment_with_retry():
249            deployment = self.get_deployment(deployment_name)
250            if deployment is not None:
251                logger.debug(
252                    'Waiting for deployment %s to be deleted. '
253                    'Non-terminated replicas: %s', deployment.metadata.name,
254                    deployment.status.replicas)
255            return deployment
256
257        _wait_for_deleted_deployment_with_retry()
258
259    def list_pods_with_labels(self, labels: dict) -> List[V1Pod]:
260        pod_list: V1PodList = self.api.core.list_namespaced_pod(
261            self.name, label_selector=label_dict_to_selector(labels))
262        return pod_list.items
263
264    def get_pod(self, name) -> client.V1Pod:
265        return self.api.core.read_namespaced_pod(name, self.name)
266
267    def wait_for_pod_started(self, pod_name, timeout_sec=60, wait_sec=1):
268
269        @retrying.retry(retry_on_result=lambda r: not self._pod_started(r),
270                        stop_max_delay=timeout_sec * 1000,
271                        wait_fixed=wait_sec * 1000)
272        def _wait_for_pod_started():
273            pod = self.get_pod(pod_name)
274            logger.debug('Waiting for pod %s to start, current phase: %s',
275                         pod.metadata.name, pod.status.phase)
276            return pod
277
278        _wait_for_pod_started()
279
280    def port_forward_pod(
281            self,
282            pod: V1Pod,
283            remote_port: int,
284            local_port: Optional[int] = None,
285            local_address: Optional[str] = None,
286    ) -> subprocess.Popen:
287        """Experimental"""
288        local_address = local_address or self.PORT_FORWARD_LOCAL_ADDRESS
289        local_port = local_port or remote_port
290        cmd = [
291            "kubectl", "--context", self.api.context, "--namespace", self.name,
292            "port-forward", "--address", local_address,
293            f"pod/{pod.metadata.name}", f"{local_port}:{remote_port}"
294        ]
295        pf = subprocess.Popen(cmd,
296                              stdout=subprocess.PIPE,
297                              stderr=subprocess.STDOUT,
298                              universal_newlines=True)
299        # Wait for stdout line indicating successful start.
300        expected = (f"Forwarding from {local_address}:{local_port}"
301                    f" -> {remote_port}")
302        try:
303            while True:
304                time.sleep(0.05)
305                output = pf.stdout.readline().strip()
306                if not output:
307                    return_code = pf.poll()
308                    if return_code is not None:
309                        errors = [error for error in pf.stdout.readlines()]
310                        raise PortForwardingError(
311                            'Error forwarding port, kubectl return '
312                            f'code {return_code}, output {errors}')
313                elif output != expected:
314                    raise PortForwardingError(
315                        f'Error forwarding port, unexpected output {output}')
316                else:
317                    logger.info(output)
318                    break
319        except Exception:
320            self.port_forward_stop(pf)
321            raise
322
323        # TODO(sergiitk): return new PortForwarder object
324        return pf
325
326    @staticmethod
327    def port_forward_stop(pf):
328        logger.info('Shutting down port forwarding, pid %s', pf.pid)
329        pf.kill()
330        stdout, _stderr = pf.communicate(timeout=5)
331        logger.info('Port forwarding stopped')
332        logger.debug('Port forwarding remaining stdout: %s', stdout)
333
334    @staticmethod
335    def _pod_started(pod: V1Pod):
336        return pod.status.phase not in ('Pending', 'Unknown')
337
338    @staticmethod
339    def _replicas_available(deployment, count):
340        return (deployment is not None and
341                deployment.status.available_replicas is not None and
342                deployment.status.available_replicas >= count)
343