1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Coordinator to help multiple threads stop when requested."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import contextlib
21import sys
22import threading
23import time
24
25import six
26
27from tensorflow.python.framework import errors
28from tensorflow.python.platform import tf_logging as logging
29from tensorflow.python.util import compat
30from tensorflow.python.util.tf_export import tf_export
31
32
33@tf_export("train.Coordinator")
34class Coordinator(object):
35  """A coordinator for threads.
36
37  This class implements a simple mechanism to coordinate the termination of a
38  set of threads.
39
40  #### Usage:
41
42  ```python
43  # Create a coordinator.
44  coord = Coordinator()
45  # Start a number of threads, passing the coordinator to each of them.
46  ...start thread 1...(coord, ...)
47  ...start thread N...(coord, ...)
48  # Wait for all the threads to terminate.
49  coord.join(threads)
50  ```
51
52  Any of the threads can call `coord.request_stop()` to ask for all the threads
53  to stop.  To cooperate with the requests, each thread must check for
54  `coord.should_stop()` on a regular basis.  `coord.should_stop()` returns
55  `True` as soon as `coord.request_stop()` has been called.
56
57  A typical thread running with a coordinator will do something like:
58
59  ```python
60  while not coord.should_stop():
61    ...do some work...
62  ```
63
64  #### Exception handling:
65
66  A thread can report an exception to the coordinator as part of the
67  `request_stop()` call.  The exception will be re-raised from the
68  `coord.join()` call.
69
70  Thread code:
71
72  ```python
73  try:
74    while not coord.should_stop():
75      ...do some work...
76  except Exception as e:
77    coord.request_stop(e)
78  ```
79
80  Main code:
81
82  ```python
83  try:
84    ...
85    coord = Coordinator()
86    # Start a number of threads, passing the coordinator to each of them.
87    ...start thread 1...(coord, ...)
88    ...start thread N...(coord, ...)
89    # Wait for all the threads to terminate.
90    coord.join(threads)
91  except Exception as e:
92    ...exception that was passed to coord.request_stop()
93  ```
94
95  To simplify the thread implementation, the Coordinator provides a
96  context handler `stop_on_exception()` that automatically requests a stop if
97  an exception is raised.  Using the context handler the thread code above
98  can be written as:
99
100  ```python
101  with coord.stop_on_exception():
102    while not coord.should_stop():
103      ...do some work...
104  ```
105
106  #### Grace period for stopping:
107
108  After a thread has called `coord.request_stop()` the other threads have a
109  fixed time to stop, this is called the 'stop grace period' and defaults to 2
110  minutes.  If any of the threads is still alive after the grace period expires
111  `coord.join()` raises a RuntimeError reporting the laggards.
112
113  ```python
114  try:
115    ...
116    coord = Coordinator()
117    # Start a number of threads, passing the coordinator to each of them.
118    ...start thread 1...(coord, ...)
119    ...start thread N...(coord, ...)
120    # Wait for all the threads to terminate, give them 10s grace period
121    coord.join(threads, stop_grace_period_secs=10)
122  except RuntimeError:
123    ...one of the threads took more than 10s to stop after request_stop()
124    ...was called.
125  except Exception:
126    ...exception that was passed to coord.request_stop()
127  ```
128  """
129
130  def __init__(self, clean_stop_exception_types=None):
131    """Create a new Coordinator.
132
133    Args:
134      clean_stop_exception_types: Optional tuple of Exception types that should
135        cause a clean stop of the coordinator. If an exception of one of these
136        types is reported to `request_stop(ex)` the coordinator will behave as
137        if `request_stop(None)` was called.  Defaults to
138        `(tf.errors.OutOfRangeError,)` which is used by input queues to signal
139        the end of input. When feeding training data from a Python iterator it
140        is common to add `StopIteration` to this list.
141    """
142    if clean_stop_exception_types is None:
143      clean_stop_exception_types = (errors.OutOfRangeError,)
144    self._clean_stop_exception_types = tuple(clean_stop_exception_types)
145    # Protects all attributes.
146    self._lock = threading.Lock()
147    # Event set when threads must stop.
148    self._stop_event = threading.Event()
149    # Python exc_info to report.
150    # If not None, it should hold the returned value of sys.exc_info(), which is
151    # a tuple containing exception (type, value, traceback).
152    self._exc_info_to_raise = None
153    # True if we have called join() already.
154    self._joined = False
155    # Set of threads registered for joining when join() is called.  These
156    # threads will be joined in addition to the threads passed to the join()
157    # call.  It's ok if threads are both registered and passed to the join()
158    # call.
159    self._registered_threads = set()
160
161  def _filter_exception(self, ex):
162    """Check if the exception indicated in 'ex' should be ignored.
163
164    This method examines `ex` to check if it is an exception that should be
165    reported to the users.  If yes, it returns `ex` as is, otherwise it returns
166    None.
167
168    The code returns None for exception types listed in
169    `_clean_stop_exception_types`.
170
171    Args:
172      ex: None, an `Exception`, or a Python `exc_info` tuple as returned by
173        `sys.exc_info()`.
174
175    Returns:
176      ex or None.
177    """
178    if isinstance(ex, tuple):
179      ex2 = ex[1]
180    else:
181      ex2 = ex
182    if isinstance(ex2, self._clean_stop_exception_types):
183      # Ignore the exception.
184      ex = None
185    return ex
186
187  def request_stop(self, ex=None):
188    """Request that the threads stop.
189
190    After this is called, calls to `should_stop()` will return `True`.
191
192    Note: If an exception is being passed in, in must be in the context of
193    handling the exception (i.e. `try: ... except Exception as ex: ...`) and not
194    a newly created one.
195
196    Args:
197      ex: Optional `Exception`, or Python `exc_info` tuple as returned by
198        `sys.exc_info()`.  If this is the first call to `request_stop()` the
199        corresponding exception is recorded and re-raised from `join()`.
200    """
201    with self._lock:
202      ex = self._filter_exception(ex)
203      # If we have already joined the coordinator the exception will not have a
204      # chance to be reported, so just raise it normally.  This can happen if
205      # you continue to use a session have having stopped and joined the
206      # coordinator threads.
207      if self._joined:
208        if isinstance(ex, tuple):
209          six.reraise(*ex)
210        elif ex is not None:
211          # NOTE(touts): This is bogus if request_stop() is not called
212          # from the exception handler that raised ex.
213          six.reraise(*sys.exc_info())
214      if not self._stop_event.is_set():
215        if ex and self._exc_info_to_raise is None:
216          if isinstance(ex, tuple):
217            logging.info("Error reported to Coordinator: %s",
218                         compat.as_str_any(ex[1]),
219                         exc_info=ex)
220            self._exc_info_to_raise = ex
221          else:
222            logging.info("Error reported to Coordinator: %s, %s",
223                         type(ex),
224                         compat.as_str_any(ex))
225            self._exc_info_to_raise = sys.exc_info()
226          # self._exc_info_to_raise should contain a tuple containing exception
227          # (type, value, traceback)
228          if (len(self._exc_info_to_raise) != 3 or
229              not self._exc_info_to_raise[0] or
230              not self._exc_info_to_raise[1]):
231            # Raise, catch and record the exception here so that error happens
232            # where expected.
233            try:
234              raise ValueError(
235                  "ex must be a tuple or sys.exc_info must return the current "
236                  "exception: %s"
237                  % self._exc_info_to_raise)
238            except ValueError:
239              # Record this error so it kills the coordinator properly.
240              # NOTE(touts): As above, this is bogus if request_stop() is not
241              # called from the exception handler that raised ex.
242              self._exc_info_to_raise = sys.exc_info()
243
244        self._stop_event.set()
245
246  def clear_stop(self):
247    """Clears the stop flag.
248
249    After this is called, calls to `should_stop()` will return `False`.
250    """
251    with self._lock:
252      self._joined = False
253      self._exc_info_to_raise = None
254      if self._stop_event.is_set():
255        self._stop_event.clear()
256
257  def should_stop(self):
258    """Check if stop was requested.
259
260    Returns:
261      True if a stop was requested.
262    """
263    return self._stop_event.is_set()
264
265  @contextlib.contextmanager
266  def stop_on_exception(self):
267    """Context manager to request stop when an Exception is raised.
268
269    Code that uses a coordinator must catch exceptions and pass
270    them to the `request_stop()` method to stop the other threads
271    managed by the coordinator.
272
273    This context handler simplifies the exception handling.
274    Use it as follows:
275
276    ```python
277    with coord.stop_on_exception():
278      # Any exception raised in the body of the with
279      # clause is reported to the coordinator before terminating
280      # the execution of the body.
281      ...body...
282    ```
283
284    This is completely equivalent to the slightly longer code:
285
286    ```python
287    try:
288      ...body...
289    except:
290      coord.request_stop(sys.exc_info())
291    ```
292
293    Yields:
294      nothing.
295    """
296    try:
297      yield
298    except:  # pylint: disable=bare-except
299      self.request_stop(ex=sys.exc_info())
300
301  def wait_for_stop(self, timeout=None):
302    """Wait till the Coordinator is told to stop.
303
304    Args:
305      timeout: Float.  Sleep for up to that many seconds waiting for
306        should_stop() to become True.
307
308    Returns:
309      True if the Coordinator is told stop, False if the timeout expired.
310    """
311    return self._stop_event.wait(timeout)
312
313  def register_thread(self, thread):
314    """Register a thread to join.
315
316    Args:
317      thread: A Python thread to join.
318    """
319    with self._lock:
320      self._registered_threads.add(thread)
321
322  def join(self, threads=None, stop_grace_period_secs=120,
323           ignore_live_threads=False):
324    """Wait for threads to terminate.
325
326    This call blocks until a set of threads have terminated.  The set of thread
327    is the union of the threads passed in the `threads` argument and the list
328    of threads that registered with the coordinator by calling
329    `Coordinator.register_thread()`.
330
331    After the threads stop, if an `exc_info` was passed to `request_stop`, that
332    exception is re-raised.
333
334    Grace period handling: When `request_stop()` is called, threads are given
335    'stop_grace_period_secs' seconds to terminate.  If any of them is still
336    alive after that period expires, a `RuntimeError` is raised.  Note that if
337    an `exc_info` was passed to `request_stop()` then it is raised instead of
338    that `RuntimeError`.
339
340    Args:
341      threads: List of `threading.Threads`. The started threads to join in
342        addition to the registered threads.
343      stop_grace_period_secs: Number of seconds given to threads to stop after
344        `request_stop()` has been called.
345      ignore_live_threads: If `False`, raises an error if any of the threads are
346        still alive after `stop_grace_period_secs`.
347
348    Raises:
349      RuntimeError: If any thread is still alive after `request_stop()`
350        is called and the grace period expires.
351    """
352    # Threads registered after this call will not be joined.
353    with self._lock:
354      if threads is None:
355        threads = self._registered_threads
356      else:
357        threads = self._registered_threads.union(set(threads))
358      # Copy the set into a list to avoid race conditions where a new thread
359      # is added while we are waiting.
360      threads = list(threads)
361
362    # Wait for all threads to stop or for request_stop() to be called.
363    while any(t.is_alive() for t in threads) and not self.wait_for_stop(1.0):
364      pass
365
366    # If any thread is still alive, wait for the grace period to expire.
367    # By the time this check is executed, threads may still be shutting down,
368    # so we add a sleep of increasing duration to give them a chance to shut
369    # down without losing too many cycles.
370    # The sleep duration is limited to the remaining grace duration.
371    stop_wait_secs = 0.001
372    while any(t.is_alive() for t in threads) and stop_grace_period_secs >= 0.0:
373      time.sleep(stop_wait_secs)
374      stop_grace_period_secs -= stop_wait_secs
375      stop_wait_secs = 2 * stop_wait_secs
376      # Keep the waiting period within sane bounds.
377      # The minimum value is to avoid decreasing stop_wait_secs to a value
378      # that could cause stop_grace_period_secs to remain unchanged.
379      stop_wait_secs = max(min(stop_wait_secs, stop_grace_period_secs), 0.001)
380
381    # List the threads still alive after the grace period.
382    stragglers = [t.name for t in threads if t.is_alive()]
383
384    # Terminate with an exception if appropriate.
385    with self._lock:
386      self._joined = True
387      self._registered_threads = set()
388      if self._exc_info_to_raise:
389        six.reraise(*self._exc_info_to_raise)
390      elif stragglers:
391        if ignore_live_threads:
392          logging.info("Coordinator stopped with threads still running: %s",
393                       " ".join(stragglers))
394        else:
395          raise RuntimeError(
396              "Coordinator stopped with threads still running: %s" %
397              " ".join(stragglers))
398
399  @property
400  def joined(self):
401    return self._joined
402
403  def raise_requested_exception(self):
404    """If an exception has been passed to `request_stop`, this raises it."""
405    with self._lock:
406      if self._exc_info_to_raise:
407        six.reraise(*self._exc_info_to_raise)
408
409
410# Threads for the standard services.
411@tf_export(v1=["train.LooperThread"])
412class LooperThread(threading.Thread):
413  """A thread that runs code repeatedly, optionally on a timer.
414
415  This thread class is intended to be used with a `Coordinator`.  It repeatedly
416  runs code specified either as `target` and `args` or by the `run_loop()`
417  method.
418
419  Before each run the thread checks if the coordinator has requested stop.  In
420  that case the looper thread terminates immediately.
421
422  If the code being run raises an exception, that exception is reported to the
423  coordinator and the thread terminates.  The coordinator will then request all
424  the other threads it coordinates to stop.
425
426  You typically pass looper threads to the supervisor `Join()` method.
427  """
428
429  def __init__(self, coord, timer_interval_secs, target=None, args=None,
430               kwargs=None):
431    """Create a LooperThread.
432
433    Args:
434      coord: A Coordinator.
435      timer_interval_secs: Time boundaries at which to call Run(), or None
436        if it should be called back to back.
437      target: Optional callable object that will be executed in the thread.
438      args: Optional arguments to pass to `target` when calling it.
439      kwargs: Optional keyword arguments to pass to `target` when calling it.
440
441    Raises:
442      ValueError: If one of the arguments is invalid.
443    """
444    if not isinstance(coord, Coordinator):
445      raise ValueError("'coord' argument must be a Coordinator: %s" % coord)
446    super(LooperThread, self).__init__()
447    self.daemon = True
448    self._coord = coord
449    self._timer_interval_secs = timer_interval_secs
450    self._target = target
451    if self._target:
452      self._args = args or ()
453      self._kwargs = kwargs or {}
454    elif args or kwargs:
455      raise ValueError("'args' and 'kwargs' argument require that you also "
456                       "pass 'target'")
457    self._coord.register_thread(self)
458
459  @staticmethod
460  def loop(coord, timer_interval_secs, target, args=None, kwargs=None):
461    """Start a LooperThread that calls a function periodically.
462
463    If `timer_interval_secs` is None the thread calls `target(args)`
464    repeatedly.  Otherwise `target(args)` is called every `timer_interval_secs`
465    seconds.  The thread terminates when a stop of the coordinator is
466    requested.
467
468    Args:
469      coord: A Coordinator.
470      timer_interval_secs: Number. Time boundaries at which to call `target`.
471      target: A callable object.
472      args: Optional arguments to pass to `target` when calling it.
473      kwargs: Optional keyword arguments to pass to `target` when calling it.
474
475    Returns:
476      The started thread.
477    """
478    looper = LooperThread(coord, timer_interval_secs, target=target, args=args,
479                          kwargs=kwargs)
480    looper.start()
481    return looper
482
483  def run(self):
484    with self._coord.stop_on_exception():
485      self.start_loop()
486      if self._timer_interval_secs is None:
487        # Call back-to-back.
488        while not self._coord.should_stop():
489          self.run_loop()
490      else:
491        # Next time at which to call run_loop(), starts as 'now'.
492        next_timer_time = time.time()
493        while not self._coord.wait_for_stop(next_timer_time - time.time()):
494          next_timer_time += self._timer_interval_secs
495          self.run_loop()
496      self.stop_loop()
497
498  def start_loop(self):
499    """Called when the thread starts."""
500    pass
501
502  def stop_loop(self):
503    """Called when the thread stops."""
504    pass
505
506  def run_loop(self):
507    """Called at 'timer_interval_secs' boundaries."""
508    if self._target:
509      self._target(*self._args, **self._kwargs)
510