1#!/usr/bin/env python
2#
3# Copyright 2016 - The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""Common Utilities.
18
19The following code is copied from chromite with modifications.
20  - class TempDir: chromite/lib/osutils.py
21
22"""
23
24import base64
25import binascii
26import errno
27import getpass
28import logging
29import os
30import shutil
31import struct
32import subprocess
33import sys
34import tarfile
35import tempfile
36import time
37import uuid
38
39from acloud.public import errors
40
41
42logger = logging.getLogger(__name__)
43
44
45SSH_KEYGEN_CMD = ["ssh-keygen", "-t", "rsa", "-b", "4096"]
46
47
48class TempDir(object):
49    """Object that creates a temporary directory.
50
51    This object can either be used as a context manager or just as a simple
52    object. The temporary directory is stored as self.tempdir in the object, and
53    is returned as a string by a 'with' statement.
54    """
55
56    def __init__(self, prefix='tmp', base_dir=None, delete=True):
57        """Constructor. Creates the temporary directory.
58
59        Args:
60            prefix: See tempfile.mkdtemp documentation.
61            base_dir: The directory to place the temporary directory.
62                      If None, will choose from system default tmp dir.
63            delete: Whether the temporary dir should be deleted as part of cleanup.
64        """
65        self.delete = delete
66        self.tempdir = tempfile.mkdtemp(prefix=prefix, dir=base_dir)
67        os.chmod(self.tempdir, 0o700)
68
69    def Cleanup(self):
70        """Clean up the temporary directory."""
71        # Note that _TempDirSetup may have failed, resulting in these attributes
72        # not being set; this is why we use getattr here (and must).
73        tempdir = getattr(self, 'tempdir', None)
74        if tempdir is not None and self.delete:
75            try:
76                shutil.rmtree(tempdir)
77            except EnvironmentError as e:
78                # Ignore error if directory or file does not exist.
79                if e.errno != errno.ENOENT:
80                    raise
81            finally:
82                self.tempdir = None
83
84    def __enter__(self):
85        """Return the temporary directory."""
86        return self.tempdir
87
88    def __exit__(self, exc_type, exc_value, exc_traceback):
89        """Exit the context manager."""
90        try:
91            self.Cleanup()
92        except Exception:  # pylint: disable=W0703
93            if exc_type:
94                # If an exception from inside the context was already in progress,
95                # log our cleanup exception, then allow the original to resume.
96                logger.error('While exiting %s:', self, exc_info=True)
97
98                if self.tempdir:
99                    # Log all files in tempdir at the time of the failure.
100                    try:
101                        logger.error('Directory contents were:')
102                        for name in os.listdir(self.tempdir):
103                            logger.error('  %s', name)
104                    except OSError:
105                        logger.error('  Directory did not exist.')
106            else:
107                # If there was not an exception from the context, raise ours.
108                raise
109
110    def __del__(self):
111        """Delete the object."""
112        self.Cleanup()
113
114def RetryOnException(retry_checker, max_retries, sleep_multiplier=0,
115                     retry_backoff_factor=1):
116  """Decorater which retries the function call if |retry_checker| returns true.
117
118  Args:
119    retry_checker: A callback function which should take an exception instance
120                   and return True if functor(*args, **kwargs) should be retried
121                   when such exception is raised, and return False if it should
122                   not be retried.
123    max_retries: Maximum number of retries allowed.
124    sleep_multiplier: Will sleep sleep_multiplier * attempt_count seconds if
125                      retry_backoff_factor is 1.  Will sleep
126                      sleep_multiplier * (
127                          retry_backoff_factor ** (attempt_count -  1))
128                      if retry_backoff_factor != 1.
129    retry_backoff_factor: See explanation of sleep_multiplier.
130
131  Returns:
132    The function wrapper.
133  """
134  def _Wrapper(func):
135    def _FunctionWrapper(*args, **kwargs):
136      return Retry(retry_checker, max_retries, func, sleep_multiplier,
137                   retry_backoff_factor,
138                   *args, **kwargs)
139    return _FunctionWrapper
140  return _Wrapper
141
142
143def Retry(retry_checker, max_retries, functor, sleep_multiplier=0,
144          retry_backoff_factor=1, *args, **kwargs):
145  """Conditionally retry a function.
146
147  Args:
148    retry_checker: A callback function which should take an exception instance
149                   and return True if functor(*args, **kwargs) should be retried
150                   when such exception is raised, and return False if it should
151                   not be retried.
152    max_retries: Maximum number of retries allowed.
153    functor: The function to call, will call functor(*args, **kwargs).
154    sleep_multiplier: Will sleep sleep_multiplier * attempt_count seconds if
155                      retry_backoff_factor is 1.  Will sleep
156                      sleep_multiplier * (
157                          retry_backoff_factor ** (attempt_count -  1))
158                      if retry_backoff_factor != 1.
159    retry_backoff_factor: See explanation of sleep_multiplier.
160    *args: Arguments to pass to the functor.
161    **kwargs: Key-val based arguments to pass to the functor.
162
163  Returns:
164    The return value of the functor.
165
166  Raises:
167    Exception: The exception that functor(*args, **kwargs) throws.
168  """
169  attempt_count = 0
170  while attempt_count <= max_retries:
171    try:
172      attempt_count += 1
173      return_value = functor(*args, **kwargs)
174      return return_value
175    except Exception as e:  # pylint: disable=W0703
176      if retry_checker(e) and attempt_count <= max_retries:
177        if retry_backoff_factor != 1:
178          sleep = sleep_multiplier * (
179              retry_backoff_factor ** (attempt_count -  1))
180        else:
181          sleep = sleep_multiplier * attempt_count
182        time.sleep(sleep)
183      else:
184        raise
185
186
187def RetryExceptionType(exception_types, max_retries, functor, *args, **kwargs):
188  """Retry exception if it is one of the given types.
189
190  Args:
191    exception_types: A tuple of exception types, e.g. (ValueError, KeyError)
192    max_retries: Max number of retries allowed.
193    functor: The function to call. Will be retried if exception is raised and
194             the exception is one of the exception_types.
195    *args: Arguments to pass to Retry function.
196    **kwargs: Key-val based arguments to pass to Retry functions.
197
198  Returns:
199    The value returned by calling functor.
200  """
201  return Retry(lambda e: isinstance(e, exception_types), max_retries,
202               functor, *args, **kwargs)
203
204
205def PollAndWait(func, expected_return, timeout_exception, timeout_secs,
206                sleep_interval_secs, *args, **kwargs):
207    """Call a function until the function returns expected value or times out.
208
209    Args:
210        func: Function to call.
211        expected_return: The expected return value.
212        timeout_exception: Exception to raise when it hits timeout.
213        timeout_secs: Timeout seconds.
214                      If 0 or less than zero, the function will run once and
215                      we will not wait on it.
216        sleep_interval_secs: Time to sleep between two attemps.
217        *args: list of args to pass to func.
218        **kwargs: dictionary of keyword based args to pass to func.
219
220    Raises:
221        timeout_exception: if the run of function times out.
222    """
223    # TODO(fdeng): Currently this method does not kill
224    # |func|, if |func| takes longer than |timeout_secs|.
225    # We can use a more robust version from chromite.
226    start = time.time()
227    while True:
228        return_value = func(*args, **kwargs)
229        if return_value == expected_return:
230            return
231        elif time.time() - start > timeout_secs:
232            raise timeout_exception
233        else:
234            if sleep_interval_secs > 0:
235                time.sleep(sleep_interval_secs)
236
237
238def GenerateUniqueName(prefix=None, suffix=None):
239    """Generate a random unque name using uuid4.
240
241    Args:
242        prefix: String, desired prefix to prepend to the generated name.
243        suffix: String, desired suffix to append to the generated name.
244
245    Returns:
246        String, a random name.
247    """
248    name = uuid.uuid4().hex
249    if prefix:
250        name = "-".join([prefix, name])
251    if suffix:
252        name = "-".join([name, suffix])
253    return name
254
255
256def MakeTarFile(src_dict, dest):
257    """Archive files in tar.gz format to a file named as |dest|.
258
259    Args:
260        src_dict: A dictionary that maps a path to be archived
261                  to the corresponding name that appears in the archive.
262        dest: String, path to output file, e.g. /tmp/myfile.tar.gz
263    """
264    logger.info("Compressing %s into %s.", src_dict.keys(), dest)
265    with tarfile.open(dest, "w:gz") as tar:
266        for src, arcname in src_dict.iteritems():
267            tar.add(src, arcname=arcname)
268
269
270def CreateSshKeyPairIfNotExist(private_key_path, public_key_path):
271    """Create the ssh key pair if they don't exist.
272
273    Check if the public and private key pairs exist at
274    the given places. If not, create them.
275
276    Args:
277        private_key_path: Path to the private key file.
278                          e.g. ~/.ssh/acloud_rsa
279        public_key_path: Path to the public key file.
280                         e.g. ~/.ssh/acloud_rsa.pub
281    Raises:
282        error.DriverError: If failed to create the key pair.
283    """
284    public_key_path = os.path.expanduser(public_key_path)
285    private_key_path = os.path.expanduser(private_key_path)
286    create_key = (
287            not os.path.exists(public_key_path) and
288            not os.path.exists(private_key_path))
289    if not create_key:
290        logger.debug("The ssh private key (%s) or public key (%s) already exist,"
291                     "will not automatically create the key pairs.",
292                     private_key_path, public_key_path)
293        return
294    cmd = SSH_KEYGEN_CMD + ["-C", getpass.getuser(), "-f", private_key_path]
295    logger.info("The ssh private key (%s) and public key (%s) do not exist, "
296                "automatically creating key pair, calling: %s",
297                private_key_path, public_key_path, " ".join(cmd))
298    try:
299        subprocess.check_call(cmd, stdout=sys.stderr, stderr=sys.stdout)
300    except subprocess.CalledProcessError as e:
301        raise errors.DriverError(
302                "Failed to create ssh key pair: %s" % str(e))
303    except OSError as e:
304        raise errors.DriverError(
305                "Failed to create ssh key pair, please make sure "
306                "'ssh-keygen' is installed: %s" % str(e))
307
308    # By default ssh-keygen will create a public key file
309    # by append .pub to the private key file name. Rename it
310    # to what's requested by public_key_path.
311    default_pub_key_path = "%s.pub" % private_key_path
312    try:
313        if default_pub_key_path != public_key_path:
314            os.rename(default_pub_key_path, public_key_path)
315    except OSError as e:
316        raise errors.DriverError(
317                "Failed to rename %s to %s: %s" %
318                (default_pub_key_path, public_key_path, str(e)))
319
320    logger.info("Created ssh private key (%s) and public key (%s)",
321                private_key_path, public_key_path)
322
323
324def VerifyRsaPubKey(rsa):
325    """Verify the format of rsa public key.
326
327    Args:
328        rsa: content of rsa public key. It should follow the format of
329             ssh-rsa AAAAB3NzaC1yc2EA.... test@test.com
330
331    Raises:
332        DriverError if the format is not correct.
333    """
334    if not rsa or not all(ord(c) < 128 for c in rsa):
335        raise errors.DriverError(
336            "rsa key is empty or contains non-ascii character: %s" % rsa)
337
338    elements = rsa.split()
339    if len(elements) != 3:
340        raise errors.DriverError("rsa key is invalid, wrong format: %s" % rsa)
341
342    key_type, data, _ = elements
343    try:
344        binary_data = base64.decodestring(data)
345        # number of bytes of int type
346        int_length = 4
347        # binary_data is like "7ssh-key..." in a binary format.
348        # The first 4 bytes should represent 7, which should be
349        # the length of the following string "ssh-key".
350        # And the next 7 bytes should be string "ssh-key".
351        # We will verify that the rsa conforms to this format.
352        # ">I" in the following line means "big-endian unsigned integer".
353        type_length = struct.unpack(">I", binary_data[:int_length])[0]
354        if binary_data[int_length:int_length + type_length] != key_type:
355            raise errors.DriverError("rsa key is invalid: %s" % rsa)
356    except (struct.error, binascii.Error) as e:
357        raise errors.DriverError("rsa key is invalid: %s, error: %s" %
358                                 (rsa, str(e)))
359
360
361class BatchHttpRequestExecutor(object):
362    """A helper class that executes requests in batch with retry.
363
364    This executor executes http requests in a batch and retry
365    those that have failed. It iteratively updates the dictionary
366    self._final_results with latest results, which can be retrieved
367    via GetResults.
368    """
369
370    def __init__(self,
371                 execute_once_functor,
372                 requests,
373                 retry_http_codes=None,
374                 max_retry=None,
375                 sleep=None,
376                 backoff_factor=None,
377                 other_retriable_errors=None):
378        """Initializes the executor.
379
380        Args:
381            execute_once_functor: A function that execute requests in batch once.
382                                  It should return a dictionary like
383                                  {request_id: (response, exception)}
384            requests: A dictionary where key is request id picked by caller,
385                      and value is a apiclient.http.HttpRequest.
386            retry_http_codes: A list of http codes to retry.
387            max_retry: See utils.Retry.
388            sleep: See utils.Retry.
389            backoff_factor: See utils.Retry.
390            other_retriable_errors: A tuple of error types that should be retried
391                                    other than errors.HttpError.
392        """
393        self._execute_once_functor = execute_once_functor
394        self._requests = requests
395        # A dictionary that maps request id to pending request.
396        self._pending_requests = {}
397        # A dictionary that maps request id to a tuple (response, exception).
398        self._final_results = {}
399        self._retry_http_codes = retry_http_codes
400        self._max_retry = max_retry
401        self._sleep = sleep
402        self._backoff_factor = backoff_factor
403        self._other_retriable_errors = other_retriable_errors
404
405    def _ShoudRetry(self, exception):
406        """Check if an exception is retriable."""
407        if isinstance(exception, self._other_retriable_errors):
408            return True
409
410        if (isinstance(exception, errors.HttpError) and
411                exception.code in self._retry_http_codes):
412            return True
413        return False
414
415    def _ExecuteOnce(self):
416        """Executes pending requests and update it with failed, retriable ones.
417
418        Raises:
419            HasRetriableRequestsError: if some requests fail and are retriable.
420        """
421        results = self._execute_once_functor(self._pending_requests)
422        # Update final_results with latest results.
423        self._final_results.update(results)
424        # Clear pending_requests
425        self._pending_requests.clear()
426        for request_id, result in results.iteritems():
427            exception = result[1]
428            if exception is not None and self._ShoudRetry(exception):
429                # If this is a retriable exception, put it in pending_requests
430                self._pending_requests[request_id] = self._requests[request_id]
431        if self._pending_requests:
432            # If there is still retriable requests pending, raise an error
433            # so that Retry will retry this function with pending_requests.
434            raise errors.HasRetriableRequestsError(
435                "Retriable errors: %s" % [str(results[rid][1])
436                                          for rid in self._pending_requests])
437
438    def Execute(self):
439        """Executes the requests and retry if necessary.
440
441        Will populate self._final_results.
442        """
443        def _ShouldRetryHandler(exc):
444            """Check if |exc| is a retriable exception.
445
446            Args:
447                exc: An exception.
448
449            Returns:
450                True if exception is of type HasRetriableRequestsError; False otherwise.
451            """
452            should_retry = isinstance(exc, errors.HasRetriableRequestsError)
453            if should_retry:
454                logger.info("Will retry failed requests.", exc_info=True)
455                logger.info("%s", exc)
456            return should_retry
457
458        try:
459            self._pending_requests = self._requests.copy()
460            Retry(
461                _ShouldRetryHandler, max_retries=self._max_retry,
462                functor=self._ExecuteOnce,
463                sleep_multiplier=self._sleep,
464                retry_backoff_factor=self._backoff_factor)
465        except errors.HasRetriableRequestsError:
466            logger.debug("Some requests did not succeed after retry.")
467
468    def GetResults(self):
469        """Returns final results.
470
471        Returns:
472            results, a dictionary in the following format
473            {request_id: (response, exception)}
474            request_ids are those from requests; response
475            is the http response for the request or None on error;
476            exception is an instance of DriverError or None if no error.
477        """
478        return self._final_results
479