1# Copyright 2013 The Chromium Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5"""Thread and ThreadGroup that reraise exceptions on the main thread."""
6# pylint: disable=W0212
7
8import logging
9import sys
10import threading
11import time
12import traceback
13
14from devil.utils import watchdog_timer
15
16
17class TimeoutError(Exception):
18  """Module-specific timeout exception."""
19  pass
20
21
22def LogThreadStack(thread, error_log_func=logging.critical):
23  """Log the stack for the given thread.
24
25  Args:
26    thread: a threading.Thread instance.
27    error_log_func: Logging function when logging errors.
28  """
29  stack = sys._current_frames()[thread.ident]
30  error_log_func('*' * 80)
31  error_log_func('Stack dump for thread %r', thread.name)
32  error_log_func('*' * 80)
33  for filename, lineno, name, line in traceback.extract_stack(stack):
34    error_log_func('File: "%s", line %d, in %s', filename, lineno, name)
35    if line:
36      error_log_func('  %s', line.strip())
37  error_log_func('*' * 80)
38
39
40class ReraiserThread(threading.Thread):
41  """Thread class that can reraise exceptions."""
42
43  def __init__(self, func, args=None, kwargs=None, name=None):
44    """Initialize thread.
45
46    Args:
47      func: callable to call on a new thread.
48      args: list of positional arguments for callable, defaults to empty.
49      kwargs: dictionary of keyword arguments for callable, defaults to empty.
50      name: thread name, defaults to the function name.
51    """
52    if not name:
53      if hasattr(func, '__name__') and func.__name__ != '<lambda>':
54        name = func.__name__
55      else:
56        name = 'anonymous'
57    super(ReraiserThread, self).__init__(name=name)
58    if not args:
59      args = []
60    if not kwargs:
61      kwargs = {}
62    self.daemon = True
63    self._func = func
64    self._args = args
65    self._kwargs = kwargs
66    self._ret = None
67    self._exc_info = None
68    self._thread_group = None
69
70  def ReraiseIfException(self):
71    """Reraise exception if an exception was raised in the thread."""
72    if self._exc_info:
73      raise self._exc_info[0], self._exc_info[1], self._exc_info[2]
74
75  def GetReturnValue(self):
76    """Reraise exception if present, otherwise get the return value."""
77    self.ReraiseIfException()
78    return self._ret
79
80  # override
81  def run(self):
82    """Overrides Thread.run() to add support for reraising exceptions."""
83    try:
84      self._ret = self._func(*self._args, **self._kwargs)
85    except:  # pylint: disable=W0702
86      self._exc_info = sys.exc_info()
87
88
89class ReraiserThreadGroup(object):
90  """A group of ReraiserThread objects."""
91
92  def __init__(self, threads=None):
93    """Initialize thread group.
94
95    Args:
96      threads: a list of ReraiserThread objects; defaults to empty.
97    """
98    self._threads = []
99    # Set when a thread from one group has called JoinAll on another. It is used
100    # to detect when a there is a TimeoutRetryThread active that links to the
101    # current thread.
102    self.blocked_parent_thread_group = None
103    if threads:
104      for thread in threads:
105        self.Add(thread)
106
107  def Add(self, thread):
108    """Add a thread to the group.
109
110    Args:
111      thread: a ReraiserThread object.
112    """
113    assert thread._thread_group is None
114    thread._thread_group = self
115    self._threads.append(thread)
116
117  def StartAll(self, will_block=False):
118    """Start all threads.
119
120    Args:
121      will_block: Whether the calling thread will subsequently block on this
122        thread group. Causes the active ReraiserThreadGroup (if there is one)
123        to be marked as blocking on this thread group.
124    """
125    if will_block:
126      # Multiple threads blocking on the same outer thread should not happen in
127      # practice.
128      assert not self.blocked_parent_thread_group
129      self.blocked_parent_thread_group = CurrentThreadGroup()
130    for thread in self._threads:
131      thread.start()
132
133  def _JoinAll(self, watcher=None, timeout=None):
134    """Join all threads without stack dumps.
135
136    Reraises exceptions raised by the child threads and supports breaking
137    immediately on exceptions raised on the main thread.
138
139    Args:
140      watcher: Watchdog object providing the thread timeout. If none is
141          provided, the thread will never be timed out.
142      timeout: An optional number of seconds to wait before timing out the join
143          operation. This will not time out the threads.
144    """
145    if watcher is None:
146      watcher = watchdog_timer.WatchdogTimer(None)
147    alive_threads = self._threads[:]
148    end_time = (time.time() + timeout) if timeout else None
149    try:
150      while alive_threads and (end_time is None or end_time > time.time()):
151        for thread in alive_threads[:]:
152          if watcher.IsTimedOut():
153            raise TimeoutError('Timed out waiting for %d of %d threads.' %
154                               (len(alive_threads), len(self._threads)))
155          # Allow the main thread to periodically check for interrupts.
156          thread.join(0.1)
157          if not thread.isAlive():
158            alive_threads.remove(thread)
159      # All threads are allowed to complete before reraising exceptions.
160      for thread in self._threads:
161        thread.ReraiseIfException()
162    finally:
163      self.blocked_parent_thread_group = None
164
165  def IsAlive(self):
166    """Check whether any of the threads are still alive.
167
168    Returns:
169      Whether any of the threads are still alive.
170    """
171    return any(t.isAlive() for t in self._threads)
172
173  def JoinAll(self, watcher=None, timeout=None,
174              error_log_func=logging.critical):
175    """Join all threads.
176
177    Reraises exceptions raised by the child threads and supports breaking
178    immediately on exceptions raised on the main thread. Unfinished threads'
179    stacks will be logged on watchdog timeout.
180
181    Args:
182      watcher: Watchdog object providing the thread timeout. If none is
183          provided, the thread will never be timed out.
184      timeout: An optional number of seconds to wait before timing out the join
185          operation. This will not time out the threads.
186      error_log_func: Logging function when logging errors.
187    """
188    try:
189      self._JoinAll(watcher, timeout)
190    except TimeoutError:
191      error_log_func('Timed out. Dumping threads.')
192      for thread in (t for t in self._threads if t.isAlive()):
193        LogThreadStack(thread, error_log_func=error_log_func)
194      raise
195
196  def GetAllReturnValues(self, watcher=None):
197    """Get all return values, joining all threads if necessary.
198
199    Args:
200      watcher: same as in |JoinAll|. Only used if threads are alive.
201    """
202    if any([t.isAlive() for t in self._threads]):
203      self.JoinAll(watcher)
204    return [t.GetReturnValue() for t in self._threads]
205
206
207def CurrentThreadGroup():
208  """Returns the ReraiserThreadGroup that owns the running thread.
209
210  Returns:
211    The current thread group, otherwise None.
212  """
213  current_thread = threading.current_thread()
214  if isinstance(current_thread, ReraiserThread):
215    return current_thread._thread_group  # pylint: disable=no-member
216  return None
217
218
219def RunAsync(funcs, watcher=None):
220  """Executes the given functions in parallel and returns their results.
221
222  Args:
223    funcs: List of functions to perform on their own threads.
224    watcher: Watchdog object providing timeout, by default waits forever.
225
226  Returns:
227    A list of return values in the order of the given functions.
228  """
229  thread_group = ReraiserThreadGroup(ReraiserThread(f) for f in funcs)
230  thread_group.StartAll(will_block=True)
231  return thread_group.GetAllReturnValues(watcher=watcher)
232