1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14#==============================================================================
15"""Data Flow Operations."""
16# pylint: disable=g-bad-name
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import hashlib
22import threading
23
24import six
25
26from tensorflow.python.eager import context
27from tensorflow.python.framework import dtypes as _dtypes
28from tensorflow.python.framework import ops
29from tensorflow.python.framework import random_seed
30from tensorflow.python.framework import tensor_shape
31from tensorflow.python.framework import tensor_util
32from tensorflow.python.lib.io import python_io
33from tensorflow.python.ops import array_ops
34from tensorflow.python.ops import control_flow_ops
35from tensorflow.python.ops import gen_data_flow_ops
36from tensorflow.python.ops import math_ops
37from tensorflow.python.ops import resource_variable_ops
38# go/tf-wildcard-import
39# pylint: disable=wildcard-import
40from tensorflow.python.ops.gen_data_flow_ops import *
41from tensorflow.python.util import deprecation
42from tensorflow.python.util.compat import collections_abc
43from tensorflow.python.util.tf_export import tf_export
44
45# pylint: enable=wildcard-import
46
47
48def _as_type_list(dtypes):
49  """Convert dtypes to a list of types."""
50  assert dtypes is not None
51  if not (isinstance(dtypes, list) or isinstance(dtypes, tuple)):
52    # We have a single type.
53    return [dtypes]
54  else:
55    # We have a list or tuple of types.
56    return list(dtypes)
57
58
59def _as_shape_list(shapes,
60                   dtypes,
61                   unknown_dim_allowed=False,
62                   unknown_rank_allowed=False):
63  """Convert shapes to a list of tuples of int (or None)."""
64  del dtypes
65  if unknown_dim_allowed:
66    if (not isinstance(shapes, collections_abc.Sequence) or not shapes or
67        any(shape is None or isinstance(shape, int) for shape in shapes)):
68      raise ValueError(
69          "When providing partial shapes, a list of shapes must be provided.")
70  if shapes is None:
71    return None
72  if isinstance(shapes, tensor_shape.TensorShape):
73    shapes = [shapes]
74  if not isinstance(shapes, (tuple, list)):
75    raise TypeError(
76        "shapes must be a TensorShape or a list or tuple of TensorShapes.")
77  if all(shape is None or isinstance(shape, int) for shape in shapes):
78    # We have a single shape.
79    shapes = [shapes]
80  shapes = [tensor_shape.as_shape(shape) for shape in shapes]
81  if not unknown_dim_allowed:
82    if any(not shape.is_fully_defined() for shape in shapes):
83      raise ValueError("All shapes must be fully defined: %s" % shapes)
84  if not unknown_rank_allowed:
85    if any(shape.dims is None for shape in shapes):
86      raise ValueError("All shapes must have a defined rank: %s" % shapes)
87
88  return shapes
89
90
91def _as_name_list(names, dtypes):
92  if names is None:
93    return None
94  if not isinstance(names, (list, tuple)):
95    names = [names]
96  if len(names) != len(dtypes):
97    raise ValueError("List of names must have the same length as the list "
98                     "of dtypes")
99  return list(names)
100
101
102def _shape_common(s1, s2):
103  """The greatest lower bound (ordered by specificity) TensorShape."""
104  s1 = tensor_shape.TensorShape(s1)
105  s2 = tensor_shape.TensorShape(s2)
106  if s1.ndims is None or s2.ndims is None or s1.ndims != s2.ndims:
107    return tensor_shape.unknown_shape()
108  d = [
109      d1 if d1 is not None and d1 == d2 else None
110      for (d1, d2) in zip(s1.as_list(), s2.as_list())
111  ]
112  return tensor_shape.TensorShape(d)
113
114
115# pylint: disable=protected-access
116@tf_export("queue.QueueBase",
117           v1=["queue.QueueBase", "io.QueueBase", "QueueBase"])
118@deprecation.deprecated_endpoints(["io.QueueBase", "QueueBase"])
119class QueueBase(object):
120  """Base class for queue implementations.
121
122  A queue is a TensorFlow data structure that stores tensors across
123  multiple steps, and exposes operations that enqueue and dequeue
124  tensors.
125
126  Each queue element is a tuple of one or more tensors, where each
127  tuple component has a static dtype, and may have a static shape. The
128  queue implementations support versions of enqueue and dequeue that
129  handle single elements, versions that support enqueuing and
130  dequeuing a batch of elements at once.
131
132  See `tf.queue.FIFOQueue` and
133  `tf.queue.RandomShuffleQueue` for concrete
134  implementations of this class, and instructions on how to create
135  them.
136  """
137
138  def __init__(self, dtypes, shapes, names, queue_ref):
139    """Constructs a queue object from a queue reference.
140
141    The two optional lists, `shapes` and `names`, must be of the same length
142    as `dtypes` if provided.  The values at a given index `i` indicate the
143    shape and name to use for the corresponding queue component in `dtypes`.
144
145    Args:
146      dtypes:  A list of types.  The length of dtypes must equal the number
147        of tensors in each element.
148      shapes: Constraints on the shapes of tensors in an element:
149        A list of shape tuples or None. This list is the same length
150        as dtypes.  If the shape of any tensors in the element are constrained,
151        all must be; shapes can be None if the shapes should not be constrained.
152      names: Optional list of names.  If provided, the `enqueue()` and
153        `dequeue()` methods will use dictionaries with these names as keys.
154        Must be None or a list or tuple of the same length as `dtypes`.
155      queue_ref: The queue reference, i.e. the output of the queue op.
156
157    Raises:
158      ValueError: If one of the arguments is invalid.
159    """
160    self._dtypes = dtypes
161    if shapes is not None:
162      if len(shapes) != len(dtypes):
163        raise ValueError("Queue shapes must have the same length as dtypes")
164      self._shapes = [tensor_shape.TensorShape(s) for s in shapes]
165    else:
166      self._shapes = [tensor_shape.unknown_shape() for _ in self._dtypes]
167    if names is not None:
168      if len(names) != len(dtypes):
169        raise ValueError("Queue names must have the same length as dtypes")
170      self._names = names
171    else:
172      self._names = None
173    self._queue_ref = queue_ref
174    if isinstance(queue_ref, ops.EagerTensor):
175      if context.context().scope_name:
176        self._name = context.context().scope_name
177      else:
178        self._name = "Empty"
179      self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
180          queue_ref, None)
181    else:
182      self._name = self._queue_ref.op.name.split("/")[-1]
183
184  @staticmethod
185  def from_list(index, queues):
186    """Create a queue using the queue reference from `queues[index]`.
187
188    Args:
189      index: An integer scalar tensor that determines the input that gets
190        selected.
191      queues: A list of `QueueBase` objects.
192
193    Returns:
194      A `QueueBase` object.
195
196    Raises:
197      TypeError: When `queues` is not a list of `QueueBase` objects,
198        or when the data types of `queues` are not all the same.
199    """
200    if ((not queues) or (not isinstance(queues, list)) or
201        (not all(isinstance(x, QueueBase) for x in queues))):
202      raise TypeError("A list of queues expected")
203
204    dtypes = queues[0].dtypes
205    if not all(dtypes == q.dtypes for q in queues[1:]):
206      raise TypeError("Queues do not have matching component dtypes.")
207
208    names = queues[0].names
209    if not all(names == q.names for q in queues[1:]):
210      raise TypeError("Queues do not have matching component names.")
211
212    queue_shapes = [q.shapes for q in queues]
213    reduced_shapes = [
214        six.moves.reduce(_shape_common, s) for s in zip(*queue_shapes)
215    ]
216
217    queue_refs = array_ops.stack([x.queue_ref for x in queues])
218    selected_queue = array_ops.gather(queue_refs, index)
219    return QueueBase(
220        dtypes=dtypes,
221        shapes=reduced_shapes,
222        names=names,
223        queue_ref=selected_queue)
224
225  @property
226  def queue_ref(self):
227    """The underlying queue reference."""
228    return self._queue_ref
229
230  @property
231  def name(self):
232    """The name of the underlying queue."""
233    if context.executing_eagerly():
234      return self._name
235    return self._queue_ref.op.name
236
237  @property
238  def dtypes(self):
239    """The list of dtypes for each component of a queue element."""
240    return self._dtypes
241
242  @property
243  def shapes(self):
244    """The list of shapes for each component of a queue element."""
245    return self._shapes
246
247  @property
248  def names(self):
249    """The list of names for each component of a queue element."""
250    return self._names
251
252  def _check_enqueue_dtypes(self, vals):
253    """Validate and convert `vals` to a list of `Tensor`s.
254
255    The `vals` argument can be a Tensor, a list or tuple of tensors, or a
256    dictionary with tensor values.
257
258    If it is a dictionary, the queue must have been constructed with a
259    `names` attribute and the dictionary keys must match the queue names.
260    If the queue was constructed with a `names` attribute, `vals` must
261    be a dictionary.
262
263    Args:
264      vals: A tensor, a list or tuple of tensors, or a dictionary..
265
266    Returns:
267      A list of `Tensor` objects.
268
269    Raises:
270      ValueError: If `vals` is invalid.
271    """
272    if isinstance(vals, dict):
273      if not self._names:
274        raise ValueError("Queue must have names to enqueue a dictionary")
275      if sorted(self._names, key=str) != sorted(vals.keys(), key=str):
276        raise ValueError("Keys in dictionary to enqueue do not match "
277                         "names of Queue.  Dictionary: (%s), Queue: (%s)" %
278                         (sorted(vals.keys()), sorted(self._names)))
279      # The order of values in `self._names` indicates the order in which the
280      # tensors in the dictionary `vals` must be listed.
281      vals = [vals[k] for k in self._names]
282    else:
283      if self._names:
284        raise ValueError("You must enqueue a dictionary in a Queue with names")
285      if not isinstance(vals, (list, tuple)):
286        vals = [vals]
287
288    tensors = []
289    for i, (val, dtype) in enumerate(zip(vals, self._dtypes)):
290      tensors.append(
291          ops.convert_to_tensor(val, dtype=dtype, name="component_%d" % i))
292
293    return tensors
294
295  def _scope_vals(self, vals):
296    """Return a list of values to pass to `name_scope()`.
297
298    Args:
299      vals: A tensor, a list or tuple of tensors, or a dictionary.
300
301    Returns:
302      The values in vals as a list.
303    """
304    if isinstance(vals, (list, tuple)):
305      return vals
306    elif isinstance(vals, dict):
307      return vals.values()
308    else:
309      return [vals]
310
311  def enqueue(self, vals, name=None):
312    """Enqueues one element to this queue.
313
314    If the queue is full when this operation executes, it will block
315    until the element has been enqueued.
316
317    At runtime, this operation may raise an error if the queue is
318    `tf.QueueBase.close` before or during its execution. If the
319    queue is closed before this operation runs,
320    `tf.errors.CancelledError` will be raised. If this operation is
321    blocked, and either (i) the queue is closed by a close operation
322    with `cancel_pending_enqueues=True`, or (ii) the session is
323    `tf.Session.close`,
324    `tf.errors.CancelledError` will be raised.
325
326    Args:
327      vals: A tensor, a list or tuple of tensors, or a dictionary containing
328        the values to enqueue.
329      name: A name for the operation (optional).
330
331    Returns:
332      The operation that enqueues a new tuple of tensors to the queue.
333    """
334    with ops.name_scope(name, "%s_enqueue" % self._name,
335                        self._scope_vals(vals)) as scope:
336      vals = self._check_enqueue_dtypes(vals)
337
338      # NOTE(mrry): Not using a shape function because we need access to
339      # the `QueueBase` object.
340      for val, shape in zip(vals, self._shapes):
341        val.get_shape().assert_is_compatible_with(shape)
342
343      if self._queue_ref.dtype == _dtypes.resource:
344        return gen_data_flow_ops.queue_enqueue_v2(
345            self._queue_ref, vals, name=scope)
346      else:
347        return gen_data_flow_ops.queue_enqueue(
348            self._queue_ref, vals, name=scope)
349
350  def enqueue_many(self, vals, name=None):
351    """Enqueues zero or more elements to this queue.
352
353    This operation slices each component tensor along the 0th dimension to
354    make multiple queue elements. All of the tensors in `vals` must have the
355    same size in the 0th dimension.
356
357    If the queue is full when this operation executes, it will block
358    until all of the elements have been enqueued.
359
360    At runtime, this operation may raise an error if the queue is
361    `tf.QueueBase.close` before or during its execution. If the
362    queue is closed before this operation runs,
363    `tf.errors.CancelledError` will be raised. If this operation is
364    blocked, and either (i) the queue is closed by a close operation
365    with `cancel_pending_enqueues=True`, or (ii) the session is
366    `tf.Session.close`,
367    `tf.errors.CancelledError` will be raised.
368
369    Args:
370      vals: A tensor, a list or tuple of tensors, or a dictionary
371        from which the queue elements are taken.
372      name: A name for the operation (optional).
373
374    Returns:
375      The operation that enqueues a batch of tuples of tensors to the queue.
376    """
377    with ops.name_scope(name, "%s_EnqueueMany" % self._name,
378                        self._scope_vals(vals)) as scope:
379      vals = self._check_enqueue_dtypes(vals)
380
381      # NOTE(mrry): Not using a shape function because we need access to
382      # the `QueueBase` object.
383      # NOTE(fchollet): the code that follow is verbose because it needs to be
384      # compatible with both TF v1 TensorShape behavior and TF v2 behavior.
385      batch_dim = tensor_shape.dimension_value(
386          vals[0].get_shape().with_rank_at_least(1)[0])
387      batch_dim = tensor_shape.Dimension(batch_dim)
388      for val, shape in zip(vals, self._shapes):
389        val_batch_dim = tensor_shape.dimension_value(
390            val.get_shape().with_rank_at_least(1)[0])
391        val_batch_dim = tensor_shape.Dimension(val_batch_dim)
392        batch_dim = batch_dim.merge_with(val_batch_dim)
393        val.get_shape()[1:].assert_is_compatible_with(shape)
394
395      return gen_data_flow_ops.queue_enqueue_many_v2(
396          self._queue_ref, vals, name=scope)
397
398  def _dequeue_return_value(self, tensors):
399    """Return the value to return from a dequeue op.
400
401    If the queue has names, return a dictionary with the
402    names as keys.  Otherwise return either a single tensor
403    or a list of tensors depending on the length of `tensors`.
404
405    Args:
406      tensors: List of tensors from the dequeue op.
407
408    Returns:
409      A single tensor, a list of tensors, or a dictionary
410      of tensors.
411    """
412    if self._names:
413      # The returned values in `tensors` are in the same order as
414      # the names in `self._names`.
415      return {n: tensors[i] for i, n in enumerate(self._names)}
416    elif len(tensors) == 1:
417      return tensors[0]
418    else:
419      return tensors
420
421  def dequeue(self, name=None):
422    """Dequeues one element from this queue.
423
424    If the queue is empty when this operation executes, it will block
425    until there is an element to dequeue.
426
427    At runtime, this operation may raise an error if the queue is
428    `tf.QueueBase.close` before or during its execution. If the
429    queue is closed, the queue is empty, and there are no pending
430    enqueue operations that can fulfill this request,
431    `tf.errors.OutOfRangeError` will be raised. If the session is
432    `tf.Session.close`,
433    `tf.errors.CancelledError` will be raised.
434
435    Args:
436      name: A name for the operation (optional).
437
438    Returns:
439      The tuple of tensors that was dequeued.
440    """
441    if name is None:
442      name = "%s_Dequeue" % self._name
443    if self._queue_ref.dtype == _dtypes.resource:
444      ret = gen_data_flow_ops.queue_dequeue_v2(
445          self._queue_ref, self._dtypes, name=name)
446    else:
447      ret = gen_data_flow_ops.queue_dequeue(
448          self._queue_ref, self._dtypes, name=name)
449
450    # NOTE(mrry): Not using a shape function because we need access to
451    # the `QueueBase` object.
452    if not context.executing_eagerly():
453      op = ret[0].op
454      for output, shape in zip(op.values(), self._shapes):
455        output.set_shape(shape)
456
457    return self._dequeue_return_value(ret)
458
459  def dequeue_many(self, n, name=None):
460    """Dequeues and concatenates `n` elements from this queue.
461
462    This operation concatenates queue-element component tensors along
463    the 0th dimension to make a single component tensor.  All of the
464    components in the dequeued tuple will have size `n` in the 0th dimension.
465
466    If the queue is closed and there are less than `n` elements left, then an
467    `OutOfRange` exception is raised.
468
469    At runtime, this operation may raise an error if the queue is
470    `tf.QueueBase.close` before or during its execution. If the
471    queue is closed, the queue contains fewer than `n` elements, and
472    there are no pending enqueue operations that can fulfill this
473    request, `tf.errors.OutOfRangeError` will be raised. If the
474    session is `tf.Session.close`,
475    `tf.errors.CancelledError` will be raised.
476
477    Args:
478      n: A scalar `Tensor` containing the number of elements to dequeue.
479      name: A name for the operation (optional).
480
481    Returns:
482      The list of concatenated tensors that was dequeued.
483    """
484    if name is None:
485      name = "%s_DequeueMany" % self._name
486
487    ret = gen_data_flow_ops.queue_dequeue_many_v2(
488        self._queue_ref, n=n, component_types=self._dtypes, name=name)
489
490    # NOTE(mrry): Not using a shape function because we need access to
491    # the Queue object.
492    if not context.executing_eagerly():
493      op = ret[0].op
494      batch_dim = tensor_shape.Dimension(
495          tensor_util.constant_value(op.inputs[1]))
496      for output, shape in zip(op.values(), self._shapes):
497        output.set_shape(
498            tensor_shape.TensorShape([batch_dim]).concatenate(shape))
499
500    return self._dequeue_return_value(ret)
501
502  def dequeue_up_to(self, n, name=None):
503    """Dequeues and concatenates `n` elements from this queue.
504
505    **Note** This operation is not supported by all queues.  If a queue does not
506    support DequeueUpTo, then a `tf.errors.UnimplementedError` is raised.
507
508    This operation concatenates queue-element component tensors along
509    the 0th dimension to make a single component tensor. If the queue
510    has not been closed, all of the components in the dequeued tuple
511    will have size `n` in the 0th dimension.
512
513    If the queue is closed and there are more than `0` but fewer than
514    `n` elements remaining, then instead of raising a
515    `tf.errors.OutOfRangeError` like `tf.QueueBase.dequeue_many`,
516    less than `n` elements are returned immediately.  If the queue is
517    closed and there are `0` elements left in the queue, then a
518    `tf.errors.OutOfRangeError` is raised just like in `dequeue_many`.
519    Otherwise the behavior is identical to `dequeue_many`.
520
521    Args:
522      n: A scalar `Tensor` containing the number of elements to dequeue.
523      name: A name for the operation (optional).
524
525    Returns:
526      The tuple of concatenated tensors that was dequeued.
527    """
528    if name is None:
529      name = "%s_DequeueUpTo" % self._name
530
531    ret = gen_data_flow_ops.queue_dequeue_up_to_v2(
532        self._queue_ref, n=n, component_types=self._dtypes, name=name)
533
534    # NOTE(mrry): Not using a shape function because we need access to
535    # the Queue object.
536    if not context.executing_eagerly():
537      op = ret[0].op
538      for output, shape in zip(op.values(), self._shapes):
539        output.set_shape(tensor_shape.TensorShape([None]).concatenate(shape))
540
541    return self._dequeue_return_value(ret)
542
543  def close(self, cancel_pending_enqueues=False, name=None):
544    """Closes this queue.
545
546    This operation signals that no more elements will be enqueued in
547    the given queue. Subsequent `enqueue` and `enqueue_many`
548    operations will fail. Subsequent `dequeue` and `dequeue_many`
549    operations will continue to succeed if sufficient elements remain
550    in the queue. Subsequently dequeue and dequeue_many operations
551    that would otherwise block waiting for more elements (if close
552    hadn't been called) will now fail immediately.
553
554    If `cancel_pending_enqueues` is `True`, all pending requests will also
555    be canceled.
556
557    Args:
558      cancel_pending_enqueues: (Optional.) A boolean, defaulting to
559        `False` (described above).
560      name: A name for the operation (optional).
561
562    Returns:
563      The operation that closes the queue.
564    """
565    if name is None:
566      name = "%s_Close" % self._name
567    if self._queue_ref.dtype == _dtypes.resource:
568      return gen_data_flow_ops.queue_close_v2(
569          self._queue_ref,
570          cancel_pending_enqueues=cancel_pending_enqueues,
571          name=name)
572    else:
573      return gen_data_flow_ops.queue_close(
574          self._queue_ref,
575          cancel_pending_enqueues=cancel_pending_enqueues,
576          name=name)
577
578  def is_closed(self, name=None):
579    """Returns true if queue is closed.
580
581    This operation returns true if the queue is closed and false if the queue
582    is open.
583
584    Args:
585      name: A name for the operation (optional).
586
587    Returns:
588      True if the queue is closed and false if the queue is open.
589    """
590    if name is None:
591      name = "%s_Is_Closed" % self._name
592    if self._queue_ref.dtype == _dtypes.resource:
593      return gen_data_flow_ops.queue_is_closed_v2(self._queue_ref, name=name)
594    else:
595      return gen_data_flow_ops.queue_is_closed_(self._queue_ref, name=name)
596
597  def size(self, name=None):
598    """Compute the number of elements in this queue.
599
600    Args:
601      name: A name for the operation (optional).
602
603    Returns:
604      A scalar tensor containing the number of elements in this queue.
605    """
606    if name is None:
607      name = "%s_Size" % self._name
608    if self._queue_ref.dtype == _dtypes.resource:
609      return gen_data_flow_ops.queue_size_v2(self._queue_ref, name=name)
610    else:
611      return gen_data_flow_ops.queue_size(self._queue_ref, name=name)
612
613def _shared_name(shared_name):
614  if context.executing_eagerly():
615    return str(ops.uid())
616  return shared_name
617
618
619@tf_export(
620    "queue.RandomShuffleQueue",
621    v1=["queue.RandomShuffleQueue",
622        "io.RandomShuffleQueue", "RandomShuffleQueue"])
623@deprecation.deprecated_endpoints(
624    ["io.RandomShuffleQueue", "RandomShuffleQueue"])
625class RandomShuffleQueue(QueueBase):
626  """A queue implementation that dequeues elements in a random order.
627
628  See `tf.queue.QueueBase` for a description of the methods on
629  this class.
630  """
631
632  def __init__(self,
633               capacity,
634               min_after_dequeue,
635               dtypes,
636               shapes=None,
637               names=None,
638               seed=None,
639               shared_name=None,
640               name="random_shuffle_queue"):
641    """Create a queue that dequeues elements in a random order.
642
643    A `RandomShuffleQueue` has bounded capacity; supports multiple
644    concurrent producers and consumers; and provides exactly-once
645    delivery.
646
647    A `RandomShuffleQueue` holds a list of up to `capacity`
648    elements. Each element is a fixed-length tuple of tensors whose
649    dtypes are described by `dtypes`, and whose shapes are optionally
650    described by the `shapes` argument.
651
652    If the `shapes` argument is specified, each component of a queue
653    element must have the respective fixed shape. If it is
654    unspecified, different queue elements may have different shapes,
655    but the use of `dequeue_many` is disallowed.
656
657    The `min_after_dequeue` argument allows the caller to specify a
658    minimum number of elements that will remain in the queue after a
659    `dequeue` or `dequeue_many` operation completes, to ensure a
660    minimum level of mixing of elements. This invariant is maintained
661    by blocking those operations until sufficient elements have been
662    enqueued. The `min_after_dequeue` argument is ignored after the
663    queue has been closed.
664
665    Args:
666      capacity: An integer. The upper bound on the number of elements
667        that may be stored in this queue.
668      min_after_dequeue: An integer (described above).
669      dtypes:  A list of `DType` objects. The length of `dtypes` must equal
670        the number of tensors in each queue element.
671      shapes: (Optional.) A list of fully-defined `TensorShape` objects
672        with the same length as `dtypes`, or `None`.
673      names: (Optional.) A list of string naming the components in the queue
674        with the same length as `dtypes`, or `None`.  If specified the dequeue
675        methods return a dictionary with the names as keys.
676      seed: A Python integer. Used to create a random seed. See
677        `tf.compat.v1.set_random_seed`
678        for behavior.
679      shared_name: (Optional.) If non-empty, this queue will be shared under
680        the given name across multiple sessions.
681      name: Optional name for the queue operation.
682    """
683    dtypes = _as_type_list(dtypes)
684    shapes = _as_shape_list(shapes, dtypes)
685    names = _as_name_list(names, dtypes)
686    seed1, seed2 = random_seed.get_seed(seed)
687    if seed1 is None and seed2 is None:
688      seed1, seed2 = 0, 0
689    elif seed is None and shared_name is not None:
690      # This means that graph seed is provided but op seed is not provided.
691      # If shared_name is also provided, make seed2 depend only on the graph
692      # seed and shared_name. (seed2 from get_seed() is generally dependent on
693      # the id of the last op created.)
694      string = (str(seed1) + shared_name).encode("utf-8")
695      seed2 = int(hashlib.md5(string).hexdigest()[:8], 16) & 0x7FFFFFFF
696    queue_ref = gen_data_flow_ops.random_shuffle_queue_v2(
697        component_types=dtypes,
698        shapes=shapes,
699        capacity=capacity,
700        min_after_dequeue=min_after_dequeue,
701        seed=seed1,
702        seed2=seed2,
703        shared_name=_shared_name(shared_name),
704        name=name)
705
706    super(RandomShuffleQueue, self).__init__(dtypes, shapes, names, queue_ref)
707
708
709@tf_export("queue.FIFOQueue", v1=["queue.FIFOQueue", "FIFOQueue"])
710@deprecation.deprecated_endpoints("FIFOQueue")
711class FIFOQueue(QueueBase):
712  """A queue implementation that dequeues elements in first-in first-out order.
713
714  See `tf.queue.QueueBase` for a description of the methods on
715  this class.
716  """
717
718  def __init__(self,
719               capacity,
720               dtypes,
721               shapes=None,
722               names=None,
723               shared_name=None,
724               name="fifo_queue"):
725    """Creates a queue that dequeues elements in a first-in first-out order.
726
727    A `FIFOQueue` has bounded capacity; supports multiple concurrent
728    producers and consumers; and provides exactly-once delivery.
729
730    A `FIFOQueue` holds a list of up to `capacity` elements. Each
731    element is a fixed-length tuple of tensors whose dtypes are
732    described by `dtypes`, and whose shapes are optionally described
733    by the `shapes` argument.
734
735    If the `shapes` argument is specified, each component of a queue
736    element must have the respective fixed shape. If it is
737    unspecified, different queue elements may have different shapes,
738    but the use of `dequeue_many` is disallowed.
739
740    Args:
741      capacity: An integer. The upper bound on the number of elements
742        that may be stored in this queue.
743      dtypes:  A list of `DType` objects. The length of `dtypes` must equal
744        the number of tensors in each queue element.
745      shapes: (Optional.) A list of fully-defined `TensorShape` objects
746        with the same length as `dtypes`, or `None`.
747      names: (Optional.) A list of string naming the components in the queue
748        with the same length as `dtypes`, or `None`.  If specified the dequeue
749        methods return a dictionary with the names as keys.
750      shared_name: (Optional.) If non-empty, this queue will be shared under
751        the given name across multiple sessions.
752      name: Optional name for the queue operation.
753    """
754    dtypes = _as_type_list(dtypes)
755    shapes = _as_shape_list(shapes, dtypes)
756    names = _as_name_list(names, dtypes)
757    with ops.init_scope(), ops.device("CPU"):
758      queue_ref = gen_data_flow_ops.fifo_queue_v2(
759          component_types=dtypes,
760          shapes=shapes,
761          capacity=capacity,
762          shared_name=_shared_name(shared_name),
763          name=name)
764
765    super(FIFOQueue, self).__init__(dtypes, shapes, names, queue_ref)
766
767
768# TODO(allenl): If GPU-compatible queues turn out to be useful, we should
769# implement GPU kernels for EnqueueMany and DequeueMany so we can make the
770# public FIFOQueue GPU-compatible and remove this internal version.
771class GPUCompatibleFIFOQueue(QueueBase):
772  """A queue implementation that dequeues elements in first-in first-out order.
773
774  GPUCompatibleFIFOQueue is like FIFOQueue, but the queue resource may be placed
775  either on a CPU or on a GPU. It is not cross-device: enqueues and dequeues
776  will be colocated with the queue resource. GPUCompatibleFIFOQueue only
777  supports enqueue and dequeue at the moment, not enqueue_many or dequeue_many.
778
779  See `tf.queue.QueueBase` for a description of the methods on this class.
780  """
781
782  def __init__(self,
783               capacity,
784               dtypes,
785               shapes=None,
786               names=None,
787               shared_name=None,
788               name="fifo_queue"):
789    """Creates a queue that dequeues elements in a first-in first-out order.
790
791    A `FIFOQueue` has bounded capacity; supports multiple concurrent
792    producers and consumers; and provides exactly-once delivery.
793
794    A `FIFOQueue` holds a list of up to `capacity` elements. Each
795    element is a fixed-length tuple of tensors whose dtypes are
796    described by `dtypes`, and whose shapes are optionally described
797    by the `shapes` argument.
798
799    If the `shapes` argument is specified, each component of a queue
800    element must have the respective fixed shape. If it is
801    unspecified, different queue elements may have different shapes,
802    but the use of `dequeue_many` is disallowed.
803
804    Args:
805      capacity: An integer. The upper bound on the number of elements
806        that may be stored in this queue.
807      dtypes:  A list of `DType` objects. The length of `dtypes` must equal
808        the number of tensors in each queue element.
809      shapes: (Optional.) A list of fully-defined `TensorShape` objects
810        with the same length as `dtypes`, or `None`.
811      names: (Optional.) A list of string naming the components in the queue
812        with the same length as `dtypes`, or `None`.  If specified the dequeue
813        methods return a dictionary with the names as keys.
814      shared_name: (Optional.) If non-empty, this queue will be shared under
815        the given name across multiple sessions.
816      name: Optional name for the queue operation.
817    """
818    dtypes = _as_type_list(dtypes)
819    shapes = _as_shape_list(shapes, dtypes)
820    names = _as_name_list(names, dtypes)
821    with ops.init_scope():
822      queue_ref = gen_data_flow_ops.fifo_queue_v2(
823          component_types=dtypes,
824          shapes=shapes,
825          capacity=capacity,
826          shared_name=_shared_name(shared_name),
827          name=name)
828
829    super(GPUCompatibleFIFOQueue, self).__init__(
830        dtypes, shapes, names, queue_ref)
831
832  def enqueue_many(self, vals, name=None):
833    """enqueue_many is not supported on GPUCompatibleFIFOQueue."""
834    raise NotImplementedError(
835        "GPUCompatibleFIFOQueue does not support enqueue_many or dequeue_many, "
836        "only enqueue and dequeue.")
837
838  def dequeue_many(self, n, name=None):
839    """dequeue_many is not supported on GPUCompatibleFIFOQueue."""
840    raise NotImplementedError(
841        "GPUCompatibleFIFOQueue does not support enqueue_many or dequeue_many, "
842        "only enqueue and dequeue.")
843
844
845@tf_export(
846    "queue.PaddingFIFOQueue",
847    v1=["queue.PaddingFIFOQueue", "io.PaddingFIFOQueue", "PaddingFIFOQueue"])
848@deprecation.deprecated_endpoints(["io.PaddingFIFOQueue", "PaddingFIFOQueue"])
849class PaddingFIFOQueue(QueueBase):
850  """A FIFOQueue that supports batching variable-sized tensors by padding.
851
852  A `PaddingFIFOQueue` may contain components with dynamic shape, while also
853  supporting `dequeue_many`.  See the constructor for more details.
854
855  See `tf.queue.QueueBase` for a description of the methods on
856  this class.
857  """
858
859  def __init__(self,
860               capacity,
861               dtypes,
862               shapes,
863               names=None,
864               shared_name=None,
865               name="padding_fifo_queue"):
866    """Creates a queue that dequeues elements in a first-in first-out order.
867
868    A `PaddingFIFOQueue` has bounded capacity; supports multiple concurrent
869    producers and consumers; and provides exactly-once delivery.
870
871    A `PaddingFIFOQueue` holds a list of up to `capacity` elements. Each
872    element is a fixed-length tuple of tensors whose dtypes are
873    described by `dtypes`, and whose shapes are described by the `shapes`
874    argument.
875
876    The `shapes` argument must be specified; each component of a queue
877    element must have the respective shape.  Shapes of fixed
878    rank but variable size are allowed by setting any shape dimension to None.
879    In this case, the inputs' shape may vary along the given dimension, and
880    `dequeue_many` will pad the given dimension with zeros up to the maximum
881    shape of all elements in the given batch.
882
883    Args:
884      capacity: An integer. The upper bound on the number of elements
885        that may be stored in this queue.
886      dtypes:  A list of `DType` objects. The length of `dtypes` must equal
887        the number of tensors in each queue element.
888      shapes: A list of `TensorShape` objects, with the same length as
889        `dtypes`.  Any dimension in the `TensorShape` containing value
890        `None` is dynamic and allows values to be enqueued with
891         variable size in that dimension.
892      names: (Optional.) A list of string naming the components in the queue
893        with the same length as `dtypes`, or `None`.  If specified the dequeue
894        methods return a dictionary with the names as keys.
895      shared_name: (Optional.) If non-empty, this queue will be shared under
896        the given name across multiple sessions.
897      name: Optional name for the queue operation.
898
899    Raises:
900      ValueError: If shapes is not a list of shapes, or the lengths of dtypes
901        and shapes do not match, or if names is specified and the lengths of
902        dtypes and names do not match.
903    """
904    dtypes = _as_type_list(dtypes)
905    shapes = _as_shape_list(shapes, dtypes, unknown_dim_allowed=True)
906    names = _as_name_list(names, dtypes)
907    if len(dtypes) != len(shapes):
908      raise ValueError("Shapes must be provided for all components, "
909                       "but received %d dtypes and %d shapes." % (len(dtypes),
910                                                                  len(shapes)))
911
912    queue_ref = gen_data_flow_ops.padding_fifo_queue_v2(
913        component_types=dtypes,
914        shapes=shapes,
915        capacity=capacity,
916        shared_name=_shared_name(shared_name),
917        name=name)
918
919    super(PaddingFIFOQueue, self).__init__(dtypes, shapes, names, queue_ref)
920
921
922@tf_export("queue.PriorityQueue",
923           v1=["queue.PriorityQueue", "io.PriorityQueue", "PriorityQueue"])
924@deprecation.deprecated_endpoints(["io.PriorityQueue", "PriorityQueue"])
925class PriorityQueue(QueueBase):
926  """A queue implementation that dequeues elements in prioritized order.
927
928  See `tf.queue.QueueBase` for a description of the methods on
929  this class.
930  """
931
932  def __init__(self,
933               capacity,
934               types,
935               shapes=None,
936               names=None,
937               shared_name=None,
938               name="priority_queue"):
939    """Creates a queue that dequeues elements in a first-in first-out order.
940
941    A `PriorityQueue` has bounded capacity; supports multiple concurrent
942    producers and consumers; and provides exactly-once delivery.
943
944    A `PriorityQueue` holds a list of up to `capacity` elements. Each
945    element is a fixed-length tuple of tensors whose dtypes are
946    described by `types`, and whose shapes are optionally described
947    by the `shapes` argument.
948
949    If the `shapes` argument is specified, each component of a queue
950    element must have the respective fixed shape. If it is
951    unspecified, different queue elements may have different shapes,
952    but the use of `dequeue_many` is disallowed.
953
954    Enqueues and Dequeues to the `PriorityQueue` must include an additional
955    tuple entry at the beginning: the `priority`.  The priority must be
956    an int64 scalar (for `enqueue`) or an int64 vector (for `enqueue_many`).
957
958    Args:
959      capacity: An integer. The upper bound on the number of elements
960        that may be stored in this queue.
961      types:  A list of `DType` objects. The length of `types` must equal
962        the number of tensors in each queue element, except the first priority
963        element.  The first tensor in each element is the priority,
964        which must be type int64.
965      shapes: (Optional.) A list of fully-defined `TensorShape` objects,
966        with the same length as `types`, or `None`.
967      names: (Optional.) A list of strings naming the components in the queue
968        with the same length as `dtypes`, or `None`.  If specified, the dequeue
969        methods return a dictionary with the names as keys.
970      shared_name: (Optional.) If non-empty, this queue will be shared under
971        the given name across multiple sessions.
972      name: Optional name for the queue operation.
973    """
974    types = _as_type_list(types)
975    shapes = _as_shape_list(shapes, types)
976
977    queue_ref = gen_data_flow_ops.priority_queue_v2(
978        component_types=types,
979        shapes=shapes,
980        capacity=capacity,
981        shared_name=_shared_name(shared_name),
982        name=name)
983
984    priority_dtypes = [_dtypes.int64] + types
985    priority_shapes = [()] + shapes if shapes else shapes
986
987    super(PriorityQueue, self).__init__(priority_dtypes, priority_shapes, names,
988                                        queue_ref)
989
990
991# TODO(josh11b): class BatchQueue(QueueBase):
992
993
994class Barrier(object):
995  """Represents a key-value map that persists across graph executions."""
996
997  def __init__(self, types, shapes=None, shared_name=None, name="barrier"):
998    """Creates a barrier that persists across different graph executions.
999
1000    A barrier represents a key-value map, where each key is a string, and
1001    each value is a tuple of tensors.
1002
1003    At runtime, the barrier contains 'complete' and 'incomplete'
1004    elements. A complete element has defined tensors for all
1005    components of its value tuple, and may be accessed using
1006    take_many. An incomplete element has some undefined components in
1007    its value tuple, and may be updated using insert_many.
1008
1009    The barrier call `take_many` outputs values in a particular order.
1010    First, it only outputs completed values.  Second, the order in which
1011    completed values are returned matches the order in which their very
1012    first component was inserted into the barrier.  So, for example, for this
1013    sequence of insertions and removals:
1014
1015      barrier = Barrier((tf.string, tf.int32), shapes=((), ()))
1016      barrier.insert_many(0, keys=["k1", "k2"], values=["a", "b"]).run()
1017      barrier.insert_many(1, keys=["k1"], values=[1]).run()
1018      barrier.insert_many(0, keys=["k3"], values=["c"]).run()
1019      barrier.insert_many(1, keys=["k3"], values=[3]).run()
1020      barrier.insert_many(1, keys=["k2"], values=[2]).run()
1021
1022      (indices, keys, values) = barrier.take_many(2)
1023      (indices_val, keys_val, values0_val, values1_val) =
1024         session.run([indices, keys, values[0], values[1]])
1025
1026    The output will be (up to permutation of "k1" and "k2"):
1027
1028      indices_val == (-2**63, -2**63)
1029      keys_val == ("k1", "k2")
1030      values0_val == ("a", "b")
1031      values1_val == (1, 2)
1032
1033    Note the key "k2" was inserted into the barrier before "k3".  Even though
1034    "k3" was completed first, both are complete by the time
1035    take_many is called.  As a result, "k2" is prioritized and "k1" and "k2"
1036    are returned first.  "k3" remains in the barrier until the next execution
1037    of `take_many`.  Since "k1" and "k2" had their first insertions into
1038    the barrier together, their indices are the same (-2**63).  The index
1039    of "k3" will be -2**63 + 1, because it was the next new inserted key.
1040
1041    Args:
1042      types: A single dtype or a tuple of dtypes, corresponding to the
1043        dtypes of the tensor elements that comprise a value in this barrier.
1044      shapes: Optional. Constraints on the shapes of tensors in the values:
1045        a single tensor shape tuple; a tuple of tensor shape tuples
1046        for each barrier-element tuple component; or None if the shape should
1047        not be constrained.
1048      shared_name: Optional. If non-empty, this barrier will be shared under
1049        the given name across multiple sessions.
1050      name: Optional name for the barrier op.
1051
1052    Raises:
1053      ValueError: If one of the `shapes` indicate no elements.
1054    """
1055    self._types = _as_type_list(types)
1056
1057    if shapes is not None:
1058      shapes = _as_shape_list(shapes, self._types)
1059      self._shapes = [tensor_shape.TensorShape(s) for s in shapes]
1060      for i, shape in enumerate(self._shapes):
1061        if shape.num_elements() == 0:
1062          raise ValueError("Empty tensors are not supported, but received "
1063                           "shape '%s' at index %d" % (shape, i))
1064    else:
1065      self._shapes = [tensor_shape.unknown_shape() for _ in self._types]
1066
1067    self._barrier_ref = gen_data_flow_ops.barrier(
1068        component_types=self._types,
1069        shapes=self._shapes,
1070        shared_name=shared_name,
1071        name=name)
1072    if context.executing_eagerly():
1073      self._name = context.context().scope_name
1074    else:
1075      self._name = self._barrier_ref.op.name.split("/")[-1]
1076
1077  @property
1078  def barrier_ref(self):
1079    """Get the underlying barrier reference."""
1080    return self._barrier_ref
1081
1082  @property
1083  def name(self):
1084    """The name of the underlying barrier."""
1085    if context.executing_eagerly():
1086      return self._name
1087    return self._barrier_ref.op.name
1088
1089  def insert_many(self, component_index, keys, values, name=None):
1090    """For each key, assigns the respective value to the specified component.
1091
1092    This operation updates each element at component_index.
1093
1094    Args:
1095      component_index: The component of the value that is being assigned.
1096      keys: A vector of keys, with length n.
1097      values: An any-dimensional tensor of values, which are associated with the
1098        respective keys. The first dimension must have length n.
1099      name: Optional name for the op.
1100
1101    Returns:
1102      The operation that performs the insertion.
1103    Raises:
1104      InvalidArgumentsError: If inserting keys and values without elements.
1105    """
1106    if name is None:
1107      name = "%s_BarrierInsertMany" % self._name
1108    return gen_data_flow_ops.barrier_insert_many(
1109        self._barrier_ref, keys, values, component_index, name=name)
1110
1111  def take_many(self,
1112                num_elements,
1113                allow_small_batch=False,
1114                timeout=None,
1115                name=None):
1116    """Takes the given number of completed elements from this barrier.
1117
1118    This operation concatenates completed-element component tensors along
1119    the 0th dimension to make a single component tensor.
1120
1121    If barrier has no completed elements, this operation will block
1122    until there are 'num_elements' elements to take.
1123
1124    TODO(b/25743580): the semantics of `allow_small_batch` are experimental
1125    and may be extended to other cases in the future.
1126
1127    TODO(ebrevdo): If a take_many(allow_small_batch=True) is blocking
1128    already when the barrier is closed, it will block for ever. Fix this
1129    by using asynchronous operations.
1130
1131    Args:
1132      num_elements: The number of elements to take.
1133      allow_small_batch: If the barrier is closed, don't block if there are less
1134        completed elements than requested, but instead return all available
1135        completed elements.
1136      timeout: This specifies the number of milliseconds to block
1137        before returning with DEADLINE_EXCEEDED. (This option is not
1138        supported yet.)
1139      name: A name for the operation (optional).
1140
1141    Returns:
1142      A tuple of (index, key, value_list).
1143      "index" is a int64 tensor of length num_elements containing the
1144        index of the insert_many call for which the very first component of
1145        the given element was inserted into the Barrier, starting with
1146        the value -2**63.  Note, this value is different from the
1147        index of the insert_many call for which the element was completed.
1148      "key" is a string tensor of length num_elements containing the keys.
1149      "value_list" is a tuple of tensors, each one with size num_elements
1150        in the 0th dimension for each component in the barrier's values.
1151
1152    """
1153    if name is None:
1154      name = "%s_BarrierTakeMany" % self._name
1155    ret = gen_data_flow_ops.barrier_take_many(
1156        self._barrier_ref,
1157        num_elements,
1158        self._types,
1159        allow_small_batch,
1160        timeout,
1161        name=name)
1162
1163    # NOTE(mrry): Not using a shape function because we need access to
1164    # the Barrier object.
1165    if not context.executing_eagerly():
1166      op = ret[0].op
1167      if allow_small_batch:
1168        batch_dim = None
1169      else:
1170        batch_dim = tensor_shape.Dimension(
1171            tensor_util.constant_value(op.inputs[1]))
1172      op.outputs[0].set_shape(tensor_shape.TensorShape([batch_dim]))  # indices
1173      op.outputs[1].set_shape(tensor_shape.TensorShape([batch_dim]))  # keys
1174      for output, shape in zip(op.outputs[2:], self._shapes):  # value_list
1175        output.set_shape(
1176            tensor_shape.TensorShape([batch_dim]).concatenate(shape))
1177
1178    return ret
1179
1180  def close(self, cancel_pending_enqueues=False, name=None):
1181    """Closes this barrier.
1182
1183    This operation signals that no more new key values will be inserted in the
1184    given barrier. Subsequent InsertMany operations with new keys will fail.
1185    InsertMany operations that just complement already existing keys with other
1186    components, will continue to succeed. Subsequent TakeMany operations will
1187    continue to succeed if sufficient elements remain in the barrier. Subsequent
1188    TakeMany operations that would block will fail immediately.
1189
1190    If `cancel_pending_enqueues` is `True`, all pending requests to the
1191    underlying queue will also be canceled, and completing of already
1192    started values is also not acceptable anymore.
1193
1194    Args:
1195      cancel_pending_enqueues: (Optional.) A boolean, defaulting to
1196        `False` (described above).
1197      name: Optional name for the op.
1198
1199    Returns:
1200      The operation that closes the barrier.
1201    """
1202    if name is None:
1203      name = "%s_BarrierClose" % self._name
1204    return gen_data_flow_ops.barrier_close(
1205        self._barrier_ref,
1206        cancel_pending_enqueues=cancel_pending_enqueues,
1207        name=name)
1208
1209  def ready_size(self, name=None):
1210    """Compute the number of complete elements in the given barrier.
1211
1212    Args:
1213      name: A name for the operation (optional).
1214
1215    Returns:
1216      A single-element tensor containing the number of complete elements in the
1217      given barrier.
1218    """
1219    if name is None:
1220      name = "%s_BarrierReadySize" % self._name
1221    return gen_data_flow_ops.barrier_ready_size(self._barrier_ref, name=name)
1222
1223  def incomplete_size(self, name=None):
1224    """Compute the number of incomplete elements in the given barrier.
1225
1226    Args:
1227      name: A name for the operation (optional).
1228
1229    Returns:
1230      A single-element tensor containing the number of incomplete elements in
1231      the given barrier.
1232    """
1233    if name is None:
1234      name = "%s_BarrierIncompleteSize" % self._name
1235    return gen_data_flow_ops.barrier_incomplete_size(
1236        self._barrier_ref, name=name)
1237
1238
1239@tf_export(v1=["ConditionalAccumulatorBase"])
1240class ConditionalAccumulatorBase(object):
1241  """A conditional accumulator for aggregating gradients.
1242
1243  Up-to-date gradients (i.e., time step at which gradient was computed is
1244  equal to the accumulator's time step) are added to the accumulator.
1245
1246  Extraction of the average gradient is blocked until the required number of
1247  gradients has been accumulated.
1248  """
1249
1250  def __init__(self, dtype, shape, accumulator_ref):
1251    """Creates a new ConditionalAccumulator.
1252
1253    Args:
1254      dtype: Datatype of the accumulated gradients.
1255      shape: Shape of the accumulated gradients.
1256      accumulator_ref: A handle to the conditional accumulator, created by sub-
1257        classes
1258    """
1259    self._dtype = dtype
1260    if shape is not None:
1261      self._shape = tensor_shape.TensorShape(shape)
1262    else:
1263      self._shape = tensor_shape.unknown_shape()
1264    self._accumulator_ref = accumulator_ref
1265    if context.executing_eagerly():
1266      self._name = context.context().scope_name
1267    else:
1268      self._name = self._accumulator_ref.op.name.split("/")[-1]
1269
1270  @property
1271  def accumulator_ref(self):
1272    """The underlying accumulator reference."""
1273    return self._accumulator_ref
1274
1275  @property
1276  def name(self):
1277    """The name of the underlying accumulator."""
1278    return self._name
1279
1280  @property
1281  def dtype(self):
1282    """The datatype of the gradients accumulated by this accumulator."""
1283    return self._dtype
1284
1285  def num_accumulated(self, name=None):
1286    """Number of gradients that have currently been aggregated in accumulator.
1287
1288    Args:
1289      name: Optional name for the operation.
1290
1291    Returns:
1292      Number of accumulated gradients currently in accumulator.
1293    """
1294    if name is None:
1295      name = "%s_NumAccumulated" % self._name
1296
1297    return gen_data_flow_ops.resource_accumulator_num_accumulated(
1298        self._accumulator_ref, name=name)
1299
1300  def set_global_step(self, new_global_step, name=None):
1301    """Sets the global time step of the accumulator.
1302
1303    The operation logs a warning if we attempt to set to a time step that is
1304    lower than the accumulator's own time step.
1305
1306    Args:
1307      new_global_step: Value of new time step. Can be a variable or a constant
1308      name: Optional name for the operation.
1309
1310    Returns:
1311      Operation that sets the accumulator's time step.
1312    """
1313    return gen_data_flow_ops.resource_accumulator_set_global_step(
1314        self._accumulator_ref,
1315        math_ops.cast(ops.convert_to_tensor(new_global_step), _dtypes.int64),
1316        name=name)
1317
1318
1319@tf_export(v1=["ConditionalAccumulator"])
1320class ConditionalAccumulator(ConditionalAccumulatorBase):
1321  """A conditional accumulator for aggregating gradients.
1322
1323  Up-to-date gradients (i.e., time step at which gradient was computed is
1324  equal to the accumulator's time step) are added to the accumulator.
1325
1326  Extraction of the average gradient is blocked until the required number of
1327  gradients has been accumulated.
1328  """
1329
1330  def __init__(self,
1331               dtype,
1332               shape=None,
1333               shared_name=None,
1334               name="conditional_accumulator",
1335               reduction_type="MEAN"):
1336    """Creates a new ConditionalAccumulator.
1337
1338    Args:
1339      dtype: Datatype of the accumulated gradients.
1340      shape: Shape of the accumulated gradients.
1341      shared_name: Optional. If non-empty, this accumulator will be shared under
1342        the given name across multiple sessions.
1343      name: Optional name for the accumulator.
1344      reduction_type: Reduction type to use when taking the gradient.
1345    """
1346    accumulator_ref = gen_data_flow_ops.resource_conditional_accumulator(
1347        dtype=dtype,
1348        shape=shape,
1349        shared_name=shared_name,
1350        name=name,
1351        reduction_type=reduction_type)
1352    if context.executing_eagerly():
1353      self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
1354          handle=accumulator_ref, handle_device=context.context().device_name)
1355
1356    super(ConditionalAccumulator, self).__init__(dtype, shape, accumulator_ref)
1357
1358  def apply_grad(self, grad, local_step=0, name=None):
1359    """Attempts to apply a gradient to the accumulator.
1360
1361    The attempt is silently dropped if the gradient is stale, i.e., local_step
1362    is less than the accumulator's global time step.
1363
1364    Args:
1365      grad: The gradient tensor to be applied.
1366      local_step: Time step at which the gradient was computed.
1367      name: Optional name for the operation.
1368
1369    Returns:
1370      The operation that (conditionally) applies a gradient to the accumulator.
1371
1372    Raises:
1373      ValueError: If grad is of the wrong shape
1374    """
1375    grad = ops.convert_to_tensor(grad, self._dtype)
1376    grad.get_shape().assert_is_compatible_with(self._shape)
1377    local_step = math_ops.cast(ops.convert_to_tensor(local_step), _dtypes.int64)
1378
1379    return gen_data_flow_ops.resource_accumulator_apply_gradient(
1380        self._accumulator_ref, local_step=local_step, gradient=grad, name=name)
1381
1382  def take_grad(self, num_required, name=None):
1383    """Attempts to extract the average gradient from the accumulator.
1384
1385    The operation blocks until sufficient number of gradients have been
1386    successfully applied to the accumulator.
1387
1388    Once successful, the following actions are also triggered:
1389
1390    - Counter of accumulated gradients is reset to 0.
1391    - Aggregated gradient is reset to 0 tensor.
1392    - Accumulator's internal time step is incremented by 1.
1393
1394    Args:
1395      num_required: Number of gradients that needs to have been aggregated
1396      name: Optional name for the operation
1397
1398    Returns:
1399      A tensor holding the value of the average gradient.
1400
1401    Raises:
1402      InvalidArgumentError: If num_required < 1
1403    """
1404    out = gen_data_flow_ops.resource_accumulator_take_gradient(
1405        self._accumulator_ref, num_required, dtype=self._dtype, name=name)
1406    out.set_shape(self._shape)
1407    return out
1408
1409
1410@tf_export(
1411    v1=["sparse.SparseConditionalAccumulator", "SparseConditionalAccumulator"])
1412class SparseConditionalAccumulator(ConditionalAccumulatorBase):
1413  """A conditional accumulator for aggregating sparse gradients.
1414
1415  Sparse gradients are represented by `IndexedSlices`.
1416
1417  Up-to-date gradients (i.e., time step at which gradient was computed is
1418  equal to the accumulator's time step) are added to the accumulator.
1419
1420  Extraction of the average gradient is blocked until the required number of
1421  gradients has been accumulated.
1422
1423  Args:
1424    dtype: Datatype of the accumulated gradients.
1425    shape: Shape of the accumulated gradients.
1426    shared_name: Optional. If non-empty, this accumulator will be shared under
1427      the given name across multiple sessions.
1428    name: Optional name for the accumulator.
1429    reduction_type: Reduction type to use when taking the gradient.
1430  """
1431
1432  def __init__(self,
1433               dtype,
1434               shape=None,
1435               shared_name=None,
1436               name="sparse_conditional_accumulator",
1437               reduction_type="MEAN"):
1438    accumulator_ref = gen_data_flow_ops.sparse_conditional_accumulator(
1439        dtype=dtype,
1440        shape=shape,
1441        shared_name=shared_name,
1442        name=name,
1443        reduction_type=reduction_type)
1444    super(SparseConditionalAccumulator, self).__init__(dtype, shape,
1445                                                       accumulator_ref)
1446
1447  def apply_indexed_slices_grad(self, grad, local_step=0, name=None):
1448    """Attempts to apply a gradient to the accumulator.
1449
1450    The attempt is silently dropped if the gradient is stale, i.e., `local_step`
1451    is less than the accumulator's global time step.
1452
1453    Args:
1454      grad: The gradient `IndexedSlices` to be applied.
1455      local_step: Time step at which the gradient was computed.
1456      name: Optional name for the operation.
1457
1458    Returns:
1459      The operation that (conditionally) applies a gradient to the accumulator.
1460
1461    Raises:
1462      InvalidArgumentError: If grad is of the wrong shape
1463    """
1464    return self.apply_grad(
1465        grad_indices=grad.indices,
1466        grad_values=grad.values,
1467        grad_shape=grad.dense_shape,
1468        local_step=local_step,
1469        name=name)
1470
1471  def apply_grad(self,
1472                 grad_indices,
1473                 grad_values,
1474                 grad_shape=None,
1475                 local_step=0,
1476                 name=None):
1477    """Attempts to apply a sparse gradient to the accumulator.
1478
1479    The attempt is silently dropped if the gradient is stale, i.e., `local_step`
1480    is less than the accumulator's global time step.
1481
1482    A sparse gradient is represented by its indices, values and possibly empty
1483    or None shape. Indices must be a vector representing the locations of
1484    non-zero entries in the tensor. Values are the non-zero slices of the
1485    gradient, and must have the same first dimension as indices, i.e., the nnz
1486    represented by indices and values must be consistent. Shape, if not empty or
1487    None, must be consistent with the accumulator's shape (if also provided).
1488
1489    Example:
1490      A tensor [[0, 0], [0, 1], [2, 3]] can be represented
1491        indices: [1,2]
1492        values: [[0,1],[2,3]]
1493        shape: [3, 2]
1494
1495    Args:
1496      grad_indices: Indices of the sparse gradient to be applied.
1497      grad_values: Values of the sparse gradient to be applied.
1498      grad_shape: Shape of the sparse gradient to be applied.
1499      local_step: Time step at which the gradient was computed.
1500      name: Optional name for the operation.
1501
1502    Returns:
1503      The operation that (conditionally) applies a gradient to the accumulator.
1504
1505    Raises:
1506      InvalidArgumentError: If grad is of the wrong shape
1507    """
1508    local_step = math_ops.cast(ops.convert_to_tensor(local_step), _dtypes.int64)
1509    return gen_data_flow_ops.sparse_accumulator_apply_gradient(
1510        self._accumulator_ref,
1511        local_step=local_step,
1512        gradient_indices=math_ops.cast(grad_indices, _dtypes.int64),
1513        gradient_values=grad_values,
1514        gradient_shape=math_ops.cast(
1515            [] if grad_shape is None else grad_shape, _dtypes.int64),
1516        has_known_shape=(grad_shape is not None),
1517        name=name)
1518
1519  def take_grad(self, num_required, name=None):
1520    """Attempts to extract the average gradient from the accumulator.
1521
1522    The operation blocks until sufficient number of gradients have been
1523    successfully applied to the accumulator.
1524
1525    Once successful, the following actions are also triggered:
1526    - Counter of accumulated gradients is reset to 0.
1527    - Aggregated gradient is reset to 0 tensor.
1528    - Accumulator's internal time step is incremented by 1.
1529
1530    Args:
1531      num_required: Number of gradients that needs to have been aggregated
1532      name: Optional name for the operation
1533
1534    Returns:
1535      A tuple of indices, values, and shape representing the average gradient.
1536
1537    Raises:
1538      InvalidArgumentError: If `num_required` < 1
1539    """
1540    return gen_data_flow_ops.sparse_accumulator_take_gradient(
1541        self._accumulator_ref, num_required, dtype=self._dtype, name=name)
1542
1543  def take_indexed_slices_grad(self, num_required, name=None):
1544    """Attempts to extract the average gradient from the accumulator.
1545
1546    The operation blocks until sufficient number of gradients have been
1547    successfully applied to the accumulator.
1548
1549    Once successful, the following actions are also triggered:
1550    - Counter of accumulated gradients is reset to 0.
1551    - Aggregated gradient is reset to 0 tensor.
1552    - Accumulator's internal time step is incremented by 1.
1553
1554    Args:
1555      num_required: Number of gradients that needs to have been aggregated
1556      name: Optional name for the operation
1557
1558    Returns:
1559      An `IndexedSlices` holding the value of the average gradient.
1560
1561    Raises:
1562      InvalidArgumentError: If `num_required` < 1
1563    """
1564    return_val = gen_data_flow_ops.sparse_accumulator_take_gradient(
1565        self._accumulator_ref, num_required, dtype=self._dtype, name=name)
1566    return ops.IndexedSlices(
1567        indices=return_val.indices,
1568        values=return_val.values,
1569        dense_shape=return_val.shape)
1570
1571  # SparseConditionalAccumulator is not switched to resource. Use old kernels.
1572  def num_accumulated(self, name=None):
1573    """Number of gradients that have currently been aggregated in accumulator.
1574
1575    Args:
1576      name: Optional name for the operation.
1577
1578    Returns:
1579      Number of accumulated gradients currently in accumulator.
1580    """
1581    if name is None:
1582      name = "%s_NumAccumulated" % self._name
1583
1584    return gen_data_flow_ops.accumulator_num_accumulated(
1585        self._accumulator_ref, name=name)
1586
1587  def set_global_step(self, new_global_step, name=None):
1588    """Sets the global time step of the accumulator.
1589
1590    The operation logs a warning if we attempt to set to a time step that is
1591    lower than the accumulator's own time step.
1592
1593    Args:
1594      new_global_step: Value of new time step. Can be a variable or a constant
1595      name: Optional name for the operation.
1596
1597    Returns:
1598      Operation that sets the accumulator's time step.
1599    """
1600    return gen_data_flow_ops.accumulator_set_global_step(
1601        self._accumulator_ref,
1602        math_ops.cast(ops.convert_to_tensor(new_global_step), _dtypes.int64),
1603        name=name)
1604
1605
1606class BaseStagingArea(object):
1607  """Base class for Staging Areas."""
1608  _identifier = 0
1609  _lock = threading.Lock()
1610
1611  def __init__(self,
1612               dtypes,
1613               shapes=None,
1614               names=None,
1615               shared_name=None,
1616               capacity=0,
1617               memory_limit=0):
1618    if shared_name is None:
1619      self._name = (
1620          ops.get_default_graph().unique_name(self.__class__.__name__))
1621    elif isinstance(shared_name, six.string_types):
1622      self._name = shared_name
1623    else:
1624      raise ValueError("shared_name must be a string")
1625
1626    self._dtypes = dtypes
1627
1628    if shapes is not None:
1629      if len(shapes) != len(dtypes):
1630        raise ValueError("StagingArea shapes must be the same length as dtypes")
1631      self._shapes = [tensor_shape.TensorShape(s) for s in shapes]
1632    else:
1633      self._shapes = [tensor_shape.unknown_shape() for _ in self._dtypes]
1634
1635    if names is not None:
1636      if len(names) != len(dtypes):
1637        raise ValueError("StagingArea names must be the same length as dtypes")
1638      self._names = names
1639    else:
1640      self._names = None
1641
1642    self._capacity = capacity
1643    self._memory_limit = memory_limit
1644
1645    # all get and put ops must colocate with this op
1646    with ops.name_scope("%s_root" % self._name):
1647      self._coloc_op = control_flow_ops.no_op()
1648
1649  @property
1650  def name(self):
1651    """The name of the staging area."""
1652    return self._name
1653
1654  @property
1655  def dtypes(self):
1656    """The list of dtypes for each component of a staging area element."""
1657    return self._dtypes
1658
1659  @property
1660  def shapes(self):
1661    """The list of shapes for each component of a staging area element."""
1662    return self._shapes
1663
1664  @property
1665  def names(self):
1666    """The list of names for each component of a staging area element."""
1667    return self._names
1668
1669  @property
1670  def capacity(self):
1671    """The maximum number of elements of this staging area."""
1672    return self._capacity
1673
1674  @property
1675  def memory_limit(self):
1676    """The maximum number of bytes of this staging area."""
1677    return self._memory_limit
1678
1679  def _check_put_dtypes(self, vals, indices=None):
1680    """Validate and convert `vals` to a list of `Tensor`s.
1681
1682    The `vals` argument can be a Tensor, a list or tuple of tensors, or a
1683    dictionary with tensor values.
1684
1685    If `vals` is a list, then the appropriate indices associated with the
1686    values must be provided.
1687
1688    If it is a dictionary, the staging area must have been constructed with a
1689    `names` attribute and the dictionary keys must match the staging area names.
1690    `indices` will be inferred from the dictionary keys.
1691    If the staging area was constructed with a `names` attribute, `vals` must
1692    be a dictionary.
1693
1694    Checks that the dtype and shape of each value matches that
1695    of the staging area.
1696
1697    Args:
1698      vals: A tensor, a list or tuple of tensors, or a dictionary.
1699
1700    Returns:
1701      A (tensors, indices) tuple where `tensors` is a list of `Tensor` objects
1702      and `indices` is a list of indices associated with the tensors.
1703
1704    Raises:
1705      ValueError: If `vals` or `indices` is invalid.
1706    """
1707    if isinstance(vals, dict):
1708      if not self._names:
1709        raise ValueError(
1710            "Staging areas must have names to enqueue a dictionary")
1711      if not set(vals.keys()).issubset(self._names):
1712        raise ValueError("Keys in dictionary to put do not match names "
1713                         "of staging area. Dictionary: (%s), Queue: (%s)" %
1714                         (sorted(vals.keys()), sorted(self._names)))
1715      # The order of values in `self._names` indicates the order in which the
1716      # tensors in the dictionary `vals` must be listed.
1717      vals, indices, _ = zip(*[(vals[k], i, k)
1718                               for i, k in enumerate(self._names)
1719                               if k in vals])
1720    else:
1721      if self._names:
1722        raise ValueError("You must enqueue a dictionary in a staging area "
1723                         "with names")
1724
1725      if indices is None:
1726        raise ValueError("Indices must be supplied when inserting a list "
1727                         "of tensors")
1728
1729      if len(indices) != len(vals):
1730        raise ValueError("Number of indices '%s' doesn't match "
1731                         "number of values '%s'")
1732
1733      if not isinstance(vals, (list, tuple)):
1734        vals = [vals]
1735        indices = [0]
1736
1737    # Sanity check number of values
1738    if not len(vals) <= len(self._dtypes):
1739      raise ValueError("Unexpected number of inputs '%s' vs '%s'" %
1740                       (len(vals), len(self._dtypes)))
1741
1742    tensors = []
1743
1744    for val, i in zip(vals, indices):
1745      dtype, shape = self._dtypes[i], self._shapes[i]
1746      # Check dtype
1747      if val.dtype != dtype:
1748        raise ValueError("Datatypes do not match. '%s' != '%s'" %
1749                         (str(val.dtype), str(dtype)))
1750
1751      # Check shape
1752      val.get_shape().assert_is_compatible_with(shape)
1753
1754      tensors.append(
1755          ops.convert_to_tensor(val, dtype=dtype, name="component_%d" % i))
1756
1757    return tensors, indices
1758
1759  def _create_device_transfers(self, tensors):
1760    """Encode inter-device transfers if the current device
1761    is not the same as the Staging Area's device.
1762    """
1763
1764    if not isinstance(tensors, (tuple, list)):
1765      tensors = [tensors]
1766
1767    curr_device_scope = control_flow_ops.no_op().device
1768
1769    if curr_device_scope != self._coloc_op.device:
1770      tensors = [array_ops.identity(t) for t in tensors]
1771
1772    return tensors
1773
1774  def _get_return_value(self, tensors, indices):
1775    """Return the value to return from a get op.
1776
1777    If the staging area has names, return a dictionary with the
1778    names as keys.  Otherwise return either a single tensor
1779    or a list of tensors depending on the length of `tensors`.
1780
1781    Args:
1782      tensors: List of tensors from the get op.
1783      indices: Indices of associated names and shapes
1784
1785    Returns:
1786      A single tensor, a list of tensors, or a dictionary
1787      of tensors.
1788    """
1789
1790    tensors = self._create_device_transfers(tensors)
1791
1792    # Sets shape
1793    for output, i in zip(tensors, indices):
1794      output.set_shape(self._shapes[i])
1795
1796    if self._names:
1797      # The returned values in `tensors` are in the same order as
1798      # the names in `self._names`.
1799      return {self._names[i]: t for t, i in zip(tensors, indices)}
1800    return tensors
1801
1802  def _scope_vals(self, vals):
1803    """Return a list of values to pass to `name_scope()`.
1804
1805    Args:
1806      vals: A tensor, a list or tuple of tensors, or a dictionary.
1807
1808    Returns:
1809      The values in vals as a list.
1810    """
1811    if isinstance(vals, (list, tuple)):
1812      return vals
1813    elif isinstance(vals, dict):
1814      return vals.values()
1815    else:
1816      return [vals]
1817
1818
1819class StagingArea(BaseStagingArea):
1820  """Class for staging inputs. No ordering guarantees.
1821
1822  A `StagingArea` is a TensorFlow data structure that stores tensors across
1823  multiple steps, and exposes operations that can put and get tensors.
1824
1825  Each `StagingArea` element is a tuple of one or more tensors, where each
1826  tuple component has a static dtype, and may have a static shape.
1827
1828  The capacity of a `StagingArea` may be bounded or unbounded.
1829  It supports multiple concurrent producers and consumers; and
1830  provides exactly-once delivery.
1831
1832  Each element of a `StagingArea` is a fixed-length tuple of tensors whose
1833  dtypes are described by `dtypes`, and whose shapes are optionally described
1834  by the `shapes` argument.
1835
1836  If the `shapes` argument is specified, each component of a staging area
1837  element must have the respective fixed shape. If it is
1838  unspecified, different elements may have different shapes,
1839
1840  It can be configured with a capacity in which case
1841  put(values) will block until space becomes available.
1842
1843  Similarly, it can be configured with a memory limit which
1844  will block put(values) until space is available.
1845  This is mostly useful for limiting the number of tensors on
1846  devices such as GPUs.
1847
1848  All get() and peek() commands block if the requested data
1849  is not present in the Staging Area.
1850
1851  """
1852
1853  def __init__(self,
1854               dtypes,
1855               shapes=None,
1856               names=None,
1857               shared_name=None,
1858               capacity=0,
1859               memory_limit=0):
1860    """Constructs a staging area object.
1861
1862    The two optional lists, `shapes` and `names`, must be of the same length
1863    as `dtypes` if provided.  The values at a given index `i` indicate the
1864    shape and name to use for the corresponding queue component in `dtypes`.
1865
1866    The device scope at the time of object creation determines where the
1867    storage for the `StagingArea` will reside.  Calls to `put` will incur a copy
1868    to this memory space, if necessary.  Tensors returned by `get` will be
1869    placed according to the device scope when `get` is called.
1870
1871    Args:
1872      dtypes:  A list of types.  The length of dtypes must equal the number
1873        of tensors in each element.
1874      shapes: (Optional.) Constraints on the shapes of tensors in an element.
1875        A list of shape tuples or None. This list is the same length
1876        as dtypes.  If the shape of any tensors in the element are constrained,
1877        all must be; shapes can be None if the shapes should not be constrained.
1878      names: (Optional.) If provided, the `get()` and
1879        `put()` methods will use dictionaries with these names as keys.
1880        Must be None or a list or tuple of the same length as `dtypes`.
1881      shared_name: (Optional.) A name to be used for the shared object. By
1882        passing the same name to two different python objects they will share
1883        the underlying staging area. Must be a string.
1884      capacity: (Optional.) Maximum number of elements.
1885        An integer. If zero, the Staging Area is unbounded
1886      memory_limit: (Optional.) Maximum number of bytes of all tensors
1887        in the Staging Area.
1888        An integer. If zero, the Staging Area is unbounded
1889
1890    Raises:
1891      ValueError: If one of the arguments is invalid.
1892    """
1893
1894    super(StagingArea, self).__init__(dtypes, shapes, names, shared_name,
1895                                      capacity, memory_limit)
1896
1897  def put(self, values, name=None):
1898    """Create an op that places a value into the staging area.
1899
1900    This operation will block if the `StagingArea` has reached
1901    its capacity.
1902
1903    Args:
1904      values: A single tensor, a list or tuple of tensors, or a dictionary with
1905        tensor values. The number of elements must match the length of the
1906        list provided to the dtypes argument when creating the StagingArea.
1907      name: A name for the operation (optional).
1908
1909    Returns:
1910        The created op.
1911
1912    Raises:
1913      ValueError: If the number or type of inputs don't match the staging area.
1914    """
1915    with ops.name_scope(name, "%s_put" % self._name,
1916                        self._scope_vals(values)) as scope:
1917
1918      if not isinstance(values, (list, tuple, dict)):
1919        values = [values]
1920
1921      # Hard-code indices for this staging area
1922      indices = list(six.moves.range(len(values)))
1923      vals, _ = self._check_put_dtypes(values, indices)
1924
1925      with ops.colocate_with(self._coloc_op):
1926        op = gen_data_flow_ops.stage(
1927            values=vals,
1928            shared_name=self._name,
1929            name=scope,
1930            capacity=self._capacity,
1931            memory_limit=self._memory_limit)
1932
1933      return op
1934
1935  def __internal_get(self, get_fn, name):
1936    with ops.colocate_with(self._coloc_op):
1937      ret = get_fn()
1938
1939    indices = list(six.moves.range(len(self._dtypes)))  # Hard coded
1940    return self._get_return_value(ret, indices)
1941
1942  def get(self, name=None):
1943    """Gets one element from this staging area.
1944
1945    If the staging area is empty when this operation executes, it will block
1946    until there is an element to dequeue.
1947
1948    Note that unlike others ops that can block, like the queue Dequeue
1949    operations, this can stop other work from happening.  To avoid this, the
1950    intended use is for this to be called only when there will be an element
1951    already available.  One method for doing this in a training loop would be to
1952    run a `put()` call during a warmup session.run call, and then call both
1953    `get()` and `put()` in each subsequent step.
1954
1955    The placement of the returned tensor will be determined by the current
1956    device scope when this function is called.
1957
1958    Args:
1959      name: A name for the operation (optional).
1960
1961    Returns:
1962      The tuple of tensors that was gotten.
1963    """
1964    if name is None:
1965      name = "%s_get" % self._name
1966
1967    # pylint: disable=bad-continuation
1968    fn = lambda: gen_data_flow_ops.unstage(dtypes=self._dtypes,
1969                    shared_name=self._name, name=name,
1970                    capacity=self._capacity,
1971                    memory_limit=self._memory_limit)
1972    # pylint: enable=bad-continuation
1973
1974    return self.__internal_get(fn, name)
1975
1976  def peek(self, index, name=None):
1977    """Peeks at an element in the staging area.
1978
1979    If the staging area is too small to contain the element at
1980    the specified index, it will block until enough elements
1981    are inserted to complete the operation.
1982
1983    The placement of the returned tensor will be determined by
1984    the current device scope when this function is called.
1985
1986    Args:
1987      index: The index of the tensor within the staging area
1988              to look up.
1989      name: A name for the operation (optional).
1990
1991    Returns:
1992      The tuple of tensors that was gotten.
1993    """
1994    if name is None:
1995      name = "%s_peek" % self._name
1996
1997    # pylint: disable=bad-continuation
1998    fn = lambda: gen_data_flow_ops.stage_peek(index,
1999                    dtypes=self._dtypes, shared_name=self._name,
2000                    name=name, capacity=self._capacity,
2001                    memory_limit=self._memory_limit)
2002    # pylint: enable=bad-continuation
2003
2004    return self.__internal_get(fn, name)
2005
2006  def size(self, name=None):
2007    """Returns the number of elements in the staging area.
2008
2009    Args:
2010        name: A name for the operation (optional)
2011
2012    Returns:
2013        The created op
2014    """
2015    if name is None:
2016      name = "%s_size" % self._name
2017
2018    return gen_data_flow_ops.stage_size(
2019        name=name,
2020        shared_name=self._name,
2021        dtypes=self._dtypes,
2022        capacity=self._capacity,
2023        memory_limit=self._memory_limit)
2024
2025  def clear(self, name=None):
2026    """Clears the staging area.
2027
2028    Args:
2029        name: A name for the operation (optional)
2030
2031    Returns:
2032        The created op
2033    """
2034    if name is None:
2035      name = "%s_clear" % self._name
2036
2037    return gen_data_flow_ops.stage_clear(
2038        name=name,
2039        shared_name=self._name,
2040        dtypes=self._dtypes,
2041        capacity=self._capacity,
2042        memory_limit=self._memory_limit)
2043
2044
2045class MapStagingArea(BaseStagingArea):
2046  """A `MapStagingArea` is a TensorFlow data structure that stores tensors
2047  across multiple steps, and exposes operations that can put and get tensors.
2048
2049  Each `MapStagingArea` element is a (key, value) pair.
2050  Only int64 keys are supported, other types should be
2051  hashed to produce a key.
2052  Values are a tuple of one or more tensors.
2053  Each tuple component has a static dtype,
2054  and may have a static shape.
2055
2056  The capacity of a `MapStagingArea` may be bounded or unbounded.
2057  It supports multiple concurrent producers and consumers; and
2058  provides exactly-once delivery.
2059
2060  Each value tuple of a `MapStagingArea` is a fixed-length tuple of tensors
2061  whose
2062  dtypes are described by `dtypes`, and whose shapes are optionally described
2063  by the `shapes` argument.
2064
2065  If the `shapes` argument is specified, each component of a staging area
2066  element must have the respective fixed shape. If it is
2067  unspecified, different elements may have different shapes,
2068
2069  It behaves like an associative container with support for:
2070
2071   - put(key, values)
2072   - peek(key)         like dict.get(key)
2073   - get(key)          like dict.pop(key)
2074   - get(key=None)     like dict.popitem()
2075   - size()
2076   - clear()
2077
2078  If ordered a tree structure ordered by key will be used and
2079  get(key=None) will remove (key, value) pairs in increasing key order.
2080  Otherwise a hashtable
2081
2082  It can be configured with a capacity in which case
2083  put(key, values) will block until space becomes available.
2084
2085  Similarly, it can be configured with a memory limit which
2086  will block put(key, values) until space is available.
2087  This is mostly useful for limiting the number of tensors on
2088  devices such as GPUs.
2089
2090  All get() and peek() commands block if the requested
2091  (key, value) pair is not present in the staging area.
2092
2093  Partial puts are supported and will be placed in an incomplete
2094  map until such time as all values associated with the key have
2095  been inserted. Once completed, this (key, value) pair will be
2096  inserted into the map. Data in the incomplete map
2097  counts towards the memory limit, but not towards capacity limit.
2098
2099  Partial gets from the map are also supported.
2100  This removes the partially requested tensors from the entry,
2101  but the entry is only removed from the map once all tensors
2102  associated with it are removed.
2103  """
2104
2105  def __init__(self,
2106               dtypes,
2107               shapes=None,
2108               names=None,
2109               shared_name=None,
2110               ordered=False,
2111               capacity=0,
2112               memory_limit=0):
2113    """Args:
2114
2115      dtypes:  A list of types.  The length of dtypes must equal the number
2116        of tensors in each element.
2117      capacity: (Optional.) Maximum number of elements.
2118        An integer. If zero, the Staging Area is unbounded
2119      memory_limit: (Optional.) Maximum number of bytes of all tensors
2120        in the Staging Area (excluding keys).
2121        An integer. If zero, the Staging Area is unbounded
2122      ordered: (Optional.) If True the underlying data structure
2123        is a tree ordered on key. Otherwise assume a hashtable.
2124      shapes: (Optional.) Constraints on the shapes of tensors in an element.
2125        A list of shape tuples or None. This list is the same length
2126        as dtypes.  If the shape of any tensors in the element are constrained,
2127        all must be; shapes can be None if the shapes should not be constrained.
2128      names: (Optional.) If provided, the `get()` and
2129        `put()` methods will use dictionaries with these names as keys.
2130        Must be None or a list or tuple of the same length as `dtypes`.
2131      shared_name: (Optional.) A name to be used for the shared object. By
2132        passing the same name to two different python objects they will share
2133        the underlying staging area. Must be a string.
2134
2135    Raises:
2136      ValueError: If one of the arguments is invalid.
2137
2138    """
2139
2140    super(MapStagingArea, self).__init__(dtypes, shapes, names, shared_name,
2141                                         capacity, memory_limit)
2142
2143    # Defer to different methods depending if the map is ordered
2144    self._ordered = ordered
2145
2146    if ordered:
2147      self._put_fn = gen_data_flow_ops.ordered_map_stage
2148      self._pop_fn = gen_data_flow_ops.ordered_map_unstage
2149      self._popitem_fn = gen_data_flow_ops.ordered_map_unstage_no_key
2150      self._peek_fn = gen_data_flow_ops.ordered_map_peek
2151      self._size_fn = gen_data_flow_ops.ordered_map_size
2152      self._incomplete_size_fn = gen_data_flow_ops.ordered_map_incomplete_size
2153      self._clear_fn = gen_data_flow_ops.ordered_map_clear
2154    else:
2155      self._put_fn = gen_data_flow_ops.map_stage
2156      self._pop_fn = gen_data_flow_ops.map_unstage
2157      self._popitem_fn = gen_data_flow_ops.map_unstage_no_key
2158      self._peek_fn = gen_data_flow_ops.map_peek
2159      self._size_fn = gen_data_flow_ops.map_size
2160      self._incomplete_size_fn = gen_data_flow_ops.map_incomplete_size
2161      self._clear_fn = gen_data_flow_ops.map_clear
2162
2163  def put(self, key, vals, indices=None, name=None):
2164    """Create an op that stores the (key, vals) pair in the staging area.
2165
2166    Incomplete puts are possible, preferably using a dictionary for vals
2167    as the appropriate dtypes and shapes can be inferred from the value names
2168    dictionary key values. If vals is a list or tuple, indices must
2169    also be specified so that the op knows at which element position
2170    to perform the insert.
2171
2172    This operation will block if the capacity or memory limit of this
2173    container is reached.
2174
2175    Args:
2176        key: Key associated with the data
2177        vals: Tensor (or a dict/tuple of Tensors) to place
2178                into the staging area.
2179        indices: (Optional) if vals is a tuple/list, this is required.
2180        name: A name for the operation (optional)
2181
2182    Returns:
2183        The created op
2184
2185    Raises:
2186        ValueError: If the number or type of inputs don't match the staging
2187        area.
2188    """
2189
2190    with ops.name_scope(name, "%s_put" % self._name,
2191                        self._scope_vals(vals)) as scope:
2192
2193      vals, indices = self._check_put_dtypes(vals, indices)
2194
2195      with ops.colocate_with(self._coloc_op):
2196        op = self._put_fn(
2197            key,
2198            indices,
2199            vals,
2200            dtypes=self._dtypes,
2201            shared_name=self._name,
2202            name=scope,
2203            capacity=self._capacity,
2204            memory_limit=self._memory_limit)
2205    return op
2206
2207  def _get_indices_and_dtypes(self, indices=None):
2208    if indices is None:
2209      indices = list(six.moves.range(len(self._dtypes)))
2210
2211    if not isinstance(indices, (tuple, list)):
2212      raise TypeError("Invalid indices type '%s'" % type(indices))
2213
2214    if len(indices) == 0:
2215      raise ValueError("Empty indices")
2216
2217    if all(isinstance(i, str) for i in indices):
2218      if self._names is None:
2219        raise ValueError("String indices provided '%s', but this Staging Area "
2220                         "was not created with names." % indices)
2221
2222      try:
2223        indices = [self._names.index(n) for n in indices]
2224      except ValueError:
2225        raise ValueError("Named index '%s' not in "
2226                         "Staging Area names '%s'" % (n, self._names))
2227    elif all(isinstance(i, int) for i in indices):
2228      pass
2229    else:
2230      raise TypeError("Mixed types in indices '%s'. "
2231                      "May only be str or int" % indices)
2232
2233    dtypes = [self._dtypes[i] for i in indices]
2234
2235    return indices, dtypes
2236
2237  def peek(self, key, indices=None, name=None):
2238    """Peeks at staging area data associated with the key.
2239
2240    If the key is not in the staging area, it will block
2241    until the associated (key, value) is inserted.
2242
2243    Args:
2244        key: Key associated with the required data
2245        indices: Partial list of tensors to retrieve (optional).
2246                A list of integer or string indices.
2247                String indices are only valid if the Staging Area
2248                has names associated with it.
2249        name: A name for the operation (optional)
2250
2251    Returns:
2252        The created op
2253    """
2254
2255    if name is None:
2256      name = "%s_pop" % self._name
2257
2258    indices, dtypes = self._get_indices_and_dtypes(indices)
2259
2260    with ops.colocate_with(self._coloc_op):
2261      result = self._peek_fn(
2262          key,
2263          shared_name=self._name,
2264          indices=indices,
2265          dtypes=dtypes,
2266          name=name,
2267          capacity=self._capacity,
2268          memory_limit=self._memory_limit)
2269
2270    return self._get_return_value(result, indices)
2271
2272  def get(self, key=None, indices=None, name=None):
2273    """If the key is provided, the associated (key, value) is returned from the staging area.
2274
2275    If the key is not in the staging area, this method will block until
2276    the associated (key, value) is inserted.
2277    If no key is provided and the staging area is ordered,
2278    the (key, value) with the smallest key will be returned.
2279    Otherwise, a random (key, value) will be returned.
2280
2281    If the staging area is empty when this operation executes,
2282    it will block until there is an element to dequeue.
2283
2284    Args:
2285        key: Key associated with the required data (Optional)
2286        indices: Partial list of tensors to retrieve (optional).
2287                A list of integer or string indices.
2288                String indices are only valid if the Staging Area
2289                has names associated with it.
2290        name: A name for the operation (optional)
2291
2292    Returns:
2293        The created op
2294    """
2295    if key is None:
2296      return self._popitem(indices=indices, name=name)
2297    else:
2298      return self._pop(key, indices=indices, name=name)
2299
2300  def _pop(self, key, indices=None, name=None):
2301    """Remove and return the associated (key, value) is returned from the staging area.
2302
2303    If the key is not in the staging area, this method will block until
2304    the associated (key, value) is inserted.
2305    Args:
2306        key: Key associated with the required data
2307        indices: Partial list of tensors to retrieve (optional).
2308                A list of integer or string indices.
2309                String indices are only valid if the Staging Area
2310                has names associated with it.
2311        name: A name for the operation (optional)
2312
2313    Returns:
2314        The created op
2315    """
2316    if name is None:
2317      name = "%s_get" % self._name
2318
2319    indices, dtypes = self._get_indices_and_dtypes(indices)
2320
2321    with ops.colocate_with(self._coloc_op):
2322      result = self._pop_fn(
2323          key,
2324          shared_name=self._name,
2325          indices=indices,
2326          dtypes=dtypes,
2327          name=name,
2328          capacity=self._capacity,
2329          memory_limit=self._memory_limit)
2330
2331    return key, self._get_return_value(result, indices)
2332
2333  def _popitem(self, indices=None, name=None):
2334    """If the staging area is ordered, the (key, value) with the smallest key will be returned.
2335
2336    Otherwise, a random (key, value) will be returned.
2337    If the staging area is empty when this operation executes,
2338    it will block until there is an element to dequeue.
2339
2340    Args:
2341        key: Key associated with the required data
2342        indices: Partial list of tensors to retrieve (optional).
2343                A list of integer or string indices.
2344                String indices are only valid if the Staging Area
2345                has names associated with it.
2346        name: A name for the operation (optional)
2347
2348    Returns:
2349        The created op
2350    """
2351    if name is None:
2352      name = "%s_get_nokey" % self._name
2353
2354    indices, dtypes = self._get_indices_and_dtypes(indices)
2355
2356    with ops.colocate_with(self._coloc_op):
2357      key, result = self._popitem_fn(
2358          shared_name=self._name,
2359          indices=indices,
2360          dtypes=dtypes,
2361          name=name,
2362          capacity=self._capacity,
2363          memory_limit=self._memory_limit)
2364
2365    # Separate keys and results out from
2366    # underlying namedtuple
2367    key = self._create_device_transfers(key)[0]
2368    result = self._get_return_value(result, indices)
2369
2370    return key, result
2371
2372  def size(self, name=None):
2373    """Returns the number of elements in the staging area.
2374
2375    Args:
2376        name: A name for the operation (optional)
2377
2378    Returns:
2379        The created op
2380    """
2381    if name is None:
2382      name = "%s_size" % self._name
2383
2384    return self._size_fn(
2385        shared_name=self._name,
2386        name=name,
2387        dtypes=self._dtypes,
2388        capacity=self._capacity,
2389        memory_limit=self._memory_limit)
2390
2391  def incomplete_size(self, name=None):
2392    """Returns the number of incomplete elements in the staging area.
2393
2394    Args:
2395        name: A name for the operation (optional)
2396
2397    Returns:
2398        The created op
2399    """
2400    if name is None:
2401      name = "%s_incomplete_size" % self._name
2402
2403    return self._incomplete_size_fn(
2404        shared_name=self._name,
2405        name=name,
2406        dtypes=self._dtypes,
2407        capacity=self._capacity,
2408        memory_limit=self._memory_limit)
2409
2410  def clear(self, name=None):
2411    """Clears the staging area.
2412
2413    Args:
2414        name: A name for the operation (optional)
2415
2416    Returns:
2417        The created op
2418    """
2419    if name is None:
2420      name = "%s_clear" % self._name
2421
2422    return self._clear_fn(
2423        shared_name=self._name,
2424        name=name,
2425        dtypes=self._dtypes,
2426        capacity=self._capacity,
2427        memory_limit=self._memory_limit)
2428
2429
2430class RecordInput(object):
2431  """RecordInput asynchronously reads and randomly yields TFRecords.
2432
2433  A RecordInput Op will continuously read a batch of records asynchronously
2434  into a buffer of some fixed capacity. It can also asynchronously yield
2435  random records from this buffer.
2436
2437  It will not start yielding until at least `buffer_size / 2` elements have been
2438  placed into the buffer so that sufficient randomization can take place.
2439
2440  The order the files are read will be shifted each epoch by `shift_amount` so
2441  that the data is presented in a different order every epoch.
2442  """
2443
2444  def __init__(self,
2445               file_pattern,
2446               batch_size=1,
2447               buffer_size=1,
2448               parallelism=1,
2449               shift_ratio=0,
2450               seed=0,
2451               name=None,
2452               batches=None,
2453               compression_type=None):
2454    """Constructs a RecordInput Op.
2455
2456    Args:
2457      file_pattern: File path to the dataset, possibly containing wildcards.
2458        All matching files will be iterated over each epoch.
2459      batch_size: How many records to return at a time.
2460      buffer_size: The maximum number of records the buffer will contain.
2461      parallelism: How many reader threads to use for reading from files.
2462      shift_ratio: What percentage of the total number files to move the start
2463        file forward by each epoch.
2464      seed: Specify the random number seed used by generator that randomizes
2465        records.
2466      name: Optional name for the operation.
2467      batches: None by default, creating a single batch op. Otherwise specifies
2468        how many batches to create, which are returned as a list when
2469        `get_yield_op()` is called. An example use case is to split processing
2470        between devices on one computer.
2471      compression_type: The type of compression for the file. Currently ZLIB and
2472        GZIP are supported. Defaults to none.
2473
2474    Raises:
2475      ValueError: If one of the arguments is invalid.
2476    """
2477    self._batch_size = batch_size
2478    if batches is not None:
2479      self._batch_size *= batches
2480    self._batches = batches
2481    self._file_pattern = file_pattern
2482    self._buffer_size = buffer_size
2483    self._parallelism = parallelism
2484    self._shift_ratio = shift_ratio
2485    self._seed = seed
2486    self._name = name
2487    self._compression_type = python_io.TFRecordCompressionType.NONE
2488    if compression_type is not None:
2489      self._compression_type = compression_type
2490
2491  def get_yield_op(self):
2492    """Adds a node that yields a group of records every time it is executed.
2493    If RecordInput `batches` parameter is not None, it yields a list of
2494    record batches with the specified `batch_size`.
2495    """
2496    compression_type = python_io.TFRecordOptions.get_compression_type_string(
2497        python_io.TFRecordOptions(self._compression_type))
2498    records = gen_data_flow_ops.record_input(
2499        file_pattern=self._file_pattern,
2500        file_buffer_size=self._buffer_size,
2501        file_parallelism=self._parallelism,
2502        file_shuffle_shift_ratio=self._shift_ratio,
2503        batch_size=self._batch_size,
2504        file_random_seed=self._seed,
2505        compression_type=compression_type,
2506        name=self._name)
2507    if self._batches is None:
2508      return records
2509    else:
2510      with ops.name_scope(self._name):
2511        batch_list = [[] for _ in six.moves.range(self._batches)]
2512        records = array_ops.split(records, self._batch_size, 0)
2513        for index, protobuf in enumerate(records):
2514          batch_index = index % self._batches
2515          batch_list[batch_index].append(array_ops.reshape(protobuf, []))
2516        return batch_list
2517