1# Copyright 2013 The Chromium Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5"""A utility to run functions with timeouts and retries."""
6# pylint: disable=W0702
7
8import logging
9import threading
10import time
11import traceback
12
13from devil.utils import reraiser_thread
14from devil.utils import watchdog_timer
15
16
17class TimeoutRetryThreadGroup(reraiser_thread.ReraiserThreadGroup):
18
19  def __init__(self, timeout, threads=None):
20    super(TimeoutRetryThreadGroup, self).__init__(threads)
21    self._watcher = watchdog_timer.WatchdogTimer(timeout)
22
23  def GetWatcher(self):
24    """Returns the watchdog keeping track of this thread's time."""
25    return self._watcher
26
27  def GetElapsedTime(self):
28    return self._watcher.GetElapsed()
29
30  def GetRemainingTime(self, required=0, msg=None):
31    """Get the remaining time before the thread times out.
32
33    Useful to send as the |timeout| parameter of async IO operations.
34
35    Args:
36      required: minimum amount of time that will be required to complete, e.g.,
37        some sleep or IO operation.
38      msg: error message to show if timing out.
39
40    Returns:
41      The number of seconds remaining before the thread times out, or None
42      if the thread never times out.
43
44    Raises:
45      reraiser_thread.TimeoutError if the remaining time is less than the
46        required time.
47    """
48    remaining = self._watcher.GetRemaining()
49    if remaining is not None and remaining < required:
50      if msg is None:
51        msg = 'Timeout expired'
52      if remaining > 0:
53        msg += (', wait of %.1f secs required but only %.1f secs left'
54                % (required, remaining))
55      raise reraiser_thread.TimeoutError(msg)
56    return remaining
57
58
59def CurrentTimeoutThreadGroup():
60  """Returns the thread group that owns or is blocked on the active thread.
61
62  Returns:
63    Returns None if no TimeoutRetryThreadGroup is tracking the current thread.
64  """
65  thread_group = reraiser_thread.CurrentThreadGroup()
66  while thread_group:
67    if isinstance(thread_group, TimeoutRetryThreadGroup):
68      return thread_group
69    thread_group = thread_group.blocked_parent_thread_group
70  return None
71
72
73def WaitFor(condition, wait_period=5, max_tries=None):
74  """Wait for a condition to become true.
75
76  Repeatedly call the function condition(), with no arguments, until it returns
77  a true value.
78
79  If called within a TimeoutRetryThreadGroup, it cooperates nicely with it.
80
81  Args:
82    condition: function with the condition to check
83    wait_period: number of seconds to wait before retrying to check the
84      condition
85    max_tries: maximum number of checks to make, the default tries forever
86      or until the TimeoutRetryThreadGroup expires.
87
88  Returns:
89    The true value returned by the condition, or None if the condition was
90    not met after max_tries.
91
92  Raises:
93    reraiser_thread.TimeoutError: if the current thread is a
94      TimeoutRetryThreadGroup and the timeout expires.
95  """
96  condition_name = condition.__name__
97  timeout_thread_group = CurrentTimeoutThreadGroup()
98  while max_tries is None or max_tries > 0:
99    result = condition()
100    if max_tries is not None:
101      max_tries -= 1
102    msg = ['condition', repr(condition_name), 'met' if result else 'not met']
103    if timeout_thread_group:
104      # pylint: disable=no-member
105      msg.append('(%.1fs)' % timeout_thread_group.GetElapsedTime())
106    logging.info(' '.join(msg))
107    if result:
108      return result
109    if timeout_thread_group:
110      # pylint: disable=no-member
111      timeout_thread_group.GetRemainingTime(wait_period,
112          msg='Timed out waiting for %r' % condition_name)
113    time.sleep(wait_period)
114  return None
115
116
117def _LogLastException(thread_name, attempt, max_attempts, log_func):
118  log_func('*' * 80)
119  log_func('Exception on thread %s (attempt %d of %d)', thread_name,
120                   attempt, max_attempts)
121  log_func('*' * 80)
122  fmt_exc = ''.join(traceback.format_exc())
123  for line in fmt_exc.splitlines():
124    log_func(line.rstrip())
125  log_func('*' * 80)
126
127
128def AlwaysRetry(_exception):
129  return True
130
131
132def Run(func, timeout, retries, args=None, kwargs=None, desc=None,
133        error_log_func=logging.critical, retry_if_func=AlwaysRetry):
134  """Runs the passed function in a separate thread with timeouts and retries.
135
136  Args:
137    func: the function to be wrapped.
138    timeout: the timeout in seconds for each try.
139    retries: the number of retries.
140    args: list of positional args to pass to |func|.
141    kwargs: dictionary of keyword args to pass to |func|.
142    desc: An optional description of |func| used in logging. If omitted,
143      |func.__name__| will be used.
144    error_log_func: Logging function when logging errors.
145    retry_if_func: Unary callable that takes an exception and returns
146      whether |func| should be retried. Defaults to always retrying.
147
148  Returns:
149    The return value of func(*args, **kwargs).
150  """
151  if not args:
152    args = []
153  if not kwargs:
154    kwargs = {}
155
156  num_try = 1
157  while True:
158    thread_name = 'TimeoutThread-%d-for-%s' % (num_try,
159                                               threading.current_thread().name)
160    child_thread = reraiser_thread.ReraiserThread(lambda: func(*args, **kwargs),
161                                                  name=thread_name)
162    try:
163      thread_group = TimeoutRetryThreadGroup(timeout, threads=[child_thread])
164      thread_group.StartAll(will_block=True)
165      while True:
166        thread_group.JoinAll(watcher=thread_group.GetWatcher(), timeout=60,
167                             error_log_func=error_log_func)
168        if thread_group.IsAlive():
169          logging.info('Still working on %s', desc if desc else func.__name__)
170        else:
171          return thread_group.GetAllReturnValues()[0]
172    except reraiser_thread.TimeoutError as e:
173      # Timeouts already get their stacks logged.
174      if num_try > retries or not retry_if_func(e):
175        raise
176      # Do not catch KeyboardInterrupt.
177    except Exception as e:  # pylint: disable=broad-except
178      if num_try > retries or not retry_if_func(e):
179        raise
180      _LogLastException(thread_name, num_try, retries + 1, error_log_func)
181    num_try += 1
182