1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
7#     http://www.apache.org/licenses/LICENSE-2.0
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python wrappers for Iterators."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
20import abc
21import threading
22import warnings
24import six
26from tensorflow.python.data.experimental.ops import distribute_options
27from tensorflow.python.data.ops import optional_ops
28from tensorflow.python.data.util import nest
29from tensorflow.python.data.util import structure
30from tensorflow.python.eager import context
31from tensorflow.python.framework import composite_tensor
32from tensorflow.python.framework import dtypes
33from tensorflow.python.framework import errors
34from tensorflow.python.framework import ops
35from tensorflow.python.framework import tensor_shape
36from tensorflow.python.framework import tensor_spec
37from tensorflow.python.framework import type_spec
38from tensorflow.python.ops import gen_dataset_ops
39from tensorflow.python.training.saver import BaseSaverBuilder
40from tensorflow.python.training.tracking import base as trackable
41from tensorflow.python.util import deprecation
42from tensorflow.python.util.compat import collections_abc
43from tensorflow.python.util.tf_export import tf_export
46# NOTE(mrry): It is legitimate to call `Iterator.get_next()` multiple
47# times, e.g. when you are distributing different elements to multiple
48# devices in a single step. However, a common pitfall arises when
49# users call `Iterator.get_next()` in each iteration of their training
50# loop. `Iterator.get_next()` adds ops to the graph, and executing
51# each op allocates resources (including threads); as a consequence,
52# invoking it in every iteration of a training loop causes slowdown
53# and eventual resource exhaustion. To guard against this outcome, we
54# log a warning when the number of uses crosses a threshold of suspicion.
58    "An unusually high number of `Iterator.get_next()` calls was detected. "
59    "This often indicates that `Iterator.get_next()` is being called inside "
60    "a training loop, which will cause gradual slowdown and eventual resource "
61    "exhaustion. If this is the case, restructure your code to call "
62    "`next_element = iterator.get_next()` once outside the loop, and use "
63    "`next_element` as the input to some computation that is invoked inside "
64    "the loop.")
66# Collection of all IteratorResources in the `Graph`.
67GLOBAL_ITERATORS = "iterators"
70def _device_stack_is_empty():
71  if context.executing_eagerly():
72    return context.context().device_name is None
73  # pylint: disable=protected-access
74  device_stack = ops.get_default_graph()._device_functions_outer_to_inner
75  # pylint: enable=protected-access
76  return not bool(device_stack)
80class Iterator(trackable.Trackable):
81  """Represents the state of iterating through a `Dataset`."""
83  def __init__(self, iterator_resource, initializer, output_types,
84               output_shapes, output_classes):
85    """Creates a new iterator from the given iterator resource.
87    Note: Most users will not call this initializer directly, and will
88    instead use `Dataset.make_initializable_iterator()` or
89    `Dataset.make_one_shot_iterator()`.
91    Args:
92      iterator_resource: A `tf.resource` scalar `tf.Tensor` representing the
93        iterator.
94      initializer: A `tf.Operation` that should be run to initialize this
95        iterator.
96      output_types: A nested structure of `tf.DType` objects corresponding to
97        each component of an element of this iterator.
98      output_shapes: A nested structure of `tf.TensorShape` objects
99        corresponding to each component of an element of this iterator.
100      output_classes: A nested structure of Python `type` objects corresponding
101        to each component of an element of this iterator.
102    """
103    self._iterator_resource = iterator_resource
104    self._initializer = initializer
106    if (output_types is None or output_shapes is None
107        or output_classes is None):
108      raise ValueError("If `structure` is not specified, all of "
109                       "`output_types`, `output_shapes`, and `output_classes`"
110                       " must be specified.")
111    self._element_spec = structure.convert_legacy_structure(
112        output_types, output_shapes, output_classes)
113    self._flat_tensor_shapes = structure.get_flat_tensor_shapes(
114        self._element_spec)
115    self._flat_tensor_types = structure.get_flat_tensor_types(
116        self._element_spec)
118    self._string_handle = gen_dataset_ops.iterator_to_string_handle(
119        self._iterator_resource)
120    self._get_next_call_count = 0
121    ops.add_to_collection(GLOBAL_ITERATORS, self._iterator_resource)
123  @staticmethod
124  def from_structure(output_types,
125                     output_shapes=None,
126                     shared_name=None,
127                     output_classes=None):
128    """Creates a new, uninitialized `Iterator` with the given structure.
130    This iterator-constructing method can be used to create an iterator that
131    is reusable with many different datasets.
133    The returned iterator is not bound to a particular dataset, and it has
134    no `initializer`. To initialize the iterator, run the operation returned by
135    `Iterator.make_initializer(dataset)`.
137    The following is an example
139    ```python
140    iterator = Iterator.from_structure(tf.int64, tf.TensorShape([]))
142    dataset_range = Dataset.range(10)
143    range_initializer = iterator.make_initializer(dataset_range)
145    dataset_evens = dataset_range.filter(lambda x: x % 2 == 0)
146    evens_initializer = iterator.make_initializer(dataset_evens)
148    # Define a model based on the iterator; in this example, the model_fn
149    # is expected to take scalar tf.int64 Tensors as input (see
150    # the definition of 'iterator' above).
151    prediction, loss = model_fn(iterator.get_next())
153    # Train for `num_epochs`, where for each epoch, we first iterate over
154    # dataset_range, and then iterate over dataset_evens.
155    for _ in range(num_epochs):
156      # Initialize the iterator to `dataset_range`
157      sess.run(range_initializer)
158      while True:
159        try:
160          pred, loss_val = sess.run([prediction, loss])
161        except tf.errors.OutOfRangeError:
162          break
164      # Initialize the iterator to `dataset_evens`
165      sess.run(evens_initializer)
166      while True:
167        try:
168          pred, loss_val = sess.run([prediction, loss])
169        except tf.errors.OutOfRangeError:
170          break
171    ```
173    Args:
174      output_types: A nested structure of `tf.DType` objects corresponding to
175        each component of an element of this dataset.
176      output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects
177        corresponding to each component of an element of this dataset. If
178        omitted, each component will have an unconstrainted shape.
179      shared_name: (Optional.) If non-empty, this iterator will be shared under
180        the given name across multiple sessions that share the same devices
181        (e.g. when using a remote server).
182      output_classes: (Optional.) A nested structure of Python `type` objects
183        corresponding to each component of an element of this iterator. If
184        omitted, each component is assumed to be of type `tf.Tensor`.
186    Returns:
187      An `Iterator`.
189    Raises:
190      TypeError: If the structures of `output_shapes` and `output_types` are
191        not the same.
192    """
193    output_types = nest.map_structure(dtypes.as_dtype, output_types)
194    if output_shapes is None:
195      output_shapes = nest.map_structure(
196          lambda _: tensor_shape.TensorShape(None), output_types)
197    else:
198      output_shapes = nest.map_structure_up_to(output_types,
199                                               tensor_shape.as_shape,
200                                               output_shapes)
201    if output_classes is None:
202      output_classes = nest.map_structure(lambda _: ops.Tensor, output_types)
203    nest.assert_same_structure(output_types, output_shapes)
204    output_structure = structure.convert_legacy_structure(
205        output_types, output_shapes, output_classes)
206    if shared_name is None:
207      shared_name = ""
208    iterator_resource = gen_dataset_ops.iterator_v2(
209        container="",
210        shared_name=shared_name,
211        output_types=structure.get_flat_tensor_types(output_structure),
212        output_shapes=structure.get_flat_tensor_shapes(
213            output_structure))
214    return Iterator(iterator_resource, None, output_types, output_shapes,
215                    output_classes)
217  @staticmethod
218  def from_string_handle(string_handle,
219                         output_types,
220                         output_shapes=None,
221                         output_classes=None):
222    """Creates a new, uninitialized `Iterator` based on the given handle.
224    This method allows you to define a "feedable" iterator where you can choose
225    between concrete iterators by feeding a value in a `tf.Session.run` call.
226    In that case, `string_handle` would be a `tf.compat.v1.placeholder`, and you
227    would
228    feed it with the value of `tf.data.Iterator.string_handle` in each step.
230    For example, if you had two iterators that marked the current position in
231    a training dataset and a test dataset, you could choose which to use in
232    each step as follows:
234    ```python
235    train_iterator = tf.data.Dataset(...).make_one_shot_iterator()
236    train_iterator_handle = sess.run(train_iterator.string_handle())
238    test_iterator = tf.data.Dataset(...).make_one_shot_iterator()
239    test_iterator_handle = sess.run(test_iterator.string_handle())
241    handle = tf.compat.v1.placeholder(tf.string, shape=[])
242    iterator = tf.data.Iterator.from_string_handle(
243        handle, train_iterator.output_types)
245    next_element = iterator.get_next()
246    loss = f(next_element)
248    train_loss = sess.run(loss, feed_dict={handle: train_iterator_handle})
249    test_loss = sess.run(loss, feed_dict={handle: test_iterator_handle})
250    ```
252    Args:
253      string_handle: A scalar `tf.Tensor` of type `tf.string` that evaluates to
254        a handle produced by the `Iterator.string_handle()` method.
255      output_types: A nested structure of `tf.DType` objects corresponding to
256        each component of an element of this dataset.
257      output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects
258        corresponding to each component of an element of this dataset. If
259        omitted, each component will have an unconstrainted shape.
260      output_classes: (Optional.) A nested structure of Python `type` objects
261        corresponding to each component of an element of this iterator. If
262        omitted, each component is assumed to be of type `tf.Tensor`.
264    Returns:
265      An `Iterator`.
266    """
267    output_types = nest.map_structure(dtypes.as_dtype, output_types)
268    if output_shapes is None:
269      output_shapes = nest.map_structure(
270          lambda _: tensor_shape.TensorShape(None), output_types)
271    else:
272      output_shapes = nest.map_structure_up_to(output_types,
273                                               tensor_shape.as_shape,
274                                               output_shapes)
275    if output_classes is None:
276      output_classes = nest.map_structure(lambda _: ops.Tensor, output_types)
277    nest.assert_same_structure(output_types, output_shapes)
278    output_structure = structure.convert_legacy_structure(
279        output_types, output_shapes, output_classes)
280    string_handle = ops.convert_to_tensor(string_handle, dtype=dtypes.string)
281    iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2(
282        string_handle,
283        output_types=structure.get_flat_tensor_types(output_structure),
284        output_shapes=structure.get_flat_tensor_shapes(output_structure))
285    return Iterator(iterator_resource, None, output_types, output_shapes,
286                    output_classes)
288  @property
289  def initializer(self):
290    """A `tf.Operation` that should be run to initialize this iterator.
292    Returns:
293      A `tf.Operation` that should be run to initialize this iterator
295    Raises:
296      ValueError: If this iterator initializes itself automatically.
297    """
298    if self._initializer is not None:
299      return self._initializer
300    else:
301      # TODO(mrry): Consider whether one-shot iterators should have
302      # initializers that simply reset their state to the beginning.
303      raise ValueError("Iterator does not have an initializer.")
305  def make_initializer(self, dataset, name=None):
306    """Returns a `tf.Operation` that initializes this iterator on `dataset`.
308    Args:
309      dataset: A `Dataset` with compatible structure to this iterator.
310      name: (Optional.) A name for the created operation.
312    Returns:
313      A `tf.Operation` that can be run to initialize this iterator on the given
314      `dataset`.
316    Raises:
317      TypeError: If `dataset` and this iterator do not have a compatible
318        element structure.
319    """
320    with ops.name_scope(name, "make_initializer") as name:
321      # NOTE(mrry): Cannot depend on `dataset_ops.get_legacy_output*()` due
322      # to that creating a circular dependency.
323      # pylint: disable=protected-access
324      dataset_output_types = nest.map_structure(
325          lambda component_spec: component_spec._to_legacy_output_types(),
326          dataset.element_spec)
327      dataset_output_shapes = nest.map_structure(
328          lambda component_spec: component_spec._to_legacy_output_shapes(),
329          dataset.element_spec)
330      dataset_output_classes = nest.map_structure(
331          lambda component_spec: component_spec._to_legacy_output_classes(),
332          dataset.element_spec)
333      # pylint: enable=protected-access
335      nest.assert_same_structure(self.output_types, dataset_output_types)
336      nest.assert_same_structure(self.output_shapes, dataset_output_shapes)
337      for iterator_class, dataset_class in zip(
338          nest.flatten(self.output_classes),
339          nest.flatten(dataset_output_classes)):
340        if iterator_class is not dataset_class:
341          raise TypeError(
342              "Expected output classes %r but got dataset with output class %r."
343              % (self.output_classes, dataset_output_classes))
344      for iterator_dtype, dataset_dtype in zip(
345          nest.flatten(self.output_types), nest.flatten(dataset_output_types)):
346        if iterator_dtype != dataset_dtype:
347          raise TypeError(
348              "Expected output types %r but got dataset with output types %r." %
349              (self.output_types, dataset_output_types))
350      for iterator_shape, dataset_shape in zip(
351          nest.flatten(self.output_shapes), nest.flatten(
352              dataset_output_shapes)):
353        if not iterator_shape.is_compatible_with(dataset_shape):
354          raise TypeError("Expected output shapes compatible with %r but got "
355                          "dataset with output shapes %r." %
356                          (self.output_shapes, dataset_output_shapes))
358    # TODO(b/169442955): Investigate the need for this colocation constraint.
359    with ops.colocate_with(self._iterator_resource):
360      # pylint: disable=protected-access
361      return gen_dataset_ops.make_iterator(
362          dataset._variant_tensor, self._iterator_resource, name=name)
364  def get_next(self, name=None):
365    """Returns a nested structure of `tf.Tensor`s representing the next element.
367    In graph mode, you should typically call this method *once* and use its
368    result as the input to another computation. A typical loop will then call
369    `tf.Session.run` on the result of that computation. The loop will terminate
370    when the `Iterator.get_next()` operation raises
371    `tf.errors.OutOfRangeError`. The following skeleton shows how to use
372    this method when building a training loop:
374    ```python
375    dataset = ...  # A `tf.data.Dataset` object.
376    iterator = dataset.make_initializable_iterator()
377    next_element = iterator.get_next()
379    # Build a TensorFlow graph that does something with each element.
380    loss = model_function(next_element)
381    optimizer = ...  # A `tf.compat.v1.train.Optimizer` object.
382    train_op = optimizer.minimize(loss)
384    with tf.compat.v1.Session() as sess:
385      try:
386        while True:
387          sess.run(train_op)
388      except tf.errors.OutOfRangeError:
389        pass
390    ```
392    NOTE: It is legitimate to call `Iterator.get_next()` multiple times, e.g.
393    when you are distributing different elements to multiple devices in a single
394    step. However, a common pitfall arises when users call `Iterator.get_next()`
395    in each iteration of their training loop. `Iterator.get_next()` adds ops to
396    the graph, and executing each op allocates resources (including threads); as
397    a consequence, invoking it in every iteration of a training loop causes
398    slowdown and eventual resource exhaustion. To guard against this outcome, we
399    log a warning when the number of uses crosses a fixed threshold of
400    suspiciousness.
402    Args:
403      name: (Optional.) A name for the created operation.
405    Returns:
406      A nested structure of `tf.Tensor` objects.
407    """
408    self._get_next_call_count += 1
409    if self._get_next_call_count > GET_NEXT_CALL_WARNING_THRESHOLD:
410      warnings.warn(GET_NEXT_CALL_WARNING_MESSAGE)
412    # TODO(b/169442955): Investigate the need for this colocation constraint.
413    with ops.colocate_with(self._iterator_resource):
414      # pylint: disable=protected-access
415      flat_ret = gen_dataset_ops.iterator_get_next(
416          self._iterator_resource,
417          output_types=self._flat_tensor_types,
418          output_shapes=self._flat_tensor_shapes,
419          name=name)
420      return structure.from_tensor_list(self._element_spec, flat_ret)
422  def get_next_as_optional(self):
423    # TODO(b/169442955): Investigate the need for this colocation constraint.
424    with ops.colocate_with(self._iterator_resource):
425      # pylint: disable=protected-access
426      return optional_ops._OptionalImpl(
427          gen_dataset_ops.iterator_get_next_as_optional(
428              self._iterator_resource,
429              output_types=structure.get_flat_tensor_types(self.element_spec),
430              output_shapes=structure.get_flat_tensor_shapes(
431                  self.element_spec)), self.element_spec)
433  def string_handle(self, name=None):
434    """Returns a string-valued `tf.Tensor` that represents this iterator.
436    Args:
437      name: (Optional.) A name for the created operation.
439    Returns:
440      A scalar `tf.Tensor` of type `tf.string`.
441    """
442    if name is None:
443      return self._string_handle
444    else:
445      return gen_dataset_ops.iterator_to_string_handle(
446          self._iterator_resource, name=name)
448  @property
449  @deprecation.deprecated(
450      None, "Use `tf.compat.v1.data.get_output_classes(iterator)`.")
451  def output_classes(self):
452    """Returns the class of each component of an element of this iterator.
454    The expected values are `tf.Tensor` and `tf.sparse.SparseTensor`.
456    Returns:
457      A nested structure of Python `type` objects corresponding to each
458      component of an element of this dataset.
459    """
460    return nest.map_structure(
461        lambda component_spec: component_spec._to_legacy_output_classes(),  # pylint: disable=protected-access
462        self._element_spec)
464  @property
465  @deprecation.deprecated(
466      None, "Use `tf.compat.v1.data.get_output_shapes(iterator)`.")
467  def output_shapes(self):
468    """Returns the shape of each component of an element of this iterator.
470    Returns:
471      A nested structure of `tf.TensorShape` objects corresponding to each
472      component of an element of this dataset.
473    """
474    return nest.map_structure(
475        lambda component_spec: component_spec._to_legacy_output_shapes(),  # pylint: disable=protected-access
476        self._element_spec)
478  @property
479  @deprecation.deprecated(
480      None, "Use `tf.compat.v1.data.get_output_types(iterator)`.")
481  def output_types(self):
482    """Returns the type of each component of an element of this iterator.
484    Returns:
485      A nested structure of `tf.DType` objects corresponding to each component
486      of an element of this dataset.
487    """
488    return nest.map_structure(
489        lambda component_spec: component_spec._to_legacy_output_types(),  # pylint: disable=protected-access
490        self._element_spec)
492  @property
493  def element_spec(self):
494    return self._element_spec
496  def _gather_saveables_for_checkpoint(self):
498    def _saveable_factory(name):
499      return _IteratorSaveable(self._iterator_resource, name)
501    return {"ITERATOR": _saveable_factory}
504_uid_counter = 0
505_uid_lock = threading.Lock()
508def _generate_shared_name(prefix):
509  with _uid_lock:
510    global _uid_counter
511    uid = _uid_counter
512    _uid_counter += 1
513  return "{}{}".format(prefix, uid)
516class IteratorResourceDeleter(object):
517  """An object which cleans up an iterator resource handle.
519  An alternative to defining a __del__ method on an object. Even if the parent
520  object is part of a reference cycle, the cycle will be collectable.
521  """
523  __slots__ = ["_deleter", "_handle", "_eager_mode"]
525  def __init__(self, handle, deleter):
526    self._deleter = deleter
527    self._handle = handle
528    self._eager_mode = context.executing_eagerly()
530  def __del__(self):
531    # Make sure the resource is deleted in the same mode as it was created in.
532    if self._eager_mode:
533      with context.eager_mode():
534        gen_dataset_ops.delete_iterator(
535            handle=self._handle, deleter=self._deleter)
536    else:
537      with context.graph_mode():
538        gen_dataset_ops.delete_iterator(
539            handle=self._handle, deleter=self._deleter)
542@tf_export("data.Iterator", v1=[])
544class IteratorBase(collections_abc.Iterator, trackable.Trackable,
545                   composite_tensor.CompositeTensor):
546  """Represents an iterator of a `tf.data.Dataset`.
548  `tf.data.Iterator` is the primary mechanism for enumerating elements of a
549  `tf.data.Dataset`. It supports the Python Iterator protocol, which means
550  it can be iterated over using a for-loop:
552  >>> dataset = tf.data.Dataset.range(2)
553  >>> for element in dataset:
554  ...   print(element)
555  tf.Tensor(0, shape=(), dtype=int64)
556  tf.Tensor(1, shape=(), dtype=int64)
558  or by fetching individual elements explicitly via `get_next()`:
560  >>> dataset = tf.data.Dataset.range(2)
561  >>> iterator = iter(dataset)
562  >>> print(iterator.get_next())
563  tf.Tensor(0, shape=(), dtype=int64)
564  >>> print(iterator.get_next())
565  tf.Tensor(1, shape=(), dtype=int64)
567  In addition, non-raising iteration is supported via `get_next_as_optional()`,
568  which returns the next element (if available) wrapped in a
569  `tf.experimental.Optional`.
571  >>> dataset = tf.data.Dataset.from_tensors(42)
572  >>> iterator = iter(dataset)
573  >>> optional = iterator.get_next_as_optional()
574  >>> print(optional.has_value())
575  tf.Tensor(True, shape=(), dtype=bool)
576  >>> optional = iterator.get_next_as_optional()
577  >>> print(optional.has_value())
578  tf.Tensor(False, shape=(), dtype=bool)
579  """
581  @abc.abstractproperty
582  def element_spec(self):
583    """The type specification of an element of this iterator.
585    >>> dataset = tf.data.Dataset.from_tensors(42)
586    >>> iterator = iter(dataset)
587    >>> iterator.element_spec
588    tf.TensorSpec(shape=(), dtype=tf.int32, name=None)
590    Returns:
591      A nested structure of `tf.TypeSpec` objects matching the structure of an
592      element of this iterator, specifying the type of individual components.
593    """
594    raise NotImplementedError("Iterator.element_spec")
596  @abc.abstractmethod
597  def get_next(self):
598    """Returns a nested structure of `tf.Tensor`s containing the next element.
600    >>> dataset = tf.data.Dataset.from_tensors(42)
601    >>> iterator = iter(dataset)
602    >>> print(iterator.get_next())
603    tf.Tensor(42, shape=(), dtype=int32)
605    Returns:
606      A nested structure of `tf.Tensor` objects.
608    Raises:
609      `tf.errors.OutOfRangeError`: If the end of the iterator has been reached.
610    """
611    raise NotImplementedError("Iterator.get_next()")
613  @abc.abstractmethod
614  def get_next_as_optional(self):
615    """Returns a `tf.experimental.Optional` which contains the next element.
617    If the iterator has reached the end of the sequence, the returned
618    `tf.experimental.Optional` will have no value.
620    >>> dataset = tf.data.Dataset.from_tensors(42)
621    >>> iterator = iter(dataset)
622    >>> optional = iterator.get_next_as_optional()
623    >>> print(optional.has_value())
624    tf.Tensor(True, shape=(), dtype=bool)
625    >>> print(optional.get_value())
626    tf.Tensor(42, shape=(), dtype=int32)
627    >>> optional = iterator.get_next_as_optional()
628    >>> print(optional.has_value())
629    tf.Tensor(False, shape=(), dtype=bool)
631    Returns:
632      A `tf.experimental.Optional` object representing the next element.
633    """
634    raise NotImplementedError("Iterator.get_next_as_optional()")
637class OwnedIterator(IteratorBase):
638  """An iterator producing tf.Tensor objects from a tf.data.Dataset.
640  The iterator resource  created through `OwnedIterator` is owned by the Python
641  object and the life time of the underlying resource is tied to the life time
642  of the `OwnedIterator` object. This makes `OwnedIterator` appropriate for use
643  in eager mode and inside of tf.functions.
644  """
646  def __init__(self, dataset=None, components=None, element_spec=None):
647    """Creates a new iterator from the given dataset.
649    If `dataset` is not specified, the iterator will be created from the given
650    tensor components and element structure. In particular, the alternative for
651    constructing the iterator is used when the iterator is reconstructed from
652    it `CompositeTensor` representation.
654    Args:
655      dataset: A `tf.data.Dataset` object.
656      components: Tensor components to construct the iterator from.
657      element_spec: A nested structure of `TypeSpec` objects that
658        represents the type specification of elements of the iterator.
660    Raises:
661      ValueError: If `dataset` is not provided and either `components` or
662        `element_spec` is not provided. Or `dataset` is provided and either
663        `components` and `element_spec` is provided.
664    """
665    super(OwnedIterator, self).__init__()
666    error_message = ("Either `dataset` or both `components` and "
667                     "`element_spec` need to be provided.")
669    if dataset is None:
670      if (components is None or element_spec is None):
671        raise ValueError(error_message)
672      # pylint: disable=protected-access
673      self._element_spec = element_spec
674      self._flat_output_types = structure.get_flat_tensor_types(
675          self._element_spec)
676      self._flat_output_shapes = structure.get_flat_tensor_shapes(
677          self._element_spec)
678      self._iterator_resource, self._deleter = components
679    else:
680      if (components is not None or element_spec is not None):
681        raise ValueError(error_message)
682      self._create_iterator(dataset)
684  def _create_iterator(self, dataset):
685    # pylint: disable=protected-access
686    dataset = dataset._apply_options()
688    # Store dataset reference to ensure that dataset is alive when this iterator
689    # is being used. For example, `tf.data.Dataset.from_generator` registers
690    # a few py_funcs that are needed in `self._next_internal`.  If the dataset
691    # is deleted, this iterator crashes on `self.__next__(...)` call.
692    self._dataset = dataset
694    ds_variant = dataset._variant_tensor
695    self._element_spec = dataset.element_spec
696    self._flat_output_types = structure.get_flat_tensor_types(
697        self._element_spec)
698    self._flat_output_shapes = structure.get_flat_tensor_shapes(
699        self._element_spec)
700    with ops.colocate_with(ds_variant):
701      self._iterator_resource, self._deleter = (
702          gen_dataset_ops.anonymous_iterator_v2(
703              output_types=self._flat_output_types,
704              output_shapes=self._flat_output_shapes))
705      gen_dataset_ops.make_iterator(ds_variant, self._iterator_resource)
706      # Delete the resource when this object is deleted
707      self._resource_deleter = IteratorResourceDeleter(
708          handle=self._iterator_resource,
709          deleter=self._deleter)
711  def __iter__(self):
712    return self
714  def next(self):  # For Python 2 compatibility
715    return self.__next__()
717  def _next_internal(self):
718    if not context.executing_eagerly():
719      # TODO(b/169442955): Investigate the need for this colocation constraint.
720      with ops.colocate_with(self._iterator_resource):
721        ret = gen_dataset_ops.iterator_get_next(
722            self._iterator_resource,
723            output_types=self._flat_output_types,
724            output_shapes=self._flat_output_shapes)
725      return structure.from_compatible_tensor_list(self._element_spec, ret)
727    # TODO(b/77291417): This runs in sync mode as iterators use an error status
728    # to communicate that there is no more data to iterate over.
729    with context.execution_mode(context.SYNC):
730      ret = gen_dataset_ops.iterator_get_next(
731          self._iterator_resource,
732          output_types=self._flat_output_types,
733          output_shapes=self._flat_output_shapes)
735      try:
736        # Fast path for the case `self._structure` is not a nested structure.
737        return self._element_spec._from_compatible_tensor_list(ret)  # pylint: disable=protected-access
738      except AttributeError:
739        return structure.from_compatible_tensor_list(self._element_spec, ret)
741  @property
742  def _type_spec(self):
743    return IteratorSpec(self.element_spec)
745  def __next__(self):
746    try:
747      return self._next_internal()
748    except errors.OutOfRangeError:
749      raise StopIteration
751  @property
752  @deprecation.deprecated(
753      None, "Use `tf.compat.v1.data.get_output_classes(iterator)`.")
754  def output_classes(self):
755    """Returns the class of each component of an element of this iterator.
757    The expected values are `tf.Tensor` and `tf.sparse.SparseTensor`.
759    Returns:
760      A nested structure of Python `type` objects corresponding to each
761      component of an element of this dataset.
762    """
763    return nest.map_structure(
764        lambda component_spec: component_spec._to_legacy_output_classes(),  # pylint: disable=protected-access
765        self._element_spec)
767  @property
768  @deprecation.deprecated(
769      None, "Use `tf.compat.v1.data.get_output_shapes(iterator)`.")
770  def output_shapes(self):
771    """Returns the shape of each component of an element of this iterator.
773    Returns:
774      A nested structure of `tf.TensorShape` objects corresponding to each
775      component of an element of this dataset.
776    """
777    return nest.map_structure(
778        lambda component_spec: component_spec._to_legacy_output_shapes(),  # pylint: disable=protected-access
779        self._element_spec)
781  @property
782  @deprecation.deprecated(
783      None, "Use `tf.compat.v1.data.get_output_types(iterator)`.")
784  def output_types(self):
785    """Returns the type of each component of an element of this iterator.
787    Returns:
788      A nested structure of `tf.DType` objects corresponding to each component
789      of an element of this dataset.
790    """
791    return nest.map_structure(
792        lambda component_spec: component_spec._to_legacy_output_types(),  # pylint: disable=protected-access
793        self._element_spec)
795  @property
796  def element_spec(self):
797    return self._element_spec
799  def get_next(self):
800    return self._next_internal()
802  def get_next_as_optional(self):
803    # TODO(b/169442955): Investigate the need for this colocation constraint.
804    with ops.colocate_with(self._iterator_resource):
805      # pylint: disable=protected-access
806      return optional_ops._OptionalImpl(
807          gen_dataset_ops.iterator_get_next_as_optional(
808              self._iterator_resource,
809              output_types=structure.get_flat_tensor_types(self.element_spec),
810              output_shapes=structure.get_flat_tensor_shapes(
811                  self.element_spec)), self.element_spec)
813  def _gather_saveables_for_checkpoint(self):
815    def _saveable_factory(name):
816      """Returns a SaveableObject for serialization/deserialization."""
817      policy = None
818      if self._dataset:
819        policy = self._dataset.options().experimental_external_state_policy
820      if policy:
821        return _IteratorSaveable(
822            self._iterator_resource,
823            name,
824            external_state_policy=policy)
825      else:
826        return _IteratorSaveable(self._iterator_resource, name)
828    return {"ITERATOR": _saveable_factory}
831@tf_export("data.IteratorSpec", v1=[])
832class IteratorSpec(type_spec.TypeSpec):
833  """Type specification for `tf.data.Iterator`.
835  For instance, `tf.data.IteratorSpec` can be used to define a tf.function that
836  takes `tf.data.Iterator` as an input argument:
838  >>> @tf.function(input_signature=[tf.data.IteratorSpec(
839  ...   tf.TensorSpec(shape=(), dtype=tf.int32, name=None))])
840  ... def square(iterator):
841  ...   x = iterator.get_next()
842  ...   return x * x
843  >>> dataset = tf.data.Dataset.from_tensors(5)
844  >>> iterator = iter(dataset)
845  >>> print(square(iterator))
846  tf.Tensor(25, shape=(), dtype=int32)
848  Attributes:
849    element_spec: A nested structure of `TypeSpec` objects that represents the
850      type specification of the iterator elements.
851  """
853  __slots__ = ["_element_spec"]
855  def __init__(self, element_spec):
856    self._element_spec = element_spec
858  @property
859  def value_type(self):
860    return OwnedIterator
862  def _serialize(self):
863    return (self._element_spec,)
865  @property
866  def _component_specs(self):
867    return (
868        tensor_spec.TensorSpec([], dtypes.resource),
869        tensor_spec.TensorSpec([], dtypes.variant),
870    )
872  def _to_components(self, value):
873    return (value._iterator_resource, value._deleter)  # pylint: disable=protected-access
875  def _from_components(self, components):
876    return OwnedIterator(
877        dataset=None,
878        components=components,
879        element_spec=self._element_spec)
881  @staticmethod
882  def from_value(value):
883    return IteratorSpec(value.element_spec)  # pylint: disable=protected-access
886# TODO(b/71645805): Expose trackable stateful objects from dataset.
887class _IteratorSaveable(BaseSaverBuilder.SaveableObject):
888  """SaveableObject for saving/restoring iterator state."""
890  def __init__(
891      self,
892      iterator_resource,
893      name,
894      external_state_policy=distribute_options.ExternalStatePolicy.FAIL):
895    serialized_iterator = gen_dataset_ops.serialize_iterator(
896        iterator_resource, external_state_policy=external_state_policy.value)
897    specs = [
898        BaseSaverBuilder.SaveSpec(
899            serialized_iterator,
900            "",
901            name + "_STATE",
902            device=iterator_resource.device)
903    ]
904    super(_IteratorSaveable, self).__init__(iterator_resource, specs, name)
906  def restore(self, restored_tensors, restored_shapes):
907    with ops.colocate_with(self.op):
908      return gen_dataset_ops.deserialize_iterator(self.op, restored_tensors[0])
912    None, "Use `tf.data.Iterator.get_next_as_optional()` instead.")
914def get_next_as_optional(iterator):
915  """Returns a `tf.experimental.Optional` with the next element of the iterator.
917  If the iterator has reached the end of the sequence, the returned
918  `tf.experimental.Optional` will have no value.
920  Args:
921    iterator: A `tf.data.Iterator`.
923  Returns:
924    A `tf.experimental.Optional` object which either contains the next element
925    of the iterator (if it exists) or no value.
926  """
927  return iterator.get_next_as_optional()