1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python wrappers for Iterators."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import threading
21import warnings
22
23from tensorflow.python.compat import compat
24from tensorflow.python.data.ops import optional_ops
25from tensorflow.python.data.util import nest
26from tensorflow.python.data.util import structure as structure_lib
27from tensorflow.python.eager import context
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import errors
30from tensorflow.python.framework import ops
31from tensorflow.python.framework import tensor_shape
32from tensorflow.python.ops import gen_dataset_ops
33from tensorflow.python.ops import resource_variable_ops
34from tensorflow.python.training.saver import BaseSaverBuilder
35from tensorflow.python.training.tracking import base as trackable
36from tensorflow.python.util.tf_export import tf_export
37
38
39# NOTE(mrry): It is legitimate to call `Iterator.get_next()` multiple
40# times, e.g. when you are distributing different elements to multiple
41# devices in a single step. However, a common pitfall arises when
42# users call `Iterator.get_next()` in each iteration of their training
43# loop. `Iterator.get_next()` adds ops to the graph, and executing
44# each op allocates resources (including threads); as a consequence,
45# invoking it in every iteration of a training loop causes slowdown
46# and eventual resource exhaustion. To guard against this outcome, we
47# log a warning when the number of uses crosses a threshold of suspicion.
48GET_NEXT_CALL_WARNING_THRESHOLD = 32
49
50GET_NEXT_CALL_WARNING_MESSAGE = (
51    "An unusually high number of `Iterator.get_next()` calls was detected. "
52    "This often indicates that `Iterator.get_next()` is being called inside "
53    "a training loop, which will cause gradual slowdown and eventual resource "
54    "exhaustion. If this is the case, restructure your code to call "
55    "`next_element = iterator.get_next()` once outside the loop, and use "
56    "`next_element` as the input to some computation that is invoked inside "
57    "the loop.")
58
59# Collection of all IteratorResources in the `Graph`.
60GLOBAL_ITERATORS = "iterators"
61
62
63def _device_stack_is_empty():
64  # pylint: disable=protected-access
65  device_stack = ops.get_default_graph()._device_functions_outer_to_inner
66  # pylint: enable=protected-access
67  return not bool(device_stack)
68
69
70@tf_export(v1=["data.Iterator"])
71class Iterator(trackable.Trackable):
72  """Represents the state of iterating through a `Dataset`."""
73
74  def __init__(self, iterator_resource, initializer, output_types,
75               output_shapes, output_classes):
76    """Creates a new iterator from the given iterator resource.
77
78    Note: Most users will not call this initializer directly, and will
79    instead use `Dataset.make_initializable_iterator()` or
80    `Dataset.make_one_shot_iterator()`.
81
82    Args:
83      iterator_resource: A `tf.resource` scalar `tf.Tensor` representing the
84        iterator.
85      initializer: A `tf.Operation` that should be run to initialize this
86        iterator.
87      output_types: A nested structure of `tf.DType` objects corresponding to
88        each component of an element of this iterator.
89      output_shapes: A nested structure of `tf.TensorShape` objects
90        corresponding to each component of an element of this iterator.
91      output_classes: A nested structure of Python `type` objects corresponding
92        to each component of an element of this iterator.
93    """
94    self._iterator_resource = iterator_resource
95    self._initializer = initializer
96
97    if (output_types is None or output_shapes is None
98        or output_classes is None):
99      raise ValueError("If `structure` is not specified, all of "
100                       "`output_types`, `output_shapes`, and `output_classes`"
101                       " must be specified.")
102    self._structure = structure_lib.convert_legacy_structure(
103        output_types, output_shapes, output_classes)
104
105    self._string_handle = gen_dataset_ops.iterator_to_string_handle(
106        self._iterator_resource)
107    self._get_next_call_count = 0
108    ops.add_to_collection(GLOBAL_ITERATORS, self._iterator_resource)
109
110  @staticmethod
111  def from_structure(output_types,
112                     output_shapes=None,
113                     shared_name=None,
114                     output_classes=None):
115    """Creates a new, uninitialized `Iterator` with the given structure.
116
117    This iterator-constructing method can be used to create an iterator that
118    is reusable with many different datasets.
119
120    The returned iterator is not bound to a particular dataset, and it has
121    no `initializer`. To initialize the iterator, run the operation returned by
122    `Iterator.make_initializer(dataset)`.
123
124    The following is an example
125
126    ```python
127    iterator = Iterator.from_structure(tf.int64, tf.TensorShape([]))
128
129    dataset_range = Dataset.range(10)
130    range_initializer = iterator.make_initializer(dataset_range)
131
132    dataset_evens = dataset_range.filter(lambda x: x % 2 == 0)
133    evens_initializer = iterator.make_initializer(dataset_evens)
134
135    # Define a model based on the iterator; in this example, the model_fn
136    # is expected to take scalar tf.int64 Tensors as input (see
137    # the definition of 'iterator' above).
138    prediction, loss = model_fn(iterator.get_next())
139
140    # Train for `num_epochs`, where for each epoch, we first iterate over
141    # dataset_range, and then iterate over dataset_evens.
142    for _ in range(num_epochs):
143      # Initialize the iterator to `dataset_range`
144      sess.run(range_initializer)
145      while True:
146        try:
147          pred, loss_val = sess.run([prediction, loss])
148        except tf.errors.OutOfRangeError:
149          break
150
151      # Initialize the iterator to `dataset_evens`
152      sess.run(evens_initializer)
153      while True:
154        try:
155          pred, loss_val = sess.run([prediction, loss])
156        except tf.errors.OutOfRangeError:
157          break
158    ```
159
160    Args:
161      output_types: A nested structure of `tf.DType` objects corresponding to
162        each component of an element of this dataset.
163      output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects
164        corresponding to each component of an element of this dataset. If
165        omitted, each component will have an unconstrainted shape.
166      shared_name: (Optional.) If non-empty, this iterator will be shared under
167        the given name across multiple sessions that share the same devices
168        (e.g. when using a remote server).
169      output_classes: (Optional.) A nested structure of Python `type` objects
170        corresponding to each component of an element of this iterator. If
171        omitted, each component is assumed to be of type `tf.Tensor`.
172
173    Returns:
174      An `Iterator`.
175
176    Raises:
177      TypeError: If the structures of `output_shapes` and `output_types` are
178        not the same.
179    """
180    output_types = nest.map_structure(dtypes.as_dtype, output_types)
181    if output_shapes is None:
182      output_shapes = nest.map_structure(
183          lambda _: tensor_shape.TensorShape(None), output_types)
184    else:
185      output_shapes = nest.map_structure_up_to(
186          output_types, tensor_shape.as_shape, output_shapes)
187    if output_classes is None:
188      output_classes = nest.map_structure(lambda _: ops.Tensor, output_types)
189    nest.assert_same_structure(output_types, output_shapes)
190    output_structure = structure_lib.convert_legacy_structure(
191        output_types, output_shapes, output_classes)
192    if shared_name is None:
193      shared_name = ""
194    # pylint: disable=protected-access
195    if compat.forward_compatible(2018, 8, 3):
196      if _device_stack_is_empty():
197        with ops.device("/cpu:0"):
198          iterator_resource = gen_dataset_ops.iterator_v2(
199              container="",
200              shared_name=shared_name,
201              output_types=output_structure._flat_types,
202              output_shapes=output_structure._flat_shapes)
203      else:
204        iterator_resource = gen_dataset_ops.iterator_v2(
205            container="",
206            shared_name=shared_name,
207            output_types=output_structure._flat_types,
208            output_shapes=output_structure._flat_shapes)
209    else:
210      iterator_resource = gen_dataset_ops.iterator(
211          container="",
212          shared_name=shared_name,
213          output_types=output_structure._flat_types,
214          output_shapes=output_structure._flat_shapes)
215    # pylint: enable=protected-access
216    return Iterator(iterator_resource, None, output_types, output_shapes,
217                    output_classes)
218
219  @staticmethod
220  def from_string_handle(string_handle,
221                         output_types,
222                         output_shapes=None,
223                         output_classes=None):
224    """Creates a new, uninitialized `Iterator` based on the given handle.
225
226    This method allows you to define a "feedable" iterator where you can choose
227    between concrete iterators by feeding a value in a `tf.Session.run` call.
228    In that case, `string_handle` would be a `tf.placeholder`, and you would
229    feed it with the value of `tf.data.Iterator.string_handle` in each step.
230
231    For example, if you had two iterators that marked the current position in
232    a training dataset and a test dataset, you could choose which to use in
233    each step as follows:
234
235    ```python
236    train_iterator = tf.data.Dataset(...).make_one_shot_iterator()
237    train_iterator_handle = sess.run(train_iterator.string_handle())
238
239    test_iterator = tf.data.Dataset(...).make_one_shot_iterator()
240    test_iterator_handle = sess.run(test_iterator.string_handle())
241
242    handle = tf.placeholder(tf.string, shape=[])
243    iterator = tf.data.Iterator.from_string_handle(
244        handle, train_iterator.output_types)
245
246    next_element = iterator.get_next()
247    loss = f(next_element)
248
249    train_loss = sess.run(loss, feed_dict={handle: train_iterator_handle})
250    test_loss = sess.run(loss, feed_dict={handle: test_iterator_handle})
251    ```
252
253    Args:
254      string_handle: A scalar `tf.Tensor` of type `tf.string` that evaluates
255        to a handle produced by the `Iterator.string_handle()` method.
256      output_types: A nested structure of `tf.DType` objects corresponding to
257        each component of an element of this dataset.
258      output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects
259        corresponding to each component of an element of this dataset. If
260        omitted, each component will have an unconstrainted shape.
261      output_classes: (Optional.) A nested structure of Python `type` objects
262        corresponding to each component of an element of this iterator. If
263        omitted, each component is assumed to be of type `tf.Tensor`.
264
265    Returns:
266      An `Iterator`.
267    """
268    output_types = nest.map_structure(dtypes.as_dtype, output_types)
269    if output_shapes is None:
270      output_shapes = nest.map_structure(
271          lambda _: tensor_shape.TensorShape(None), output_types)
272    else:
273      output_shapes = nest.map_structure_up_to(
274          output_types, tensor_shape.as_shape, output_shapes)
275    if output_classes is None:
276      output_classes = nest.map_structure(lambda _: ops.Tensor, output_types)
277    nest.assert_same_structure(output_types, output_shapes)
278    output_structure = structure_lib.convert_legacy_structure(
279        output_types, output_shapes, output_classes)
280    string_handle = ops.convert_to_tensor(string_handle, dtype=dtypes.string)
281    # pylint: disable=protected-access
282    if compat.forward_compatible(2018, 8, 3):
283      if _device_stack_is_empty():
284        with ops.device("/cpu:0"):
285          iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2(
286              string_handle,
287              output_types=output_structure._flat_types,
288              output_shapes=output_structure._flat_shapes)
289      else:
290        iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2(
291            string_handle,
292            output_types=output_structure._flat_types,
293            output_shapes=output_structure._flat_shapes)
294    else:
295      iterator_resource = gen_dataset_ops.iterator_from_string_handle(
296          string_handle,
297          output_types=output_structure._flat_types,
298          output_shapes=output_structure._flat_shapes)
299    # pylint: enable=protected-access
300    return Iterator(iterator_resource, None, output_types, output_shapes,
301                    output_classes)
302
303  @property
304  def initializer(self):
305    """A `tf.Operation` that should be run to initialize this iterator.
306
307    Returns:
308      A `tf.Operation` that should be run to initialize this iterator
309
310    Raises:
311      ValueError: If this iterator initializes itself automatically.
312    """
313    if self._initializer is not None:
314      return self._initializer
315    else:
316      # TODO(mrry): Consider whether one-shot iterators should have
317      # initializers that simply reset their state to the beginning.
318      raise ValueError("Iterator does not have an initializer.")
319
320  def make_initializer(self, dataset, name=None):
321    """Returns a `tf.Operation` that initializes this iterator on `dataset`.
322
323    Args:
324      dataset: A `Dataset` with compatible structure to this iterator.
325      name: (Optional.) A name for the created operation.
326
327    Returns:
328      A `tf.Operation` that can be run to initialize this iterator on the given
329      `dataset`.
330
331    Raises:
332      TypeError: If `dataset` and this iterator do not have a compatible
333        element structure.
334    """
335    with ops.name_scope(name, "make_initializer") as name:
336      # pylint: disable=protected-access
337      # NOTE(mrry): Cannot depend on `dataset_ops.get_legacy_output*()` due
338      # to that creating a circular dependency.
339      dataset_output_types = (
340          dataset._element_structure._to_legacy_output_types())
341      dataset_output_shapes = (
342          dataset._element_structure._to_legacy_output_shapes())
343      dataset_output_classes = (
344          dataset._element_structure._to_legacy_output_classes())
345      # pylint: enable=protected-access
346
347      nest.assert_same_structure(self.output_types, dataset_output_types)
348      nest.assert_same_structure(self.output_shapes, dataset_output_shapes)
349      for iterator_class, dataset_class in zip(
350          nest.flatten(self.output_classes),
351          nest.flatten(dataset_output_classes)):
352        if iterator_class is not dataset_class:
353          raise TypeError(
354              "Expected output classes %r but got dataset with output class %r."
355              % (self.output_classes, dataset_output_classes))
356      for iterator_dtype, dataset_dtype in zip(
357          nest.flatten(self.output_types), nest.flatten(dataset_output_types)):
358        if iterator_dtype != dataset_dtype:
359          raise TypeError(
360              "Expected output types %r but got dataset with output types %r." %
361              (self.output_types, dataset_output_types))
362      for iterator_shape, dataset_shape in zip(
363          nest.flatten(self.output_shapes), nest.flatten(
364              dataset_output_shapes)):
365        if not iterator_shape.is_compatible_with(dataset_shape):
366          raise TypeError("Expected output shapes compatible with %r but got "
367                          "dataset with output shapes %r." %
368                          (self.output_shapes, dataset_output_shapes))
369    with ops.colocate_with(self._iterator_resource):
370      return gen_dataset_ops.make_iterator(
371          dataset._variant_tensor, self._iterator_resource, name=name)  # pylint: disable=protected-access
372
373  def get_next(self, name=None):
374    """Returns a nested structure of `tf.Tensor`s representing the next element.
375
376    In graph mode, you should typically call this method *once* and use its
377    result as the input to another computation. A typical loop will then call
378    `tf.Session.run` on the result of that computation. The loop will terminate
379    when the `Iterator.get_next()` operation raises
380    `tf.errors.OutOfRangeError`. The following skeleton shows how to use
381    this method when building a training loop:
382
383    ```python
384    dataset = ...  # A `tf.data.Dataset` object.
385    iterator = dataset.make_initializable_iterator()
386    next_element = iterator.get_next()
387
388    # Build a TensorFlow graph that does something with each element.
389    loss = model_function(next_element)
390    optimizer = ...  # A `tf.train.Optimizer` object.
391    train_op = optimizer.minimize(loss)
392
393    with tf.Session() as sess:
394      try:
395        while True:
396          sess.run(train_op)
397      except tf.errors.OutOfRangeError:
398        pass
399    ```
400
401    NOTE: It is legitimate to call `Iterator.get_next()` multiple times, e.g.
402    when you are distributing different elements to multiple devices in a single
403    step. However, a common pitfall arises when users call `Iterator.get_next()`
404    in each iteration of their training loop. `Iterator.get_next()` adds ops to
405    the graph, and executing each op allocates resources (including threads); as
406    a consequence, invoking it in every iteration of a training loop causes
407    slowdown and eventual resource exhaustion. To guard against this outcome, we
408    log a warning when the number of uses crosses a fixed threshold of
409    suspiciousness.
410
411    Args:
412      name: (Optional.) A name for the created operation.
413
414    Returns:
415      A nested structure of `tf.Tensor` objects.
416    """
417    self._get_next_call_count += 1
418    if self._get_next_call_count > GET_NEXT_CALL_WARNING_THRESHOLD:
419      warnings.warn(GET_NEXT_CALL_WARNING_MESSAGE)
420
421    # pylint: disable=protected-access
422    flat_ret = gen_dataset_ops.iterator_get_next(
423        self._iterator_resource,
424        output_types=self._structure._flat_types,
425        output_shapes=self._structure._flat_shapes, name=name)
426    return self._structure._from_tensor_list(flat_ret)
427
428  def string_handle(self, name=None):
429    """Returns a string-valued `tf.Tensor` that represents this iterator.
430
431    Args:
432      name: (Optional.) A name for the created operation.
433
434    Returns:
435      A scalar `tf.Tensor` of type `tf.string`.
436    """
437    if name is None:
438      return self._string_handle
439    else:
440      return gen_dataset_ops.iterator_to_string_handle(
441          self._iterator_resource, name=name)
442
443  @property
444  def output_classes(self):
445    """Returns the class of each component of an element of this iterator.
446
447    The expected values are `tf.Tensor` and `tf.SparseTensor`.
448
449    Returns:
450      A nested structure of Python `type` objects corresponding to each
451      component of an element of this dataset.
452    """
453    return self._structure._to_legacy_output_classes()  # pylint: disable=protected-access
454
455  @property
456  def output_shapes(self):
457    """Returns the shape of each component of an element of this iterator.
458
459    Returns:
460      A nested structure of `tf.TensorShape` objects corresponding to each
461      component of an element of this dataset.
462    """
463    return self._structure._to_legacy_output_shapes()  # pylint: disable=protected-access
464
465  @property
466  def output_types(self):
467    """Returns the type of each component of an element of this iterator.
468
469    Returns:
470      A nested structure of `tf.DType` objects corresponding to each component
471      of an element of this dataset.
472    """
473    return self._structure._to_legacy_output_types()  # pylint: disable=protected-access
474
475  @property
476  def _element_structure(self):
477    """The structure of an element of this iterator.
478
479    Returns:
480      A `Structure` object representing the structure of the components of this
481        optional.
482    """
483    return self._structure
484
485  def _gather_saveables_for_checkpoint(self):
486
487    def _saveable_factory(name):
488      return _IteratorSaveable(self._iterator_resource, name)
489
490    return {"ITERATOR": _saveable_factory}
491
492
493_uid_counter = 0
494_uid_lock = threading.Lock()
495
496
497def _generate_shared_name(prefix):
498  with _uid_lock:
499    global _uid_counter
500    uid = _uid_counter
501    _uid_counter += 1
502  return "{}{}".format(prefix, uid)
503
504
505class EagerIterator(trackable.Trackable):
506  """An iterator producing tf.Tensor objects from a tf.data.Dataset."""
507
508  def __init__(self, dataset):
509    """Creates a new iterator over the given dataset.
510
511    For example:
512    ```python
513    dataset = tf.data.Dataset.range(4)
514    for x in Iterator(dataset):
515      print(x)
516    ```
517
518    Tensors produced will be placed on the device on which this iterator object
519    was created.
520
521    Args:
522      dataset: A `tf.data.Dataset` object.
523
524    Raises:
525      RuntimeError: When invoked without eager execution enabled.
526    """
527
528    if not context.executing_eagerly():
529      raise RuntimeError(
530          "{} objects can only be used when eager execution is enabled, use "
531          "tf.data.Dataset.make_initializable_iterator or "
532          "tf.data.Dataset.make_one_shot_iterator for graph construction".
533          format(type(self)))
534    self._device = context.context().device_name
535    with ops.device("/cpu:0"):
536      # pylint: disable=protected-access
537      dataset = dataset._apply_options()
538      ds_variant = dataset._variant_tensor
539      self._structure = dataset._element_structure
540      self._flat_output_types = self._structure._flat_types
541      self._flat_output_shapes = self._structure._flat_shapes
542      with ops.colocate_with(ds_variant):
543        self._iterator_resource = gen_dataset_ops.anonymous_iterator(
544            output_types=self._flat_output_types,
545            output_shapes=self._flat_output_shapes)
546        gen_dataset_ops.make_iterator(ds_variant, self._iterator_resource)
547        # Delete the resource when this object is deleted
548        self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
549            handle=self._iterator_resource, handle_device=self._device)
550      # pylint: enable=protected-access
551
552  def __iter__(self):
553    return self
554
555  def __next__(self):  # For Python 3 compatibility
556    return self.next()
557
558  def _next_internal(self):
559    """Returns a nested structure of `tf.Tensor`s containing the next element.
560    """
561    # This runs in sync mode as iterators use an error status to communicate
562    # that there is no more data to iterate over.
563    # TODO(b/77291417): Fix
564    with context.execution_mode(context.SYNC):
565      with ops.device(self._device):
566        # TODO(ashankar): Consider removing this ops.device() contextmanager
567        # and instead mimic ops placement in graphs: Operations on resource
568        # handles execute on the same device as where the resource is placed.
569        # NOTE(mrry): Here we use the "_sync" variant of `iterator_get_next`
570        # because in eager mode this code will run synchronously on the calling
571        # thread. Therefore we do not need to make a defensive context switch
572        # to a background thread, and can achieve a small constant performance
573        # boost by invoking the iterator synchronously.
574        ret = gen_dataset_ops.iterator_get_next_sync(
575            self._iterator_resource,
576            output_types=self._flat_output_types,
577            output_shapes=self._flat_output_shapes)
578
579      return self._structure._from_compatible_tensor_list(ret)  # pylint: disable=protected-access
580
581  def next(self):
582    """Returns a nested structure of `tf.Tensor`s containing the next element.
583    """
584    try:
585      return self._next_internal()
586    except errors.OutOfRangeError:
587      raise StopIteration
588
589  @property
590  def output_classes(self):
591    """Returns the class of each component of an element of this iterator.
592
593    The expected values are `tf.Tensor` and `tf.SparseTensor`.
594
595    Returns:
596      A nested structure of Python `type` objects corresponding to each
597      component of an element of this dataset.
598    """
599    return self._structure._to_legacy_output_classes()  # pylint: disable=protected-access
600
601  @property
602  def output_shapes(self):
603    """Returns the shape of each component of an element of this iterator.
604
605    Returns:
606      A nested structure of `tf.TensorShape` objects corresponding to each
607      component of an element of this dataset.
608    """
609    return self._structure._to_legacy_output_shapes()  # pylint: disable=protected-access
610
611  @property
612  def output_types(self):
613    """Returns the type of each component of an element of this iterator.
614
615    Returns:
616      A nested structure of `tf.DType` objects corresponding to each component
617      of an element of this dataset.
618    """
619    return self._structure._to_legacy_output_types()  # pylint: disable=protected-access
620
621  @property
622  def _element_structure(self):
623    """The structure of an element of this iterator.
624
625    Returns:
626      A `Structure` object representing the structure of the components of this
627        optional.
628    """
629    return self._structure
630
631  def get_next(self, name=None):
632    """Returns a nested structure of `tf.Tensor`s containing the next element.
633
634    Args:
635      name: (Optional.) A name for the created operation. Currently unused.
636
637    Returns:
638      A nested structure of `tf.Tensor` objects.
639
640    Raises:
641      `tf.errors.OutOfRangeError`: If the end of the dataset has been reached.
642    """
643    del name
644    return self._next_internal()
645
646  def _gather_saveables_for_checkpoint(self):
647
648    def _saveable_factory(name):
649      return _IteratorSaveable(self._iterator_resource, name)
650
651    return {"ITERATOR": _saveable_factory}
652
653
654# TODO(b/71645805): Expose trackable stateful objects from dataset
655# attributes(potential).
656class _IteratorSaveable(BaseSaverBuilder.SaveableObject):
657  """SaveableObject for saving/restoring iterator state."""
658
659  def __init__(self, iterator_resource, name):
660    serialized_iterator = gen_dataset_ops.serialize_iterator(iterator_resource)
661    specs = [
662        BaseSaverBuilder.SaveSpec(serialized_iterator, "", name + "_STATE")
663    ]
664    super(_IteratorSaveable, self).__init__(iterator_resource, specs, name)
665
666  def restore(self, restored_tensors, restored_shapes):
667    with ops.colocate_with(self.op):
668      return gen_dataset_ops.deserialize_iterator(self.op, restored_tensors[0])
669
670
671def get_next_as_optional(iterator):
672  """Returns an `Optional` that contains the next value from the iterator.
673
674  If `iterator` has reached the end of the sequence, the returned `Optional`
675  will have no value.
676
677  Args:
678    iterator: A `tf.data.Iterator` object.
679
680  Returns:
681    An `Optional` object representing the next value from the iterator (if it
682    has one) or no value.
683  """
684  # pylint: disable=protected-access
685  return optional_ops._OptionalImpl(
686      gen_dataset_ops.iterator_get_next_as_optional(
687          iterator._iterator_resource,
688          output_types=iterator._element_structure._flat_types,
689          output_shapes=iterator._element_structure._flat_shapes),
690      iterator._element_structure)
691