1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Create threads to run multiple enqueue ops."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import threading
22import weakref
23
24from tensorflow.core.protobuf import queue_runner_pb2
25from tensorflow.python.client import session
26from tensorflow.python.eager import context
27from tensorflow.python.framework import errors
28from tensorflow.python.framework import ops
29from tensorflow.python.platform import tf_logging as logging
30from tensorflow.python.util import deprecation
31from tensorflow.python.util.tf_export import tf_export
32
33_DEPRECATION_INSTRUCTION = (
34    "To construct input pipelines, use the `tf.data` module.")
35
36
37@tf_export(v1=["train.queue_runner.QueueRunner", "train.QueueRunner"])
38class QueueRunner(object):
39  """Holds a list of enqueue operations for a queue, each to be run in a thread.
40
41  Queues are a convenient TensorFlow mechanism to compute tensors
42  asynchronously using multiple threads. For example in the canonical 'Input
43  Reader' setup one set of threads generates filenames in a queue; a second set
44  of threads read records from the files, processes them, and enqueues tensors
45  on a second queue; a third set of threads dequeues these input records to
46  construct batches and runs them through training operations.
47
48  There are several delicate issues when running multiple threads that way:
49  closing the queues in sequence as the input is exhausted, correctly catching
50  and reporting exceptions, etc.
51
52  The `QueueRunner`, combined with the `Coordinator`, helps handle these issues.
53
54  @compatibility(eager)
55  QueueRunners are not compatible with eager execution. Instead, please
56  use `tf.data` to get data into your model.
57  @end_compatibility
58  """
59
60  @deprecation.deprecated(None, _DEPRECATION_INSTRUCTION)
61  def __init__(self, queue=None, enqueue_ops=None, close_op=None,
62               cancel_op=None, queue_closed_exception_types=None,
63               queue_runner_def=None, import_scope=None):
64    """Create a QueueRunner.
65
66    On construction the `QueueRunner` adds an op to close the queue.  That op
67    will be run if the enqueue ops raise exceptions.
68
69    When you later call the `create_threads()` method, the `QueueRunner` will
70    create one thread for each op in `enqueue_ops`.  Each thread will run its
71    enqueue op in parallel with the other threads.  The enqueue ops do not have
72    to all be the same op, but it is expected that they all enqueue tensors in
73    `queue`.
74
75    Args:
76      queue: A `Queue`.
77      enqueue_ops: List of enqueue ops to run in threads later.
78      close_op: Op to close the queue. Pending enqueue ops are preserved.
79      cancel_op: Op to close the queue and cancel pending enqueue ops.
80      queue_closed_exception_types: Optional tuple of Exception types that
81        indicate that the queue has been closed when raised during an enqueue
82        operation.  Defaults to `(tf.errors.OutOfRangeError,)`.  Another common
83        case includes `(tf.errors.OutOfRangeError, tf.errors.CancelledError)`,
84        when some of the enqueue ops may dequeue from other Queues.
85      queue_runner_def: Optional `QueueRunnerDef` protocol buffer. If specified,
86        recreates the QueueRunner from its contents. `queue_runner_def` and the
87        other arguments are mutually exclusive.
88      import_scope: Optional `string`. Name scope to add. Only used when
89        initializing from protocol buffer.
90
91    Raises:
92      ValueError: If both `queue_runner_def` and `queue` are both specified.
93      ValueError: If `queue` or `enqueue_ops` are not provided when not
94        restoring from `queue_runner_def`.
95      RuntimeError: If eager execution is enabled.
96    """
97    if context.executing_eagerly():
98      raise RuntimeError(
99          "QueueRunners are not supported when eager execution is enabled. "
100          "Instead, please use tf.data to get data into your model.")
101
102    if queue_runner_def:
103      if queue or enqueue_ops:
104        raise ValueError("queue_runner_def and queue are mutually exclusive.")
105      self._init_from_proto(queue_runner_def,
106                            import_scope=import_scope)
107    else:
108      self._init_from_args(
109          queue=queue, enqueue_ops=enqueue_ops,
110          close_op=close_op, cancel_op=cancel_op,
111          queue_closed_exception_types=queue_closed_exception_types)
112    # Protect the count of runs to wait for.
113    self._lock = threading.Lock()
114    # A map from a session object to the number of outstanding queue runner
115    # threads for that session.
116    self._runs_per_session = weakref.WeakKeyDictionary()
117    # List of exceptions raised by the running threads.
118    self._exceptions_raised = []
119
120  def _init_from_args(self, queue=None, enqueue_ops=None, close_op=None,
121                      cancel_op=None, queue_closed_exception_types=None):
122    """Create a QueueRunner from arguments.
123
124    Args:
125      queue: A `Queue`.
126      enqueue_ops: List of enqueue ops to run in threads later.
127      close_op: Op to close the queue. Pending enqueue ops are preserved.
128      cancel_op: Op to close the queue and cancel pending enqueue ops.
129      queue_closed_exception_types: Tuple of exception types, which indicate
130        the queue has been safely closed.
131
132    Raises:
133      ValueError: If `queue` or `enqueue_ops` are not provided when not
134        restoring from `queue_runner_def`.
135      TypeError: If `queue_closed_exception_types` is provided, but is not
136        a non-empty tuple of error types (subclasses of `tf.errors.OpError`).
137    """
138    if not queue or not enqueue_ops:
139      raise ValueError("Must provide queue and enqueue_ops.")
140    self._queue = queue
141    self._enqueue_ops = enqueue_ops
142    self._close_op = close_op
143    self._cancel_op = cancel_op
144    if queue_closed_exception_types is not None:
145      if (not isinstance(queue_closed_exception_types, tuple)
146          or not queue_closed_exception_types
147          or not all(issubclass(t, errors.OpError)
148                     for t in queue_closed_exception_types)):
149        raise TypeError(
150            "queue_closed_exception_types, when provided, "
151            "must be a tuple of tf.error types, but saw: %s"
152            % queue_closed_exception_types)
153    self._queue_closed_exception_types = queue_closed_exception_types
154    # Close when no more will be produced, but pending enqueues should be
155    # preserved.
156    if self._close_op is None:
157      self._close_op = self._queue.close()
158    # Close and cancel pending enqueues since there was an error and we want
159    # to unblock everything so we can cleanly exit.
160    if self._cancel_op is None:
161      self._cancel_op = self._queue.close(cancel_pending_enqueues=True)
162    if not self._queue_closed_exception_types:
163      self._queue_closed_exception_types = (errors.OutOfRangeError,)
164    else:
165      self._queue_closed_exception_types = tuple(
166          self._queue_closed_exception_types)
167
168  def _init_from_proto(self, queue_runner_def, import_scope=None):
169    """Create a QueueRunner from `QueueRunnerDef`.
170
171    Args:
172      queue_runner_def: Optional `QueueRunnerDef` protocol buffer.
173      import_scope: Optional `string`. Name scope to add.
174    """
175    assert isinstance(queue_runner_def, queue_runner_pb2.QueueRunnerDef)
176    g = ops.get_default_graph()
177    self._queue = g.as_graph_element(
178        ops.prepend_name_scope(queue_runner_def.queue_name, import_scope))
179    self._enqueue_ops = [g.as_graph_element(
180        ops.prepend_name_scope(op, import_scope))
181                         for op in queue_runner_def.enqueue_op_name]
182    self._close_op = g.as_graph_element(ops.prepend_name_scope(
183        queue_runner_def.close_op_name, import_scope))
184    self._cancel_op = g.as_graph_element(ops.prepend_name_scope(
185        queue_runner_def.cancel_op_name, import_scope))
186    self._queue_closed_exception_types = tuple(
187        errors.exception_type_from_error_code(code)
188        for code in queue_runner_def.queue_closed_exception_types)
189    # Legacy support for old QueueRunnerDefs created before this field
190    # was added.
191    if not self._queue_closed_exception_types:
192      self._queue_closed_exception_types = (errors.OutOfRangeError,)
193
194  @property
195  def queue(self):
196    return self._queue
197
198  @property
199  def enqueue_ops(self):
200    return self._enqueue_ops
201
202  @property
203  def close_op(self):
204    return self._close_op
205
206  @property
207  def cancel_op(self):
208    return self._cancel_op
209
210  @property
211  def queue_closed_exception_types(self):
212    return self._queue_closed_exception_types
213
214  @property
215  def exceptions_raised(self):
216    """Exceptions raised but not handled by the `QueueRunner` threads.
217
218    Exceptions raised in queue runner threads are handled in one of two ways
219    depending on whether or not a `Coordinator` was passed to
220    `create_threads()`:
221
222    * With a `Coordinator`, exceptions are reported to the coordinator and
223      forgotten by the `QueueRunner`.
224    * Without a `Coordinator`, exceptions are captured by the `QueueRunner` and
225      made available in this `exceptions_raised` property.
226
227    Returns:
228      A list of Python `Exception` objects.  The list is empty if no exception
229      was captured.  (No exceptions are captured when using a Coordinator.)
230    """
231    return self._exceptions_raised
232
233  @property
234  def name(self):
235    """The string name of the underlying Queue."""
236    return self._queue.name
237
238  # pylint: disable=broad-except
239  def _run(self, sess, enqueue_op, coord=None):
240    """Execute the enqueue op in a loop, close the queue in case of error.
241
242    Args:
243      sess: A Session.
244      enqueue_op: The Operation to run.
245      coord: Optional Coordinator object for reporting errors and checking
246        for stop conditions.
247    """
248    decremented = False
249    try:
250      # Make a cached callable from the `enqueue_op` to decrease the
251      # Python overhead in the queue-runner loop.
252      enqueue_callable = sess.make_callable(enqueue_op)
253      while True:
254        if coord and coord.should_stop():
255          break
256        try:
257          enqueue_callable()
258        except self._queue_closed_exception_types:  # pylint: disable=catching-non-exception
259          # This exception indicates that a queue was closed.
260          with self._lock:
261            self._runs_per_session[sess] -= 1
262            decremented = True
263            if self._runs_per_session[sess] == 0:
264              try:
265                sess.run(self._close_op)
266              except Exception as e:
267                # Intentionally ignore errors from close_op.
268                logging.vlog(1, "Ignored exception: %s", str(e))
269            return
270    except Exception as e:
271      # This catches all other exceptions.
272      if coord:
273        coord.request_stop(e)
274      else:
275        logging.error("Exception in QueueRunner: %s", str(e))
276        with self._lock:
277          self._exceptions_raised.append(e)
278        raise
279    finally:
280      # Make sure we account for all terminations: normal or errors.
281      if not decremented:
282        with self._lock:
283          self._runs_per_session[sess] -= 1
284
285  def _close_on_stop(self, sess, cancel_op, coord):
286    """Close the queue when the Coordinator requests stop.
287
288    Args:
289      sess: A Session.
290      cancel_op: The Operation to run.
291      coord: Coordinator.
292    """
293    coord.wait_for_stop()
294    try:
295      sess.run(cancel_op)
296    except Exception as e:
297      # Intentionally ignore errors from cancel_op.
298      logging.vlog(1, "Ignored exception: %s", str(e))
299  # pylint: enable=broad-except
300
301  def create_threads(self, sess, coord=None, daemon=False, start=False):
302    """Create threads to run the enqueue ops for the given session.
303
304    This method requires a session in which the graph was launched.  It creates
305    a list of threads, optionally starting them.  There is one thread for each
306    op passed in `enqueue_ops`.
307
308    The `coord` argument is an optional coordinator that the threads will use
309    to terminate together and report exceptions.  If a coordinator is given,
310    this method starts an additional thread to close the queue when the
311    coordinator requests a stop.
312
313    If previously created threads for the given session are still running, no
314    new threads will be created.
315
316    Args:
317      sess: A `Session`.
318      coord: Optional `Coordinator` object for reporting errors and checking
319        stop conditions.
320      daemon: Boolean.  If `True` make the threads daemon threads.
321      start: Boolean.  If `True` starts the threads.  If `False` the
322        caller must call the `start()` method of the returned threads.
323
324    Returns:
325      A list of threads.
326    """
327    with self._lock:
328      try:
329        if self._runs_per_session[sess] > 0:
330          # Already started: no new threads to return.
331          return []
332      except KeyError:
333        # We haven't seen this session yet.
334        pass
335      self._runs_per_session[sess] = len(self._enqueue_ops)
336      self._exceptions_raised = []
337
338    ret_threads = []
339    for op in self._enqueue_ops:
340      name = "QueueRunnerThread-{}-{}".format(self.name, op.name)
341      ret_threads.append(threading.Thread(target=self._run,
342                                          args=(sess, op, coord),
343                                          name=name))
344    if coord:
345      name = "QueueRunnerThread-{}-close_on_stop".format(self.name)
346      ret_threads.append(threading.Thread(target=self._close_on_stop,
347                                          args=(sess, self._cancel_op, coord),
348                                          name=name))
349    for t in ret_threads:
350      if coord:
351        coord.register_thread(t)
352      if daemon:
353        t.daemon = True
354      if start:
355        t.start()
356    return ret_threads
357
358  def to_proto(self, export_scope=None):
359    """Converts this `QueueRunner` to a `QueueRunnerDef` protocol buffer.
360
361    Args:
362      export_scope: Optional `string`. Name scope to remove.
363
364    Returns:
365      A `QueueRunnerDef` protocol buffer, or `None` if the `Variable` is not in
366      the specified name scope.
367    """
368    if (export_scope is None or
369        self.queue.name.startswith(export_scope)):
370      queue_runner_def = queue_runner_pb2.QueueRunnerDef()
371      queue_runner_def.queue_name = ops.strip_name_scope(
372          self.queue.name, export_scope)
373      for enqueue_op in self.enqueue_ops:
374        queue_runner_def.enqueue_op_name.append(
375            ops.strip_name_scope(enqueue_op.name, export_scope))
376      queue_runner_def.close_op_name = ops.strip_name_scope(
377          self.close_op.name, export_scope)
378      queue_runner_def.cancel_op_name = ops.strip_name_scope(
379          self.cancel_op.name, export_scope)
380      queue_runner_def.queue_closed_exception_types.extend([
381          errors.error_code_from_exception_type(cls)
382          for cls in self._queue_closed_exception_types])
383      return queue_runner_def
384    else:
385      return None
386
387  @staticmethod
388  def from_proto(queue_runner_def, import_scope=None):
389    """Returns a `QueueRunner` object created from `queue_runner_def`."""
390    return QueueRunner(queue_runner_def=queue_runner_def,
391                       import_scope=import_scope)
392
393
394@tf_export(v1=["train.queue_runner.add_queue_runner", "train.add_queue_runner"])
395@deprecation.deprecated(None, _DEPRECATION_INSTRUCTION)
396def add_queue_runner(qr, collection=ops.GraphKeys.QUEUE_RUNNERS):
397  """Adds a `QueueRunner` to a collection in the graph.
398
399  When building a complex model that uses many queues it is often difficult to
400  gather all the queue runners that need to be run.  This convenience function
401  allows you to add a queue runner to a well known collection in the graph.
402
403  The companion method `start_queue_runners()` can be used to start threads for
404  all the collected queue runners.
405
406  Args:
407    qr: A `QueueRunner`.
408    collection: A `GraphKey` specifying the graph collection to add
409      the queue runner to.  Defaults to `GraphKeys.QUEUE_RUNNERS`.
410  """
411  ops.add_to_collection(collection, qr)
412
413
414@tf_export(v1=["train.queue_runner.start_queue_runners",
415               "train.start_queue_runners"])
416@deprecation.deprecated(None, _DEPRECATION_INSTRUCTION)
417def start_queue_runners(sess=None, coord=None, daemon=True, start=True,
418                        collection=ops.GraphKeys.QUEUE_RUNNERS):
419  """Starts all queue runners collected in the graph.
420
421  This is a companion method to `add_queue_runner()`.  It just starts
422  threads for all queue runners collected in the graph.  It returns
423  the list of all threads.
424
425  Args:
426    sess: `Session` used to run the queue ops.  Defaults to the
427      default session.
428    coord: Optional `Coordinator` for coordinating the started threads.
429    daemon: Whether the threads should be marked as `daemons`, meaning
430      they don't block program exit.
431    start: Set to `False` to only create the threads, not start them.
432    collection: A `GraphKey` specifying the graph collection to
433      get the queue runners from.  Defaults to `GraphKeys.QUEUE_RUNNERS`.
434
435  Raises:
436    ValueError: if `sess` is None and there isn't any default session.
437    TypeError: if `sess` is not a `tf.Session` object.
438
439  Returns:
440    A list of threads.
441
442  Raises:
443    RuntimeError: If called with eager execution enabled.
444    ValueError: If called without a default `tf.Session` registered.
445
446  @compatibility(eager)
447  Not compatible with eager execution. To ingest data under eager execution,
448  use the `tf.data` API instead.
449  @end_compatibility
450  """
451  if context.executing_eagerly():
452    raise RuntimeError("Queues are not compatible with eager execution.")
453  if sess is None:
454    sess = ops.get_default_session()
455    if not sess:
456      raise ValueError("Cannot start queue runners: No default session is "
457                       "registered. Use `with sess.as_default()` or pass an "
458                       "explicit session to tf.start_queue_runners(sess=sess)")
459
460  if not isinstance(sess, session.SessionInterface):
461    # Following check is due to backward compatibility. (b/62061352)
462    if sess.__class__.__name__ in [
463        "MonitoredSession", "SingularMonitoredSession"]:
464      return []
465    raise TypeError("sess must be a `tf.Session` object. "
466                    "Given class: {}".format(sess.__class__))
467
468  queue_runners = ops.get_collection(collection)
469  if not queue_runners:
470    logging.warning(
471        "`tf.train.start_queue_runners()` was called when no queue runners "
472        "were defined. You can safely remove the call to this deprecated "
473        "function.")
474
475  with sess.graph.as_default():
476    threads = []
477    for qr in ops.get_collection(collection):
478      threads.extend(qr.create_threads(sess, coord=coord, daemon=daemon,
479                                       start=start))
480  return threads
481
482
483ops.register_proto_function(ops.GraphKeys.QUEUE_RUNNERS,
484                            proto_type=queue_runner_pb2.QueueRunnerDef,
485                            to_proto=QueueRunner.to_proto,
486                            from_proto=QueueRunner.from_proto)
487