1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python wrappers for Datasets."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import abc
21import functools
22import sys
23import threading
24import warnings
25import weakref
26
27import numpy as np
28import six
29from six.moves import queue as Queue  # pylint: disable=redefined-builtin
30
31from tensorflow.core.framework import dataset_options_pb2
32from tensorflow.core.framework import graph_pb2
33from tensorflow.python import tf2
34from tensorflow.python.data.experimental.ops import distribute_options
35from tensorflow.python.data.experimental.ops import optimization_options
36from tensorflow.python.data.experimental.ops import stats_options
37from tensorflow.python.data.experimental.ops import threading_options
38from tensorflow.python.data.ops import iterator_ops
39from tensorflow.python.data.util import convert
40from tensorflow.python.data.util import nest
41from tensorflow.python.data.util import options as options_lib
42from tensorflow.python.data.util import random_seed
43from tensorflow.python.data.util import structure
44from tensorflow.python.data.util import traverse
45from tensorflow.python.eager import context
46from tensorflow.python.eager import def_function
47from tensorflow.python.eager import function as eager_function
48from tensorflow.python.framework import auto_control_deps
49from tensorflow.python.framework import auto_control_deps_utils as acd_utils
50from tensorflow.python.framework import composite_tensor
51from tensorflow.python.framework import constant_op
52from tensorflow.python.framework import dtypes
53from tensorflow.python.framework import function
54from tensorflow.python.framework import ops
55from tensorflow.python.framework import random_seed as core_random_seed
56from tensorflow.python.framework import smart_cond
57from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
58from tensorflow.python.framework import tensor_shape
59from tensorflow.python.framework import tensor_spec
60from tensorflow.python.framework import tensor_util
61from tensorflow.python.framework import type_spec
62from tensorflow.python.ops import array_ops
63from tensorflow.python.ops import control_flow_ops
64from tensorflow.python.ops import gen_dataset_ops
65from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
66from tensorflow.python.ops import gen_io_ops
67from tensorflow.python.ops import math_ops
68from tensorflow.python.ops import script_ops
69from tensorflow.python.ops import string_ops
70from tensorflow.python.ops.ragged import ragged_tensor
71from tensorflow.python.training.tracking import base as tracking_base
72from tensorflow.python.training.tracking import tracking
73from tensorflow.python.util import deprecation
74from tensorflow.python.util import function_utils
75from tensorflow.python.util import lazy_loader
76from tensorflow.python.util import nest as tf_nest
77from tensorflow.python.util.compat import collections_abc
78from tensorflow.python.util.tf_export import tf_export
79
80# Loaded lazily due to a circular dependency (roughly
81# tf.function->wrap_function->dataset->autograph->tf.function).
82# TODO(b/133251390): Use a regular import.
83wrap_function = lazy_loader.LazyLoader(
84    "wrap_function", globals(),
85    "tensorflow.python.eager.wrap_function")
86# TODO(mdan): Create a public API for this.
87autograph_ctx = lazy_loader.LazyLoader(
88    "autograph_ctx", globals(),
89    "tensorflow.python.autograph.core.ag_ctx")
90autograph = lazy_loader.LazyLoader(
91    "autograph", globals(),
92    "tensorflow.python.autograph.impl.api")
93
94ops.NotDifferentiable("ReduceDataset")
95
96# A constant that can be used to enable auto-tuning.
97AUTOTUNE = -1
98tf_export("data.AUTOTUNE").export_constant(__name__, "AUTOTUNE")
99# TODO(b/168128531): Deprecate and remove this symbol.
100tf_export("data.experimental.AUTOTUNE").export_constant(__name__, "AUTOTUNE")
101
102# Constants representing infinite and unknown cardinalities.
103INFINITE = -1
104UNKNOWN = -2
105tf_export("data.INFINITE_CARDINALITY").export_constant(__name__, "INFINITE")
106tf_export("data.UNKNOWN_CARDINALITY").export_constant(__name__, "UNKNOWN")
107
108
109@tf_export("data.Dataset", v1=[])
110@six.add_metaclass(abc.ABCMeta)
111class DatasetV2(collections_abc.Iterable, tracking_base.Trackable,
112                composite_tensor.CompositeTensor):
113  """Represents a potentially large set of elements.
114
115  The `tf.data.Dataset` API supports writing descriptive and efficient input
116  pipelines. `Dataset` usage follows a common pattern:
117
118  1. Create a source dataset from your input data.
119  2. Apply dataset transformations to preprocess the data.
120  3. Iterate over the dataset and process the elements.
121
122  Iteration happens in a streaming fashion, so the full dataset does not need to
123  fit into memory.
124
125  Source Datasets:
126
127  The simplest way to create a dataset is to create it from a python `list`:
128
129  >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
130  >>> for element in dataset:
131  ...   print(element)
132  tf.Tensor(1, shape=(), dtype=int32)
133  tf.Tensor(2, shape=(), dtype=int32)
134  tf.Tensor(3, shape=(), dtype=int32)
135
136  To process lines from files, use `tf.data.TextLineDataset`:
137
138  >>> dataset = tf.data.TextLineDataset(["file1.txt", "file2.txt"])
139
140  To process records written in the `TFRecord` format, use `TFRecordDataset`:
141
142  >>> dataset = tf.data.TFRecordDataset(["file1.tfrecords", "file2.tfrecords"])
143
144  To create a dataset of all files matching a pattern, use
145  `tf.data.Dataset.list_files`:
146
147  >>> dataset = tf.data.Dataset.list_files("/path/*.txt")  # doctest: +SKIP
148
149  See `tf.data.FixedLengthRecordDataset` and `tf.data.Dataset.from_generator`
150  for more ways to create datasets.
151
152  Transformations:
153
154  Once you have a dataset, you can apply transformations to prepare the data for
155  your model:
156
157  >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
158  >>> dataset = dataset.map(lambda x: x*2)
159  >>> list(dataset.as_numpy_iterator())
160  [2, 4, 6]
161
162  Common Terms:
163
164  **Element**: A single output from calling `next()` on a dataset iterator.
165    Elements may be nested structures containing multiple components. For
166    example, the element `(1, (3, "apple"))` has one tuple nested in another
167    tuple. The components are `1`, `3`, and `"apple"`.
168
169  **Component**: The leaf in the nested structure of an element.
170
171  Supported types:
172
173  Elements can be nested structures of tuples, named tuples, and dictionaries.
174  Note that Python lists are *not* treated as nested structures of components.
175  Instead, lists are converted to tensors and treated as components. For
176  example, the element `(1, [1, 2, 3])` has only two components; the tensor `1`
177  and the tensor `[1, 2, 3]`. Element components can be of any type
178  representable by `tf.TypeSpec`, including `tf.Tensor`, `tf.data.Dataset`,
179  `tf.sparse.SparseTensor`, `tf.RaggedTensor`, and `tf.TensorArray`.
180
181  >>> a = 1 # Integer element
182  >>> b = 2.0 # Float element
183  >>> c = (1, 2) # Tuple element with 2 components
184  >>> d = {"a": (2, 2), "b": 3} # Dict element with 3 components
185  >>> Point = collections.namedtuple("Point", ["x", "y"]) # doctest: +SKIP
186  >>> e = Point(1, 2) # Named tuple # doctest: +SKIP
187  >>> f = tf.data.Dataset.range(10) # Dataset element
188
189  """
190
191  def __init__(self, variant_tensor):
192    """Creates a DatasetV2 object.
193
194    This is a difference between DatasetV1 and DatasetV2. DatasetV1 does not
195    take anything in its constructor whereas in the DatasetV2, we expect
196    subclasses to create a variant_tensor and pass it in to the super() call.
197
198    Args:
199      variant_tensor: A DT_VARIANT tensor that represents the dataset.
200    """
201    self._variant_tensor_attr = variant_tensor
202    weak_self = weakref.proxy(self)
203    self._variant_tracker = self._track_trackable(
204        _VariantTracker(
205            self._variant_tensor,
206            # _trace_variant_creation only works when executing eagerly, so we
207            # don't want to run it immediately. We also want the _VariantTracker
208            # to have a weak reference to the Dataset to avoid creating
209            # reference cycles and making work for the garbage collector.
210            lambda: weak_self._trace_variant_creation()()),  # pylint: disable=unnecessary-lambda,protected-access
211        name="_variant_tracker")
212    self._graph_attr = ops.get_default_graph()
213
214    # Initialize the options for this dataset and its inputs.
215    self._options_attr = Options()
216    for input_dataset in self._inputs():
217      input_options = input_dataset.options()
218      if input_options is not None:
219        self._options_attr = self._options_attr.merge(input_options)
220
221  @property
222  def _variant_tensor(self):
223    return self._variant_tensor_attr
224
225  @_variant_tensor.setter
226  def _variant_tensor(self, _):
227    raise ValueError("The _variant_tensor property is read-only")
228
229  @deprecation.deprecated_args(None, "Use external_state_policy instead",
230                               "allow_stateful")
231  def _as_serialized_graph(
232      self,
233      allow_stateful=None,
234      strip_device_assignment=None,
235      external_state_policy=distribute_options.ExternalStatePolicy.WARN):
236    """Produces serialized graph representation of the dataset.
237
238    Args:
239      allow_stateful: If true, we allow stateful ops to be present in the graph
240        def. In that case, the state in these ops would be thrown away.
241      strip_device_assignment: If true, non-local (i.e. job and task) device
242        assignment is stripped from ops in the serialized graph.
243      external_state_policy: The ExternalStatePolicy enum that determines how we
244        handle input pipelines that depend on external state. By default, its
245        set to WARN.
246
247    Returns:
248      A scalar `tf.Tensor` of `tf.string` type, representing this dataset as a
249      serialized graph.
250    """
251    if external_state_policy:
252      policy = external_state_policy.value
253      return gen_dataset_ops.dataset_to_graph_v2(
254          self._variant_tensor,
255          external_state_policy=policy,
256          strip_device_assignment=strip_device_assignment)
257    if strip_device_assignment:
258      return gen_dataset_ops.dataset_to_graph(
259          self._variant_tensor,
260          allow_stateful=allow_stateful,
261          strip_device_assignment=strip_device_assignment)
262    return gen_dataset_ops.dataset_to_graph(
263        self._variant_tensor, allow_stateful=allow_stateful)
264
265  def _trace_variant_creation(self):
266    """Traces a function which outputs a variant `tf.Tensor` for this dataset.
267
268    Note that creating this function involves evaluating an op, and is currently
269    only supported when executing eagerly.
270
271    Returns:
272      A zero-argument `ConcreteFunction` which outputs a variant `tf.Tensor`.
273    """
274    variant = self._variant_tensor
275    if not isinstance(variant, ops.EagerTensor):
276      raise NotImplementedError(
277          "Can only export Datasets which were created executing eagerly. "
278          "Please file a feature request if this is important to you.")
279    with context.eager_mode(), ops.device("CPU"):
280      # pylint: disable=protected-access
281      graph_def = graph_pb2.GraphDef().FromString(
282          self._as_serialized_graph(external_state_policy=distribute_options
283                                    .ExternalStatePolicy.FAIL).numpy())
284    output_node_name = None
285    for node in graph_def.node:
286      if node.op == "_Retval":
287        if output_node_name is not None:
288          raise AssertionError(
289              "Found multiple return values from the dataset's graph, expected "
290              "only one.")
291        output_node_name, = node.input
292    if output_node_name is None:
293      raise AssertionError("Could not find the dataset's output node.")
294    # Add functions used in this Dataset to the function's graph, since they
295    # need to follow it around (and for example be added to a SavedModel which
296    # references the dataset).
297    variant_function = wrap_function.function_from_graph_def(
298        graph_def, inputs=[], outputs=output_node_name + ":0")
299    for used_function in self._functions():
300      used_function.function.add_to_graph(variant_function.graph)
301    return variant_function
302
303  @abc.abstractmethod
304  def _inputs(self):
305    """Returns a list of the input datasets of the dataset."""
306
307    raise NotImplementedError("Dataset._inputs")
308
309  @property
310  def _graph(self):
311    return self._graph_attr
312
313  @_graph.setter
314  def _graph(self, _):
315    raise ValueError("The _graph property is read-only")
316
317  def _has_captured_ref(self):
318    """Whether this dataset uses a function that captures ref variables.
319
320    Returns:
321      A boolean, which if true indicates that the dataset or one of its inputs
322      uses a function that captures ref variables.
323    """
324    if context.executing_eagerly():
325      # RefVariables are not supported in eager mode
326      return False
327
328    def is_tensor_or_parent_ref(tensor):
329      if tensor.dtype._is_ref_dtype:  # pylint: disable=protected-access
330        return True
331      # If the captured tensor is an eager tensor, we cannot trace its inputs.
332      if isinstance(tensor, ops._EagerTensorBase):  # pylint: disable=protected-access
333        return False
334      return any(is_tensor_or_parent_ref(x) for x in tensor.op.inputs)
335
336    for fn in self._functions():
337      if any(is_tensor_or_parent_ref(t) for t in fn.function.captured_inputs):
338        return True
339
340    return any(
341        [input_dataset._has_captured_ref() for input_dataset in self._inputs()])  # pylint: disable=protected-access
342
343  # TODO(jsimsa): Change this to be the transitive closure of functions used
344  # by this dataset and its inputs.
345  def _functions(self):
346    """Returns a list of functions associated with this dataset.
347
348    Returns:
349      A list of `StructuredFunctionWrapper` objects.
350    """
351    return []
352
353  def options(self):
354    """Returns the options for this dataset and its inputs.
355
356    Returns:
357      A `tf.data.Options` object representing the dataset options.
358    """
359    return self._options_attr
360
361  def _apply_options(self):
362    """Apply options, such as optimization configuration, to the dataset."""
363
364    dataset = self
365    options = self.options()
366
367    # (1) Apply threading options
368    if options.experimental_threading is not None:
369      t_options = options.experimental_threading
370      if t_options.max_intra_op_parallelism is not None:
371        dataset = _MaxIntraOpParallelismDataset(
372            dataset, t_options.max_intra_op_parallelism)
373      if t_options.private_threadpool_size is not None:
374        dataset = _PrivateThreadPoolDataset(dataset,
375                                            t_options.private_threadpool_size)
376
377    # (2) Apply autotune options
378    autotune, algorithm, cpu_budget, ram_budget = options._autotune_settings()  # pylint: disable=protected-access
379    if autotune:
380      dataset = _ModelDataset(dataset, algorithm, cpu_budget, ram_budget)
381
382    # (3) Apply graph rewrite options
383    # pylint: disable=protected-access
384    graph_rewrites = options._graph_rewrites()
385    graph_rewrite_configs = options._graph_rewrite_configs(autotune)
386    # pylint: enable=protected-access
387    if self._has_captured_ref():
388      if graph_rewrites.enabled or graph_rewrites.default:
389        warnings.warn(
390            "tf.data graph rewrites are not compatible with tf.Variable. "
391            "The following rewrites will be disabled: %s. To enable "
392            "rewrites, use resource variables instead by calling "
393            "`tf.enable_resource_variables()` at the start of the program." %
394            ", ".join(graph_rewrites.enabled + graph_rewrites.default))
395    elif (graph_rewrites.enabled or graph_rewrites.default or
396          (options.experimental_optimization.apply_default_optimizations  # pylint: disable=g-bool-id-comparison
397           is not False)):
398      dataset = _OptimizeDataset(dataset, graph_rewrites.enabled,
399                                 graph_rewrites.disabled,
400                                 graph_rewrites.default, graph_rewrite_configs)
401
402    # (4) Apply stats aggregator options
403    if options.experimental_stats and options.experimental_stats.aggregator:  # pylint: disable=line-too-long
404      dataset = _SetStatsAggregatorDataset(  # pylint: disable=protected-access
405          dataset, options.experimental_stats.aggregator,
406          options.experimental_stats.prefix,
407          options.experimental_stats.counter_prefix)
408    return dataset
409
410  def __iter__(self):
411    """Creates an iterator for elements of this dataset.
412
413    The returned iterator implements the Python Iterator protocol.
414
415    Returns:
416      An `tf.data.Iterator` for the elements of this dataset.
417
418    Raises:
419      RuntimeError: If not inside of tf.function and not executing eagerly.
420    """
421    if context.executing_eagerly() or ops.inside_function():
422      with ops.colocate_with(self._variant_tensor):
423        return iterator_ops.OwnedIterator(self)
424    else:
425      raise RuntimeError("__iter__() is only supported inside of tf.function "
426                         "or when eager execution is enabled.")
427
428  def __bool__(self):
429    return True  # Required as __len__ is defined
430
431  __nonzero__ = __bool__  # Python 2 backward compatibility
432
433  def __len__(self):
434    """Returns the length of the dataset if it is known and finite.
435
436    This method requires that you are running in eager mode, and that the
437    length of the dataset is known and non-infinite. When the length may be
438    unknown or infinite, or if you are running in graph mode, use
439    `tf.data.Dataset.cardinality` instead.
440
441    Returns:
442      An integer representing the length of the dataset.
443
444    Raises:
445      RuntimeError: If the dataset length is unknown or infinite, or if eager
446        execution is not enabled.
447    """
448    if not context.executing_eagerly():
449      raise TypeError("__len__() is not supported while tracing functions. "
450                      "Use `tf.data.Dataset.cardinality` instead.")
451    length = self.cardinality()
452    if length.numpy() == INFINITE:
453      raise TypeError("dataset length is infinite.")
454    if length.numpy() == UNKNOWN:
455      raise TypeError("dataset length is unknown.")
456    return length
457
458  @abc.abstractproperty
459  def element_spec(self):
460    """The type specification of an element of this dataset.
461
462    >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
463    >>> dataset.element_spec
464    TensorSpec(shape=(), dtype=tf.int32, name=None)
465
466    Returns:
467      A nested structure of `tf.TypeSpec` objects matching the structure of an
468      element of this dataset and specifying the type of individual components.
469    """
470    raise NotImplementedError("Dataset.element_spec")
471
472  def __repr__(self):
473    output_shapes = nest.map_structure(str, get_legacy_output_shapes(self))
474    output_shapes = str(output_shapes).replace("'", "")
475    output_types = nest.map_structure(repr, get_legacy_output_types(self))
476    output_types = str(output_types).replace("'", "")
477    return ("<%s shapes: %s, types: %s>" % (type(self).__name__, output_shapes,
478                                            output_types))
479
480  def as_numpy_iterator(self):
481    """Returns an iterator which converts all elements of the dataset to numpy.
482
483    Use `as_numpy_iterator` to inspect the content of your dataset. To see
484    element shapes and types, print dataset elements directly instead of using
485    `as_numpy_iterator`.
486
487    >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
488    >>> for element in dataset:
489    ...   print(element)
490    tf.Tensor(1, shape=(), dtype=int32)
491    tf.Tensor(2, shape=(), dtype=int32)
492    tf.Tensor(3, shape=(), dtype=int32)
493
494    This method requires that you are running in eager mode and the dataset's
495    element_spec contains only `TensorSpec` components.
496
497    >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
498    >>> for element in dataset.as_numpy_iterator():
499    ...   print(element)
500    1
501    2
502    3
503
504    >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
505    >>> print(list(dataset.as_numpy_iterator()))
506    [1, 2, 3]
507
508    `as_numpy_iterator()` will preserve the nested structure of dataset
509    elements.
510
511    >>> dataset = tf.data.Dataset.from_tensor_slices({'a': ([1, 2], [3, 4]),
512    ...                                               'b': [5, 6]})
513    >>> list(dataset.as_numpy_iterator()) == [{'a': (1, 3), 'b': 5},
514    ...                                       {'a': (2, 4), 'b': 6}]
515    True
516
517    Returns:
518      An iterable over the elements of the dataset, with their tensors converted
519      to numpy arrays.
520
521    Raises:
522      TypeError: if an element contains a non-`Tensor` value.
523      RuntimeError: if eager execution is not enabled.
524    """
525    if not context.executing_eagerly():
526      raise RuntimeError("as_numpy_iterator() is not supported while tracing "
527                         "functions")
528    for component_spec in nest.flatten(self.element_spec):
529      if not isinstance(
530          component_spec,
531          (tensor_spec.TensorSpec, ragged_tensor.RaggedTensorSpec)):
532        raise TypeError(
533            "Dataset.as_numpy_iterator() does not support datasets containing "
534            + str(component_spec.value_type))
535
536    return _NumpyIterator(self)
537
538  @property
539  def _flat_shapes(self):
540    """Returns a list `tf.TensorShapes`s for the element tensor representation.
541
542    Returns:
543      A list `tf.TensorShapes`s for the element tensor representation.
544    """
545    return structure.get_flat_tensor_shapes(self.element_spec)
546
547  @property
548  def _flat_types(self):
549    """Returns a list `tf.DType`s for the element tensor representation.
550
551    Returns:
552      A list `tf.DType`s for the element tensor representation.
553    """
554    return structure.get_flat_tensor_types(self.element_spec)
555
556  @property
557  def _flat_structure(self):
558    """Helper for setting `output_shapes` and `output_types` attrs of an op.
559
560    Most dataset op constructors expect `output_shapes` and `output_types`
561    arguments that represent the flattened structure of an element. This helper
562    function generates these attrs as a keyword argument dictionary, allowing
563    `Dataset._variant_tensor` implementations to pass `**self._flat_structure`
564    to the op constructor.
565
566    Returns:
567      A dictionary of keyword arguments that can be passed to a dataset op
568      constructor.
569    """
570    return {
571        "output_shapes": self._flat_shapes,
572        "output_types": self._flat_types,
573    }
574
575  @property
576  def _type_spec(self):
577    return DatasetSpec(self.element_spec)
578
579  @staticmethod
580  def from_tensors(tensors):
581    """Creates a `Dataset` with a single element, comprising the given tensors.
582
583    `from_tensors` produces a dataset containing only a single element. To slice
584    the input tensor into multiple elements, use `from_tensor_slices` instead.
585
586    >>> dataset = tf.data.Dataset.from_tensors([1, 2, 3])
587    >>> list(dataset.as_numpy_iterator())
588    [array([1, 2, 3], dtype=int32)]
589    >>> dataset = tf.data.Dataset.from_tensors(([1, 2, 3], 'A'))
590    >>> list(dataset.as_numpy_iterator())
591    [(array([1, 2, 3], dtype=int32), b'A')]
592
593    >>> # You can use `from_tensors` to produce a dataset which repeats
594    >>> # the same example many times.
595    >>> example = tf.constant([1,2,3])
596    >>> dataset = tf.data.Dataset.from_tensors(example).repeat(2)
597    >>> list(dataset.as_numpy_iterator())
598    [array([1, 2, 3], dtype=int32), array([1, 2, 3], dtype=int32)]
599
600    Note that if `tensors` contains a NumPy array, and eager execution is not
601    enabled, the values will be embedded in the graph as one or more
602    `tf.constant` operations. For large datasets (> 1 GB), this can waste
603    memory and run into byte limits of graph serialization. If `tensors`
604    contains one or more large NumPy arrays, consider the alternative described
605    in [this
606    guide](https://tensorflow.org/guide/data#consuming_numpy_arrays).
607
608    Args:
609      tensors: A dataset element.
610
611    Returns:
612      Dataset: A `Dataset`.
613    """
614    return TensorDataset(tensors)
615
616  @staticmethod
617  def from_tensor_slices(tensors):
618    """Creates a `Dataset` whose elements are slices of the given tensors.
619
620    The given tensors are sliced along their first dimension. This operation
621    preserves the structure of the input tensors, removing the first dimension
622    of each tensor and using it as the dataset dimension. All input tensors
623    must have the same size in their first dimensions.
624
625    >>> # Slicing a 1D tensor produces scalar tensor elements.
626    >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
627    >>> list(dataset.as_numpy_iterator())
628    [1, 2, 3]
629
630    >>> # Slicing a 2D tensor produces 1D tensor elements.
631    >>> dataset = tf.data.Dataset.from_tensor_slices([[1, 2], [3, 4]])
632    >>> list(dataset.as_numpy_iterator())
633    [array([1, 2], dtype=int32), array([3, 4], dtype=int32)]
634
635    >>> # Slicing a tuple of 1D tensors produces tuple elements containing
636    >>> # scalar tensors.
637    >>> dataset = tf.data.Dataset.from_tensor_slices(([1, 2], [3, 4], [5, 6]))
638    >>> list(dataset.as_numpy_iterator())
639    [(1, 3, 5), (2, 4, 6)]
640
641    >>> # Dictionary structure is also preserved.
642    >>> dataset = tf.data.Dataset.from_tensor_slices({"a": [1, 2], "b": [3, 4]})
643    >>> list(dataset.as_numpy_iterator()) == [{'a': 1, 'b': 3},
644    ...                                       {'a': 2, 'b': 4}]
645    True
646
647    >>> # Two tensors can be combined into one Dataset object.
648    >>> features = tf.constant([[1, 3], [2, 1], [3, 3]]) # ==> 3x2 tensor
649    >>> labels = tf.constant(['A', 'B', 'A']) # ==> 3x1 tensor
650    >>> dataset = Dataset.from_tensor_slices((features, labels))
651    >>> # Both the features and the labels tensors can be converted
652    >>> # to a Dataset object separately and combined after.
653    >>> features_dataset = Dataset.from_tensor_slices(features)
654    >>> labels_dataset = Dataset.from_tensor_slices(labels)
655    >>> dataset = Dataset.zip((features_dataset, labels_dataset))
656    >>> # A batched feature and label set can be converted to a Dataset
657    >>> # in similar fashion.
658    >>> batched_features = tf.constant([[[1, 3], [2, 3]],
659    ...                                 [[2, 1], [1, 2]],
660    ...                                 [[3, 3], [3, 2]]], shape=(3, 2, 2))
661    >>> batched_labels = tf.constant([['A', 'A'],
662    ...                               ['B', 'B'],
663    ...                               ['A', 'B']], shape=(3, 2, 1))
664    >>> dataset = Dataset.from_tensor_slices((batched_features, batched_labels))
665    >>> for element in dataset.as_numpy_iterator():
666    ...   print(element)
667    (array([[1, 3],
668           [2, 3]], dtype=int32), array([[b'A'],
669           [b'A']], dtype=object))
670    (array([[2, 1],
671           [1, 2]], dtype=int32), array([[b'B'],
672           [b'B']], dtype=object))
673    (array([[3, 3],
674           [3, 2]], dtype=int32), array([[b'A'],
675           [b'B']], dtype=object))
676
677    Note that if `tensors` contains a NumPy array, and eager execution is not
678    enabled, the values will be embedded in the graph as one or more
679    `tf.constant` operations. For large datasets (> 1 GB), this can waste
680    memory and run into byte limits of graph serialization. If `tensors`
681    contains one or more large NumPy arrays, consider the alternative described
682    in [this guide](
683    https://tensorflow.org/guide/data#consuming_numpy_arrays).
684
685    Args:
686      tensors: A dataset element, with each component having the same size in
687        the first dimension.
688
689    Returns:
690      Dataset: A `Dataset`.
691    """
692    return TensorSliceDataset(tensors)
693
694  class _GeneratorState(object):
695    """Stores outstanding iterators created from a Python generator.
696
697    This class keeps track of potentially multiple iterators that may have
698    been created from a generator, e.g. in the case that the dataset is
699    repeated, or nested within a parallel computation.
700    """
701
702    def __init__(self, generator):
703      self._generator = generator
704      self._lock = threading.Lock()
705      self._next_id = 0  # GUARDED_BY(self._lock)
706      self._args = {}
707      self._iterators = {}
708
709    def get_next_id(self, *args):
710      with self._lock:
711        ret = self._next_id
712        self._next_id += 1
713      self._args[ret] = args
714      # NOTE(mrry): Explicitly create an array of `np.int64` because implicit
715      # casting in `py_func()` will create an array of `np.int32` on Windows,
716      # leading to a runtime error.
717      return np.array(ret, dtype=np.int64)
718
719    def get_iterator(self, iterator_id):
720      try:
721        return self._iterators[iterator_id]
722      except KeyError:
723        iterator = iter(self._generator(*self._args.pop(iterator_id)))
724        self._iterators[iterator_id] = iterator
725        return iterator
726
727    def iterator_completed(self, iterator_id):
728      del self._iterators[iterator_id]
729
730  @staticmethod
731  @deprecation.deprecated_args(None, "Use output_signature instead",
732                               "output_types", "output_shapes")
733  def from_generator(generator,
734                     output_types=None,
735                     output_shapes=None,
736                     args=None,
737                     output_signature=None):
738    """Creates a `Dataset` whose elements are generated by `generator`.
739
740    The `generator` argument must be a callable object that returns
741    an object that supports the `iter()` protocol (e.g. a generator function).
742
743    The elements generated by `generator` must be compatible with either the
744    given `output_signature` argument or with the given `output_types` and
745    (optionally) `output_shapes` arguments, whichever was specified.
746
747    The recommended way to call `from_generator` is to use the
748    `output_signature` argument. In this case the output will be assumed to
749    consist of objects with the classes, shapes and types defined by
750    `tf.TypeSpec` objects from `output_signature` argument:
751
752    >>> def gen():
753    ...   ragged_tensor = tf.ragged.constant([[1, 2], [3]])
754    ...   yield 42, ragged_tensor
755    >>>
756    >>> dataset = tf.data.Dataset.from_generator(
757    ...      gen,
758    ...      output_signature=(
759    ...          tf.TensorSpec(shape=(), dtype=tf.int32),
760    ...          tf.RaggedTensorSpec(shape=(2, None), dtype=tf.int32)))
761    >>>
762    >>> list(dataset.take(1))
763    [(<tf.Tensor: shape=(), dtype=int32, numpy=42>,
764    <tf.RaggedTensor [[1, 2], [3]]>)]
765
766    There is also a deprecated way to call `from_generator` by either with
767    `output_types` argument alone or together with `output_shapes` argument.
768    In this case the output of the function will be assumed to consist of
769    `tf.Tensor` objects with the types defined by `output_types` and with the
770    shapes which are either unknown or defined by `output_shapes`.
771
772    Note: The current implementation of `Dataset.from_generator()` uses
773    `tf.numpy_function` and inherits the same constraints. In particular, it
774    requires the dataset and iterator related operations to be placed
775    on a device in the same process as the Python program that called
776    `Dataset.from_generator()`. The body of `generator` will not be
777    serialized in a `GraphDef`, and you should not use this method if you
778    need to serialize your model and restore it in a different environment.
779
780    Note: If `generator` depends on mutable global variables or other external
781    state, be aware that the runtime may invoke `generator` multiple times
782    (in order to support repeating the `Dataset`) and at any time
783    between the call to `Dataset.from_generator()` and the production of the
784    first element from the generator. Mutating global variables or external
785    state can cause undefined behavior, and we recommend that you explicitly
786    cache any external state in `generator` before calling
787    `Dataset.from_generator()`.
788
789    Args:
790      generator: A callable object that returns an object that supports the
791        `iter()` protocol. If `args` is not specified, `generator` must take no
792        arguments; otherwise it must take as many arguments as there are values
793        in `args`.
794      output_types: (Optional.) A nested structure of `tf.DType` objects
795        corresponding to each component of an element yielded by `generator`.
796      output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects
797        corresponding to each component of an element yielded by `generator`.
798      args: (Optional.) A tuple of `tf.Tensor` objects that will be evaluated
799        and passed to `generator` as NumPy-array arguments.
800      output_signature: (Optional.) A nested structure of `tf.TypeSpec` objects
801        corresponding to each component of an element yielded by `generator`.
802
803    Returns:
804      Dataset: A `Dataset`.
805    """
806    if not callable(generator):
807      raise TypeError("`generator` must be callable.")
808
809    if output_signature is not None:
810      if output_types is not None:
811        raise TypeError("`output_types` can not be used together with "
812                        "`output_signature`")
813      if output_shapes is not None:
814        raise TypeError("`output_shapes` can not be used together with "
815                        "`output_signature`")
816      if not all(
817          isinstance(_, type_spec.TypeSpec)
818          for _ in nest.flatten(output_signature)):
819        raise TypeError("All the elements of `output_signature` must be "
820                        "`tf.TypeSpec` objects.")
821    else:
822      if output_types is None:
823        raise TypeError("Either `output_signature` or `output_types` must "
824                        "be specified")
825
826    if output_signature is None:
827      if output_shapes is None:
828        output_shapes = nest.map_structure(
829            lambda _: tensor_shape.TensorShape(None), output_types)
830      else:
831        output_shapes = nest.map_structure_up_to(output_types,
832                                                 tensor_shape.as_shape,
833                                                 output_shapes)
834      output_signature = nest.map_structure_up_to(output_types,
835                                                  tensor_spec.TensorSpec,
836                                                  output_shapes, output_types)
837    if all([
838        isinstance(x, tensor_spec.TensorSpec)
839        for x in nest.flatten(output_signature)
840    ]):
841      output_types = nest.pack_sequence_as(
842          output_signature, [x.dtype for x in nest.flatten(output_signature)])
843      output_shapes = nest.pack_sequence_as(
844          output_signature, [x.shape for x in nest.flatten(output_signature)])
845
846    if args is None:
847      args = ()
848    else:
849      args = tuple(ops.convert_n_to_tensor(args, name="args"))
850
851    generator_state = DatasetV2._GeneratorState(generator)
852
853    def get_iterator_id_fn(unused_dummy):
854      """Creates a unique `iterator_id` for each pass over the dataset.
855
856      The returned `iterator_id` disambiguates between multiple concurrently
857      existing iterators.
858
859      Args:
860        unused_dummy: Ignored value.
861
862      Returns:
863        A `tf.int64` tensor whose value uniquely identifies an iterator in
864        `generator_state`.
865      """
866      return script_ops.numpy_function(generator_state.get_next_id, args,
867                                       dtypes.int64)
868
869    def generator_next_fn(iterator_id_t):
870      """Generates the next element from iterator with ID `iterator_id_t`.
871
872      We map this function across an infinite repetition of the
873      `iterator_id_t`, and raise `StopIteration` to terminate the iteration.
874
875      Args:
876        iterator_id_t: A `tf.int64` tensor whose value uniquely identifies the
877          iterator in `generator_state` from which to generate an element.
878
879      Returns:
880        The next element to generate from the iterator.
881      """
882      if output_types and output_shapes:
883        flattened_types = [
884            dtypes.as_dtype(dt) for dt in nest.flatten(output_types)
885        ]
886        flattened_shapes = nest.flatten(output_shapes)
887
888        def generator_py_func(iterator_id):
889          """A `py_func` that will be called to invoke the iterator."""
890          # `next()` raises `StopIteration` when there are no more
891          # elements remaining to be generated.
892          values = next(generator_state.get_iterator(iterator_id))
893
894          # Use the same _convert function from the py_func() implementation to
895          # convert the returned values to arrays early, so that we can inspect
896          # their values.
897          try:
898            flattened_values = nest.flatten_up_to(output_types, values)
899          except (TypeError, ValueError):
900            six.reraise(
901                TypeError,
902                TypeError(
903                    "`generator` yielded an element that did not match the "
904                    "expected structure. The expected structure was %s, but "
905                    "the yielded element was %s." % (output_types, values)),
906                sys.exc_info()[2])
907          ret_arrays = []
908          for ret, dtype in zip(flattened_values, flattened_types):
909            try:
910              ret_arrays.append(
911                  script_ops.FuncRegistry._convert(  # pylint: disable=protected-access
912                      ret,
913                      dtype=dtype.as_numpy_dtype))
914            except (TypeError, ValueError):
915              six.reraise(
916                  TypeError,
917                  TypeError(
918                      "`generator` yielded an element that could not be "
919                      "converted to the expected type. The expected type was "
920                      "%s, but the yielded element was %s." %
921                      (dtype.name, ret)),
922                  sys.exc_info()[2])
923
924          # Additional type and shape checking to ensure that the components of
925          # the generated element match the `output_types` and `output_shapes`
926          # arguments.
927          for (ret_array, expected_dtype,
928               expected_shape) in zip(ret_arrays, flattened_types,
929                                      flattened_shapes):
930            if ret_array.dtype != expected_dtype.as_numpy_dtype:
931              raise TypeError(
932                  "`generator` yielded an element of type %s where an element "
933                  "of type %s was expected." %
934                  (ret_array.dtype, expected_dtype.as_numpy_dtype))
935            if not expected_shape.is_compatible_with(ret_array.shape):
936              raise ValueError(
937                  "`generator` yielded an element of shape %s where an element "
938                  "of shape %s was expected." %
939                  (ret_array.shape, expected_shape))
940
941          return ret_arrays
942
943        flat_values = script_ops.numpy_function(generator_py_func,
944                                                [iterator_id_t],
945                                                flattened_types)
946
947        # The `py_func()` op drops the inferred shapes, so we add them back in
948        # here.
949        if output_shapes is not None:
950          for ret_t, shape in zip(flat_values, flattened_shapes):
951            ret_t.set_shape(shape)
952
953        return nest.pack_sequence_as(output_types, flat_values)
954      else:
955        flat_output_types = structure.get_flat_tensor_types(output_signature)
956
957        def generator_py_func(iterator_id):
958          """A `py_func` that will be called to invoke the iterator."""
959          # `next()` raises `StopIteration` when there are no more
960          # elements remaining to be generated.
961          values = next(generator_state.get_iterator(iterator_id.numpy()))
962
963          try:
964            values = structure.normalize_element(values, output_signature)
965          except (TypeError, ValueError):
966            six.reraise(
967                TypeError,
968                TypeError(
969                    "`generator` yielded an element that did not match the "
970                    "expected structure. The expected structure was %s, but "
971                    "the yielded element was %s." % (output_signature, values)),
972                sys.exc_info()[2])
973
974          values_spec = structure.type_spec_from_value(values)
975
976          if not structure.are_compatible(values_spec, output_signature):
977            raise TypeError(
978                "`generator` yielded an element of %s where an element "
979                "of %s was expected." % (values_spec, output_signature))
980
981          return structure.to_tensor_list(output_signature, values)
982
983        return script_ops._eager_py_func(  # pylint: disable=protected-access
984            generator_py_func,
985            inp=[iterator_id_t],
986            Tout=flat_output_types,
987            use_tape_cache=False)
988
989    def finalize_fn(iterator_id_t):
990      """Releases host-side state for the iterator with ID `iterator_id_t`."""
991
992      def finalize_py_func(iterator_id):
993        generator_state.iterator_completed(iterator_id)
994        # We return a dummy value so that the `finalize_fn` has a valid
995        # signature.
996        # NOTE(mrry): Explicitly create an array of `np.int64` because implicit
997        # casting in `py_func()` will create an array of `np.int32` on Windows,
998        # leading to a runtime error.
999        return np.array(0, dtype=np.int64)
1000
1001      return script_ops.numpy_function(finalize_py_func, [iterator_id_t],
1002                                       dtypes.int64)
1003
1004    # This function associates each traversal of `generator` with a unique
1005    # iterator ID.
1006    def flat_map_fn(dummy_arg):
1007      # The `get_iterator_id_fn` gets a unique ID for the current instance of
1008      # of the generator.
1009      # The `generator_next_fn` gets the next element from the iterator with the
1010      # given ID, and raises StopIteration when that iterator contains no
1011      # more elements.
1012      return _GeneratorDataset(dummy_arg, get_iterator_id_fn, generator_next_fn,
1013                               finalize_fn, output_signature)
1014
1015    # A single-element dataset that, each time it is evaluated, contains a
1016    # freshly-generated and unique (for the returned dataset) int64
1017    # ID that will be used to identify the appropriate Python state, which
1018    # is encapsulated in `generator_state`, and captured in
1019    # `get_iterator_id_map_fn`.
1020    dummy = 0
1021    id_dataset = Dataset.from_tensors(dummy)
1022
1023    # A dataset that contains all of the elements generated by a
1024    # single iterator created from `generator`, identified by the
1025    # iterator ID contained in `id_dataset`. Lifting the iteration
1026    # into a flat_map here enables multiple repetitions and/or nested
1027    # versions of the returned dataset to be created, because it forces
1028    # the generation of a new ID for each version.
1029    return id_dataset.flat_map(flat_map_fn)
1030
1031  @staticmethod
1032  def range(*args, **kwargs):
1033    """Creates a `Dataset` of a step-separated range of values.
1034
1035    >>> list(Dataset.range(5).as_numpy_iterator())
1036    [0, 1, 2, 3, 4]
1037    >>> list(Dataset.range(2, 5).as_numpy_iterator())
1038    [2, 3, 4]
1039    >>> list(Dataset.range(1, 5, 2).as_numpy_iterator())
1040    [1, 3]
1041    >>> list(Dataset.range(1, 5, -2).as_numpy_iterator())
1042    []
1043    >>> list(Dataset.range(5, 1).as_numpy_iterator())
1044    []
1045    >>> list(Dataset.range(5, 1, -2).as_numpy_iterator())
1046    [5, 3]
1047    >>> list(Dataset.range(2, 5, output_type=tf.int32).as_numpy_iterator())
1048    [2, 3, 4]
1049    >>> list(Dataset.range(1, 5, 2, output_type=tf.float32).as_numpy_iterator())
1050    [1.0, 3.0]
1051
1052    Args:
1053      *args: follows the same semantics as python's xrange.
1054        len(args) == 1 -> start = 0, stop = args[0], step = 1.
1055        len(args) == 2 -> start = args[0], stop = args[1], step = 1.
1056        len(args) == 3 -> start = args[0], stop = args[1], step = args[2].
1057      **kwargs:
1058        - output_type: Its expected dtype. (Optional, default: `tf.int64`).
1059
1060    Returns:
1061      Dataset: A `RangeDataset`.
1062
1063    Raises:
1064      ValueError: if len(args) == 0.
1065    """
1066    return RangeDataset(*args, **kwargs)
1067
1068  @staticmethod
1069  def zip(datasets):
1070    """Creates a `Dataset` by zipping together the given datasets.
1071
1072    This method has similar semantics to the built-in `zip()` function
1073    in Python, with the main difference being that the `datasets`
1074    argument can be an arbitrary nested structure of `Dataset` objects.
1075
1076    >>> # The nested structure of the `datasets` argument determines the
1077    >>> # structure of elements in the resulting dataset.
1078    >>> a = tf.data.Dataset.range(1, 4)  # ==> [ 1, 2, 3 ]
1079    >>> b = tf.data.Dataset.range(4, 7)  # ==> [ 4, 5, 6 ]
1080    >>> ds = tf.data.Dataset.zip((a, b))
1081    >>> list(ds.as_numpy_iterator())
1082    [(1, 4), (2, 5), (3, 6)]
1083    >>> ds = tf.data.Dataset.zip((b, a))
1084    >>> list(ds.as_numpy_iterator())
1085    [(4, 1), (5, 2), (6, 3)]
1086    >>>
1087    >>> # The `datasets` argument may contain an arbitrary number of datasets.
1088    >>> c = tf.data.Dataset.range(7, 13).batch(2)  # ==> [ [7, 8],
1089    ...                                            #       [9, 10],
1090    ...                                            #       [11, 12] ]
1091    >>> ds = tf.data.Dataset.zip((a, b, c))
1092    >>> for element in ds.as_numpy_iterator():
1093    ...   print(element)
1094    (1, 4, array([7, 8]))
1095    (2, 5, array([ 9, 10]))
1096    (3, 6, array([11, 12]))
1097    >>>
1098    >>> # The number of elements in the resulting dataset is the same as
1099    >>> # the size of the smallest dataset in `datasets`.
1100    >>> d = tf.data.Dataset.range(13, 15)  # ==> [ 13, 14 ]
1101    >>> ds = tf.data.Dataset.zip((a, d))
1102    >>> list(ds.as_numpy_iterator())
1103    [(1, 13), (2, 14)]
1104
1105    Args:
1106      datasets: A nested structure of datasets.
1107
1108    Returns:
1109      Dataset: A `Dataset`.
1110    """
1111    return ZipDataset(datasets)
1112
1113  def concatenate(self, dataset):
1114    """Creates a `Dataset` by concatenating the given dataset with this dataset.
1115
1116    >>> a = tf.data.Dataset.range(1, 4)  # ==> [ 1, 2, 3 ]
1117    >>> b = tf.data.Dataset.range(4, 8)  # ==> [ 4, 5, 6, 7 ]
1118    >>> ds = a.concatenate(b)
1119    >>> list(ds.as_numpy_iterator())
1120    [1, 2, 3, 4, 5, 6, 7]
1121    >>> # The input dataset and dataset to be concatenated should have the same
1122    >>> # nested structures and output types.
1123    >>> c = tf.data.Dataset.zip((a, b))
1124    >>> a.concatenate(c)
1125    Traceback (most recent call last):
1126    TypeError: Two datasets to concatenate have different types
1127    <dtype: 'int64'> and (tf.int64, tf.int64)
1128    >>> d = tf.data.Dataset.from_tensor_slices(["a", "b", "c"])
1129    >>> a.concatenate(d)
1130    Traceback (most recent call last):
1131    TypeError: Two datasets to concatenate have different types
1132    <dtype: 'int64'> and <dtype: 'string'>
1133
1134    Args:
1135      dataset: `Dataset` to be concatenated.
1136
1137    Returns:
1138      Dataset: A `Dataset`.
1139    """
1140    return ConcatenateDataset(self, dataset)
1141
1142  def prefetch(self, buffer_size):
1143    """Creates a `Dataset` that prefetches elements from this dataset.
1144
1145    Most dataset input pipelines should end with a call to `prefetch`. This
1146    allows later elements to be prepared while the current element is being
1147    processed. This often improves latency and throughput, at the cost of
1148    using additional memory to store prefetched elements.
1149
1150    Note: Like other `Dataset` methods, prefetch operates on the
1151    elements of the input dataset. It has no concept of examples vs. batches.
1152    `examples.prefetch(2)` will prefetch two elements (2 examples),
1153    while `examples.batch(20).prefetch(2)` will prefetch 2 elements
1154    (2 batches, of 20 examples each).
1155
1156    >>> dataset = tf.data.Dataset.range(3)
1157    >>> dataset = dataset.prefetch(2)
1158    >>> list(dataset.as_numpy_iterator())
1159    [0, 1, 2]
1160
1161    Args:
1162      buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the maximum
1163        number of elements that will be buffered when prefetching. If the value
1164        `tf.data.AUTOTUNE` is used, then the buffer size is dynamically tuned.
1165    Returns:
1166      Dataset: A `Dataset`.
1167    """
1168    return PrefetchDataset(self, buffer_size)
1169
1170  @staticmethod
1171  def list_files(file_pattern, shuffle=None, seed=None):
1172    """A dataset of all files matching one or more glob patterns.
1173
1174    The `file_pattern` argument should be a small number of glob patterns.
1175    If your filenames have already been globbed, use
1176    `Dataset.from_tensor_slices(filenames)` instead, as re-globbing every
1177    filename with `list_files` may result in poor performance with remote
1178    storage systems.
1179
1180    Note: The default behavior of this method is to return filenames in
1181    a non-deterministic random shuffled order. Pass a `seed` or `shuffle=False`
1182    to get results in a deterministic order.
1183
1184    Example:
1185      If we had the following files on our filesystem:
1186
1187        - /path/to/dir/a.txt
1188        - /path/to/dir/b.py
1189        - /path/to/dir/c.py
1190
1191      If we pass "/path/to/dir/*.py" as the directory, the dataset
1192      would produce:
1193
1194        - /path/to/dir/b.py
1195        - /path/to/dir/c.py
1196
1197    Args:
1198      file_pattern: A string, a list of strings, or a `tf.Tensor` of string type
1199        (scalar or vector), representing the filename glob (i.e. shell wildcard)
1200        pattern(s) that will be matched.
1201      shuffle: (Optional.) If `True`, the file names will be shuffled randomly.
1202        Defaults to `True`.
1203      seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random
1204        seed that will be used to create the distribution. See
1205        `tf.random.set_seed` for behavior.
1206
1207    Returns:
1208     Dataset: A `Dataset` of strings corresponding to file names.
1209    """
1210    with ops.name_scope("list_files"):
1211      if shuffle is None:
1212        shuffle = True
1213      file_pattern = ops.convert_to_tensor(
1214          file_pattern, dtype=dtypes.string, name="file_pattern")
1215      matching_files = gen_io_ops.matching_files(file_pattern)
1216
1217      # Raise an exception if `file_pattern` does not match any files.
1218      condition = math_ops.greater(array_ops.shape(matching_files)[0], 0,
1219                                   name="match_not_empty")
1220
1221      message = math_ops.add(
1222          "No files matched pattern: ",
1223          string_ops.reduce_join(file_pattern, separator=", "), name="message")
1224
1225      assert_not_empty = control_flow_ops.Assert(
1226          condition, [message], summarize=1, name="assert_not_empty")
1227      with ops.control_dependencies([assert_not_empty]):
1228        matching_files = array_ops.identity(matching_files)
1229
1230      dataset = Dataset.from_tensor_slices(matching_files)
1231      if shuffle:
1232        # NOTE(mrry): The shuffle buffer size must be greater than zero, but the
1233        # list of files might be empty.
1234        buffer_size = math_ops.maximum(
1235            array_ops.shape(matching_files, out_type=dtypes.int64)[0], 1)
1236        dataset = dataset.shuffle(buffer_size, seed=seed)
1237      return dataset
1238
1239  def repeat(self, count=None):
1240    """Repeats this dataset so each original value is seen `count` times.
1241
1242    >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
1243    >>> dataset = dataset.repeat(3)
1244    >>> list(dataset.as_numpy_iterator())
1245    [1, 2, 3, 1, 2, 3, 1, 2, 3]
1246
1247    Note: If this dataset is a function of global state (e.g. a random number
1248    generator), then different repetitions may produce different elements.
1249
1250    Args:
1251      count: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
1252        number of times the dataset should be repeated. The default behavior (if
1253        `count` is `None` or `-1`) is for the dataset be repeated indefinitely.
1254
1255    Returns:
1256      Dataset: A `Dataset`.
1257    """
1258    return RepeatDataset(self, count)
1259
1260  def enumerate(self, start=0):
1261    """Enumerates the elements of this dataset.
1262
1263    It is similar to python's `enumerate`.
1264
1265    >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
1266    >>> dataset = dataset.enumerate(start=5)
1267    >>> for element in dataset.as_numpy_iterator():
1268    ...   print(element)
1269    (5, 1)
1270    (6, 2)
1271    (7, 3)
1272
1273    >>> # The nested structure of the input dataset determines the structure of
1274    >>> # elements in the resulting dataset.
1275    >>> dataset = tf.data.Dataset.from_tensor_slices([(7, 8), (9, 10)])
1276    >>> dataset = dataset.enumerate()
1277    >>> for element in dataset.as_numpy_iterator():
1278    ...   print(element)
1279    (0, array([7, 8], dtype=int32))
1280    (1, array([ 9, 10], dtype=int32))
1281
1282    Args:
1283      start: A `tf.int64` scalar `tf.Tensor`, representing the start value for
1284        enumeration.
1285
1286    Returns:
1287      Dataset: A `Dataset`.
1288    """
1289
1290    max_value = np.iinfo(dtypes.int64.as_numpy_dtype).max
1291    return Dataset.zip((Dataset.range(start, max_value), self))
1292
1293  def shuffle(self, buffer_size, seed=None, reshuffle_each_iteration=None):
1294    """Randomly shuffles the elements of this dataset.
1295
1296    This dataset fills a buffer with `buffer_size` elements, then randomly
1297    samples elements from this buffer, replacing the selected elements with new
1298    elements. For perfect shuffling, a buffer size greater than or equal to the
1299    full size of the dataset is required.
1300
1301    For instance, if your dataset contains 10,000 elements but `buffer_size` is
1302    set to 1,000, then `shuffle` will initially select a random element from
1303    only the first 1,000 elements in the buffer. Once an element is selected,
1304    its space in the buffer is replaced by the next (i.e. 1,001-st) element,
1305    maintaining the 1,000 element buffer.
1306
1307    `reshuffle_each_iteration` controls whether the shuffle order should be
1308    different for each epoch. In TF 1.X, the idiomatic way to create epochs
1309    was through the `repeat` transformation:
1310
1311    >>> dataset = tf.data.Dataset.range(3)
1312    >>> dataset = dataset.shuffle(3, reshuffle_each_iteration=True)
1313    >>> dataset = dataset.repeat(2)  # doctest: +SKIP
1314    [1, 0, 2, 1, 2, 0]
1315
1316    >>> dataset = tf.data.Dataset.range(3)
1317    >>> dataset = dataset.shuffle(3, reshuffle_each_iteration=False)
1318    >>> dataset = dataset.repeat(2)  # doctest: +SKIP
1319    [1, 0, 2, 1, 0, 2]
1320
1321    In TF 2.0, `tf.data.Dataset` objects are Python iterables which makes it
1322    possible to also create epochs through Python iteration:
1323
1324    >>> dataset = tf.data.Dataset.range(3)
1325    >>> dataset = dataset.shuffle(3, reshuffle_each_iteration=True)
1326    >>> list(dataset.as_numpy_iterator())  # doctest: +SKIP
1327    [1, 0, 2]
1328    >>> list(dataset.as_numpy_iterator())  # doctest: +SKIP
1329    [1, 2, 0]
1330
1331    >>> dataset = tf.data.Dataset.range(3)
1332    >>> dataset = dataset.shuffle(3, reshuffle_each_iteration=False)
1333    >>> list(dataset.as_numpy_iterator())  # doctest: +SKIP
1334    [1, 0, 2]
1335    >>> list(dataset.as_numpy_iterator())  # doctest: +SKIP
1336    [1, 0, 2]
1337
1338    Args:
1339      buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
1340        elements from this dataset from which the new dataset will sample.
1341      seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random
1342        seed that will be used to create the distribution. See
1343        `tf.random.set_seed` for behavior.
1344      reshuffle_each_iteration: (Optional.) A boolean, which if true indicates
1345        that the dataset should be pseudorandomly reshuffled each time it is
1346        iterated over. (Defaults to `True`.)
1347
1348    Returns:
1349      Dataset: A `Dataset`.
1350    """
1351    return ShuffleDataset(self, buffer_size, seed, reshuffle_each_iteration)
1352
1353  def cache(self, filename=""):
1354    """Caches the elements in this dataset.
1355
1356    The first time the dataset is iterated over, its elements will be cached
1357    either in the specified file or in memory. Subsequent iterations will
1358    use the cached data.
1359
1360    Note: For the cache to be finalized, the input dataset must be iterated
1361    through in its entirety. Otherwise, subsequent iterations will not use
1362    cached data.
1363
1364    >>> dataset = tf.data.Dataset.range(5)
1365    >>> dataset = dataset.map(lambda x: x**2)
1366    >>> dataset = dataset.cache()
1367    >>> # The first time reading through the data will generate the data using
1368    >>> # `range` and `map`.
1369    >>> list(dataset.as_numpy_iterator())
1370    [0, 1, 4, 9, 16]
1371    >>> # Subsequent iterations read from the cache.
1372    >>> list(dataset.as_numpy_iterator())
1373    [0, 1, 4, 9, 16]
1374
1375    When caching to a file, the cached data will persist across runs. Even the
1376    first iteration through the data will read from the cache file. Changing
1377    the input pipeline before the call to `.cache()` will have no effect until
1378    the cache file is removed or the filename is changed.
1379
1380    >>> dataset = tf.data.Dataset.range(5)
1381    >>> dataset = dataset.cache("/path/to/file")  # doctest: +SKIP
1382    >>> list(dataset.as_numpy_iterator())  # doctest: +SKIP
1383    [0, 1, 2, 3, 4]
1384    >>> dataset = tf.data.Dataset.range(10)
1385    >>> dataset = dataset.cache("/path/to/file")  # Same file! # doctest: +SKIP
1386    >>> list(dataset.as_numpy_iterator())  # doctest: +SKIP
1387    [0, 1, 2, 3, 4]
1388
1389    Note: `cache` will produce exactly the same elements during each iteration
1390    through the dataset. If you wish to randomize the iteration order, make sure
1391    to call `shuffle` *after* calling `cache`.
1392
1393    Args:
1394      filename: A `tf.string` scalar `tf.Tensor`, representing the name of a
1395        directory on the filesystem to use for caching elements in this Dataset.
1396        If a filename is not provided, the dataset will be cached in memory.
1397
1398    Returns:
1399      Dataset: A `Dataset`.
1400    """
1401    return CacheDataset(self, filename)
1402
1403  def take(self, count):
1404    """Creates a `Dataset` with at most `count` elements from this dataset.
1405
1406    >>> dataset = tf.data.Dataset.range(10)
1407    >>> dataset = dataset.take(3)
1408    >>> list(dataset.as_numpy_iterator())
1409    [0, 1, 2]
1410
1411    Args:
1412      count: A `tf.int64` scalar `tf.Tensor`, representing the number of
1413        elements of this dataset that should be taken to form the new dataset.
1414        If `count` is -1, or if `count` is greater than the size of this
1415        dataset, the new dataset will contain all elements of this dataset.
1416
1417    Returns:
1418      Dataset: A `Dataset`.
1419    """
1420    return TakeDataset(self, count)
1421
1422  def skip(self, count):
1423    """Creates a `Dataset` that skips `count` elements from this dataset.
1424
1425    >>> dataset = tf.data.Dataset.range(10)
1426    >>> dataset = dataset.skip(7)
1427    >>> list(dataset.as_numpy_iterator())
1428    [7, 8, 9]
1429
1430    Args:
1431      count: A `tf.int64` scalar `tf.Tensor`, representing the number of
1432        elements of this dataset that should be skipped to form the new dataset.
1433        If `count` is greater than the size of this dataset, the new dataset
1434        will contain no elements.  If `count` is -1, skips the entire dataset.
1435
1436    Returns:
1437      Dataset: A `Dataset`.
1438    """
1439    return SkipDataset(self, count)
1440
1441  def shard(self, num_shards, index):
1442    """Creates a `Dataset` that includes only 1/`num_shards` of this dataset.
1443
1444    `shard` is deterministic. The Dataset produced by `A.shard(n, i)` will
1445    contain all elements of A whose index mod n = i.
1446
1447    >>> A = tf.data.Dataset.range(10)
1448    >>> B = A.shard(num_shards=3, index=0)
1449    >>> list(B.as_numpy_iterator())
1450    [0, 3, 6, 9]
1451    >>> C = A.shard(num_shards=3, index=1)
1452    >>> list(C.as_numpy_iterator())
1453    [1, 4, 7]
1454    >>> D = A.shard(num_shards=3, index=2)
1455    >>> list(D.as_numpy_iterator())
1456    [2, 5, 8]
1457
1458    This dataset operator is very useful when running distributed training, as
1459    it allows each worker to read a unique subset.
1460
1461    When reading a single input file, you can shard elements as follows:
1462
1463    ```python
1464    d = tf.data.TFRecordDataset(input_file)
1465    d = d.shard(num_workers, worker_index)
1466    d = d.repeat(num_epochs)
1467    d = d.shuffle(shuffle_buffer_size)
1468    d = d.map(parser_fn, num_parallel_calls=num_map_threads)
1469    ```
1470
1471    Important caveats:
1472
1473    - Be sure to shard before you use any randomizing operator (such as
1474      shuffle).
1475    - Generally it is best if the shard operator is used early in the dataset
1476      pipeline. For example, when reading from a set of TFRecord files, shard
1477      before converting the dataset to input samples. This avoids reading every
1478      file on every worker. The following is an example of an efficient
1479      sharding strategy within a complete pipeline:
1480
1481    ```python
1482    d = Dataset.list_files(pattern)
1483    d = d.shard(num_workers, worker_index)
1484    d = d.repeat(num_epochs)
1485    d = d.shuffle(shuffle_buffer_size)
1486    d = d.interleave(tf.data.TFRecordDataset,
1487                     cycle_length=num_readers, block_length=1)
1488    d = d.map(parser_fn, num_parallel_calls=num_map_threads)
1489    ```
1490
1491    Args:
1492      num_shards: A `tf.int64` scalar `tf.Tensor`, representing the number of
1493        shards operating in parallel.
1494      index: A `tf.int64` scalar `tf.Tensor`, representing the worker index.
1495
1496    Returns:
1497      Dataset: A `Dataset`.
1498
1499    Raises:
1500      InvalidArgumentError: if `num_shards` or `index` are illegal values.
1501
1502        Note: error checking is done on a best-effort basis, and errors aren't
1503        guaranteed to be caught upon dataset creation. (e.g. providing in a
1504        placeholder tensor bypasses the early checking, and will instead result
1505        in an error during a session.run call.)
1506    """
1507    return ShardDataset(self, num_shards, index)
1508
1509  def batch(self, batch_size, drop_remainder=False, num_parallel_calls=None):
1510    """Combines consecutive elements of this dataset into batches.
1511
1512    >>> dataset = tf.data.Dataset.range(8)
1513    >>> dataset = dataset.batch(3)
1514    >>> list(dataset.as_numpy_iterator())
1515    [array([0, 1, 2]), array([3, 4, 5]), array([6, 7])]
1516
1517    >>> dataset = tf.data.Dataset.range(8)
1518    >>> dataset = dataset.batch(3, drop_remainder=True)
1519    >>> list(dataset.as_numpy_iterator())
1520    [array([0, 1, 2]), array([3, 4, 5])]
1521
1522    The components of the resulting element will have an additional outer
1523    dimension, which will be `batch_size` (or `N % batch_size` for the last
1524    element if `batch_size` does not divide the number of input elements `N`
1525    evenly and `drop_remainder` is `False`). If your program depends on the
1526    batches having the same outer dimension, you should set the `drop_remainder`
1527    argument to `True` to prevent the smaller batch from being produced.
1528
1529    Args:
1530      batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
1531        consecutive elements of this dataset to combine in a single batch.
1532      drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
1533        whether the last batch should be dropped in the case it has fewer than
1534        `batch_size` elements; the default behavior is not to drop the smaller
1535        batch.
1536      num_parallel_calls: (Optional.) A `tf.int64` scalar `tf.Tensor`,
1537        representing the number of batches to compute asynchronously in
1538        parallel.
1539        If not specified, batches will be computed sequentially. If the value
1540        `tf.data.AUTOTUNE` is used, then the number of parallel
1541        calls is set dynamically based on available resources.
1542
1543    Returns:
1544      Dataset: A `Dataset`.
1545    """
1546    if num_parallel_calls is None:
1547      return BatchDataset(self, batch_size, drop_remainder)
1548    else:
1549      return ParallelBatchDataset(self, batch_size, drop_remainder,
1550                                  num_parallel_calls)
1551
1552  def padded_batch(self,
1553                   batch_size,
1554                   padded_shapes=None,
1555                   padding_values=None,
1556                   drop_remainder=False):
1557    """Combines consecutive elements of this dataset into padded batches.
1558
1559    This transformation combines multiple consecutive elements of the input
1560    dataset into a single element.
1561
1562    Like `tf.data.Dataset.batch`, the components of the resulting element will
1563    have an additional outer dimension, which will be `batch_size` (or
1564    `N % batch_size` for the last element if `batch_size` does not divide the
1565    number of input elements `N` evenly and `drop_remainder` is `False`). If
1566    your program depends on the batches having the same outer dimension, you
1567    should set the `drop_remainder` argument to `True` to prevent the smaller
1568    batch from being produced.
1569
1570    Unlike `tf.data.Dataset.batch`, the input elements to be batched may have
1571    different shapes, and this transformation will pad each component to the
1572    respective shape in `padded_shapes`. The `padded_shapes` argument
1573    determines the resulting shape for each dimension of each component in an
1574    output element:
1575
1576    * If the dimension is a constant, the component will be padded out to that
1577      length in that dimension.
1578    * If the dimension is unknown, the component will be padded out to the
1579      maximum length of all elements in that dimension.
1580
1581    >>> A = (tf.data.Dataset
1582    ...      .range(1, 5, output_type=tf.int32)
1583    ...      .map(lambda x: tf.fill([x], x)))
1584    >>> # Pad to the smallest per-batch size that fits all elements.
1585    >>> B = A.padded_batch(2)
1586    >>> for element in B.as_numpy_iterator():
1587    ...   print(element)
1588    [[1 0]
1589     [2 2]]
1590    [[3 3 3 0]
1591     [4 4 4 4]]
1592    >>> # Pad to a fixed size.
1593    >>> C = A.padded_batch(2, padded_shapes=5)
1594    >>> for element in C.as_numpy_iterator():
1595    ...   print(element)
1596    [[1 0 0 0 0]
1597     [2 2 0 0 0]]
1598    [[3 3 3 0 0]
1599     [4 4 4 4 0]]
1600    >>> # Pad with a custom value.
1601    >>> D = A.padded_batch(2, padded_shapes=5, padding_values=-1)
1602    >>> for element in D.as_numpy_iterator():
1603    ...   print(element)
1604    [[ 1 -1 -1 -1 -1]
1605     [ 2  2 -1 -1 -1]]
1606    [[ 3  3  3 -1 -1]
1607     [ 4  4  4  4 -1]]
1608    >>> # Components of nested elements can be padded independently.
1609    >>> elements = [([1, 2, 3], [10]),
1610    ...             ([4, 5], [11, 12])]
1611    >>> dataset = tf.data.Dataset.from_generator(
1612    ...     lambda: iter(elements), (tf.int32, tf.int32))
1613    >>> # Pad the first component of the tuple to length 4, and the second
1614    >>> # component to the smallest size that fits.
1615    >>> dataset = dataset.padded_batch(2,
1616    ...     padded_shapes=([4], [None]),
1617    ...     padding_values=(-1, 100))
1618    >>> list(dataset.as_numpy_iterator())
1619    [(array([[ 1,  2,  3, -1], [ 4,  5, -1, -1]], dtype=int32),
1620      array([[ 10, 100], [ 11,  12]], dtype=int32))]
1621    >>> # Pad with a single value and multiple components.
1622    >>> E = tf.data.Dataset.zip((A, A)).padded_batch(2, padding_values=-1)
1623    >>> for element in E.as_numpy_iterator():
1624    ...   print(element)
1625    (array([[ 1, -1],
1626           [ 2,  2]], dtype=int32), array([[ 1, -1],
1627           [ 2,  2]], dtype=int32))
1628    (array([[ 3,  3,  3, -1],
1629           [ 4,  4,  4,  4]], dtype=int32), array([[ 3,  3,  3, -1],
1630           [ 4,  4,  4,  4]], dtype=int32))
1631
1632    See also `tf.data.experimental.dense_to_sparse_batch`, which combines
1633    elements that may have different shapes into a `tf.sparse.SparseTensor`.
1634
1635    Args:
1636      batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
1637        consecutive elements of this dataset to combine in a single batch.
1638      padded_shapes: (Optional.) A nested structure of `tf.TensorShape` or
1639        `tf.int64` vector tensor-like objects representing the shape to which
1640        the respective component of each input element should be padded prior
1641        to batching. Any unknown dimensions will be padded to the maximum size
1642        of that dimension in each batch. If unset, all dimensions of all
1643        components are padded to the maximum size in the batch. `padded_shapes`
1644        must be set if any component has an unknown rank.
1645      padding_values: (Optional.) A nested structure of scalar-shaped
1646        `tf.Tensor`, representing the padding values to use for the respective
1647        components. None represents that the nested structure should be padded
1648        with default values.  Defaults are `0` for numeric types and the empty
1649        string for string types. The `padding_values` should have the
1650        same structure as the input dataset. If `padding_values` is a single
1651        element and the input dataset has multiple components, then the same
1652        `padding_values` will be used to pad every component of the dataset.
1653        If `padding_values` is a scalar, then its value will be broadcasted
1654        to match the shape of each component.
1655      drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
1656        whether the last batch should be dropped in the case it has fewer than
1657        `batch_size` elements; the default behavior is not to drop the smaller
1658        batch.
1659
1660    Returns:
1661      Dataset: A `Dataset`.
1662
1663    Raises:
1664      ValueError: If a component has an unknown rank, and  the `padded_shapes`
1665        argument is not set.
1666    """
1667    if padded_shapes is None:
1668      padded_shapes = get_legacy_output_shapes(self)
1669      # A `tf.TensorShape` is only false if its *rank* is unknown:
1670      # bool(tf.TensorShape(None)) is False
1671      if not all(nest.flatten(padded_shapes)):
1672        raise ValueError("You must set the `padded_shapes` argument to "
1673                         "`Dataset.padded_batch` if any component of its "
1674                         "input has an unknown rank")
1675    return PaddedBatchDataset(self, batch_size, padded_shapes, padding_values,
1676                              drop_remainder)
1677
1678  def map(self, map_func, num_parallel_calls=None, deterministic=None):
1679    """Maps `map_func` across the elements of this dataset.
1680
1681    This transformation applies `map_func` to each element of this dataset, and
1682    returns a new dataset containing the transformed elements, in the same
1683    order as they appeared in the input. `map_func` can be used to change both
1684    the values and the structure of a dataset's elements. For example, adding 1
1685    to each element, or projecting a subset of element components.
1686
1687    >>> dataset = Dataset.range(1, 6)  # ==> [ 1, 2, 3, 4, 5 ]
1688    >>> dataset = dataset.map(lambda x: x + 1)
1689    >>> list(dataset.as_numpy_iterator())
1690    [2, 3, 4, 5, 6]
1691
1692    The input signature of `map_func` is determined by the structure of each
1693    element in this dataset.
1694
1695    >>> dataset = Dataset.range(5)
1696    >>> # `map_func` takes a single argument of type `tf.Tensor` with the same
1697    >>> # shape and dtype.
1698    >>> result = dataset.map(lambda x: x + 1)
1699
1700    >>> # Each element is a tuple containing two `tf.Tensor` objects.
1701    >>> elements = [(1, "foo"), (2, "bar"), (3, "baz")]
1702    >>> dataset = tf.data.Dataset.from_generator(
1703    ...     lambda: elements, (tf.int32, tf.string))
1704    >>> # `map_func` takes two arguments of type `tf.Tensor`. This function
1705    >>> # projects out just the first component.
1706    >>> result = dataset.map(lambda x_int, y_str: x_int)
1707    >>> list(result.as_numpy_iterator())
1708    [1, 2, 3]
1709
1710    >>> # Each element is a dictionary mapping strings to `tf.Tensor` objects.
1711    >>> elements =  ([{"a": 1, "b": "foo"},
1712    ...               {"a": 2, "b": "bar"},
1713    ...               {"a": 3, "b": "baz"}])
1714    >>> dataset = tf.data.Dataset.from_generator(
1715    ...     lambda: elements, {"a": tf.int32, "b": tf.string})
1716    >>> # `map_func` takes a single argument of type `dict` with the same keys
1717    >>> # as the elements.
1718    >>> result = dataset.map(lambda d: str(d["a"]) + d["b"])
1719
1720    The value or values returned by `map_func` determine the structure of each
1721    element in the returned dataset.
1722
1723    >>> dataset = tf.data.Dataset.range(3)
1724    >>> # `map_func` returns two `tf.Tensor` objects.
1725    >>> def g(x):
1726    ...   return tf.constant(37.0), tf.constant(["Foo", "Bar", "Baz"])
1727    >>> result = dataset.map(g)
1728    >>> result.element_spec
1729    (TensorSpec(shape=(), dtype=tf.float32, name=None), TensorSpec(shape=(3,), \
1730dtype=tf.string, name=None))
1731    >>> # Python primitives, lists, and NumPy arrays are implicitly converted to
1732    >>> # `tf.Tensor`.
1733    >>> def h(x):
1734    ...   return 37.0, ["Foo", "Bar"], np.array([1.0, 2.0], dtype=np.float64)
1735    >>> result = dataset.map(h)
1736    >>> result.element_spec
1737    (TensorSpec(shape=(), dtype=tf.float32, name=None), TensorSpec(shape=(2,), \
1738dtype=tf.string, name=None), TensorSpec(shape=(2,), dtype=tf.float64, \
1739name=None))
1740    >>> # `map_func` can return nested structures.
1741    >>> def i(x):
1742    ...   return (37.0, [42, 16]), "foo"
1743    >>> result = dataset.map(i)
1744    >>> result.element_spec
1745    ((TensorSpec(shape=(), dtype=tf.float32, name=None),
1746      TensorSpec(shape=(2,), dtype=tf.int32, name=None)),
1747     TensorSpec(shape=(), dtype=tf.string, name=None))
1748
1749    `map_func` can accept as arguments and return any type of dataset element.
1750
1751    Note that irrespective of the context in which `map_func` is defined (eager
1752    vs. graph), tf.data traces the function and executes it as a graph. To use
1753    Python code inside of the function you have a few options:
1754
1755    1) Rely on AutoGraph to convert Python code into an equivalent graph
1756    computation. The downside of this approach is that AutoGraph can convert
1757    some but not all Python code.
1758
1759    2) Use `tf.py_function`, which allows you to write arbitrary Python code but
1760    will generally result in worse performance than 1). For example:
1761
1762    >>> d = tf.data.Dataset.from_tensor_slices(['hello', 'world'])
1763    >>> # transform a string tensor to upper case string using a Python function
1764    >>> def upper_case_fn(t: tf.Tensor):
1765    ...   return t.numpy().decode('utf-8').upper()
1766    >>> d = d.map(lambda x: tf.py_function(func=upper_case_fn,
1767    ...           inp=[x], Tout=tf.string))
1768    >>> list(d.as_numpy_iterator())
1769    [b'HELLO', b'WORLD']
1770
1771    3) Use `tf.numpy_function`, which also allows you to write arbitrary
1772    Python code. Note that `tf.py_function` accepts `tf.Tensor` whereas
1773    `tf.numpy_function` accepts numpy arrays and returns only numpy arrays.
1774    For example:
1775
1776    >>> d = tf.data.Dataset.from_tensor_slices(['hello', 'world'])
1777    >>> def upper_case_fn(t: np.ndarray):
1778    ...   return t.decode('utf-8').upper()
1779    >>> d = d.map(lambda x: tf.numpy_function(func=upper_case_fn,
1780    ...           inp=[x], Tout=tf.string))
1781    >>> list(d.as_numpy_iterator())
1782    [b'HELLO', b'WORLD']
1783
1784    Note that the use of `tf.numpy_function` and `tf.py_function`
1785    in general precludes the possibility of executing user-defined
1786    transformations in parallel (because of Python GIL).
1787
1788    Performance can often be improved by setting `num_parallel_calls` so that
1789    `map` will use multiple threads to process elements. If deterministic order
1790    isn't required, it can also improve performance to set
1791    `deterministic=False`.
1792
1793    >>> dataset = Dataset.range(1, 6)  # ==> [ 1, 2, 3, 4, 5 ]
1794    >>> dataset = dataset.map(lambda x: x + 1,
1795    ...     num_parallel_calls=tf.data.AUTOTUNE,
1796    ...     deterministic=False)
1797
1798    The order of elements yielded by this transformation is deterministic if
1799    `deterministic=True`. If `map_func` contains stateful operations and
1800    `num_parallel_calls > 1`, the order in which that state is accessed is
1801    undefined, so the values of output elements may not be deterministic
1802    regardless of the `deterministic` flag value.
1803
1804    Args:
1805      map_func: A function mapping a dataset element to another dataset element.
1806      num_parallel_calls: (Optional.) A `tf.int64` scalar `tf.Tensor`,
1807        representing the number elements to process asynchronously in parallel.
1808        If not specified, elements will be processed sequentially. If the value
1809        `tf.data.AUTOTUNE` is used, then the number of parallel
1810        calls is set dynamically based on available CPU.
1811      deterministic: (Optional.) When `num_parallel_calls` is specified, this
1812        boolean controls the order in which the transformation produces
1813        elements. If set to `False`, the transformation is allowed to yield
1814        elements out of order to trade determinism for performance. If not
1815        specified, the `tf.data.Options.experimental_deterministic` option
1816        (`True` by default) controls the behavior.
1817
1818    Returns:
1819      Dataset: A `Dataset`.
1820    """
1821    if num_parallel_calls is None:
1822      if deterministic is not None:
1823        warnings.warn("The `deterministic` argument has no effect unless the "
1824                      "`num_parallel_calls` argument is specified.")
1825      return MapDataset(self, map_func, preserve_cardinality=True)
1826    else:
1827      return ParallelMapDataset(
1828          self,
1829          map_func,
1830          num_parallel_calls,
1831          deterministic,
1832          preserve_cardinality=True)
1833
1834  def flat_map(self, map_func):
1835    """Maps `map_func` across this dataset and flattens the result.
1836
1837    Use `flat_map` if you want to make sure that the order of your dataset
1838    stays the same. For example, to flatten a dataset of batches into a
1839    dataset of their elements:
1840
1841    >>> dataset = tf.data.Dataset.from_tensor_slices(
1842    ...                [[1, 2, 3], [4, 5, 6], [7, 8, 9]])
1843    >>> dataset = dataset.flat_map(lambda x: Dataset.from_tensor_slices(x))
1844    >>> list(dataset.as_numpy_iterator())
1845    [1, 2, 3, 4, 5, 6, 7, 8, 9]
1846
1847    `tf.data.Dataset.interleave()` is a generalization of `flat_map`, since
1848    `flat_map` produces the same output as
1849    `tf.data.Dataset.interleave(cycle_length=1)`
1850
1851    Args:
1852      map_func: A function mapping a dataset element to a dataset.
1853
1854    Returns:
1855      Dataset: A `Dataset`.
1856    """
1857    return FlatMapDataset(self, map_func)
1858
1859  def interleave(self,
1860                 map_func,
1861                 cycle_length=None,
1862                 block_length=None,
1863                 num_parallel_calls=None,
1864                 deterministic=None):
1865    """Maps `map_func` across this dataset, and interleaves the results.
1866
1867    For example, you can use `Dataset.interleave()` to process many input files
1868    concurrently:
1869
1870    >>> # Preprocess 4 files concurrently, and interleave blocks of 16 records
1871    >>> # from each file.
1872    >>> filenames = ["/var/data/file1.txt", "/var/data/file2.txt",
1873    ...              "/var/data/file3.txt", "/var/data/file4.txt"]
1874    >>> dataset = tf.data.Dataset.from_tensor_slices(filenames)
1875    >>> def parse_fn(filename):
1876    ...   return tf.data.Dataset.range(10)
1877    >>> dataset = dataset.interleave(lambda x:
1878    ...     tf.data.TextLineDataset(x).map(parse_fn, num_parallel_calls=1),
1879    ...     cycle_length=4, block_length=16)
1880
1881    The `cycle_length` and `block_length` arguments control the order in which
1882    elements are produced. `cycle_length` controls the number of input elements
1883    that are processed concurrently. If you set `cycle_length` to 1, this
1884    transformation will handle one input element at a time, and will produce
1885    identical results to `tf.data.Dataset.flat_map`. In general,
1886    this transformation will apply `map_func` to `cycle_length` input elements,
1887    open iterators on the returned `Dataset` objects, and cycle through them
1888    producing `block_length` consecutive elements from each iterator, and
1889    consuming the next input element each time it reaches the end of an
1890    iterator.
1891
1892    For example:
1893
1894    >>> dataset = Dataset.range(1, 6)  # ==> [ 1, 2, 3, 4, 5 ]
1895    >>> # NOTE: New lines indicate "block" boundaries.
1896    >>> dataset = dataset.interleave(
1897    ...     lambda x: Dataset.from_tensors(x).repeat(6),
1898    ...     cycle_length=2, block_length=4)
1899    >>> list(dataset.as_numpy_iterator())
1900    [1, 1, 1, 1,
1901     2, 2, 2, 2,
1902     1, 1,
1903     2, 2,
1904     3, 3, 3, 3,
1905     4, 4, 4, 4,
1906     3, 3,
1907     4, 4,
1908     5, 5, 5, 5,
1909     5, 5]
1910
1911    Note: The order of elements yielded by this transformation is
1912    deterministic, as long as `map_func` is a pure function and
1913    `deterministic=True`. If `map_func` contains any stateful operations, the
1914    order in which that state is accessed is undefined.
1915
1916    Performance can often be improved by setting `num_parallel_calls` so that
1917    `interleave` will use multiple threads to fetch elements. If determinism
1918    isn't required, it can also improve performance to set
1919    `deterministic=False`.
1920
1921    >>> filenames = ["/var/data/file1.txt", "/var/data/file2.txt",
1922    ...              "/var/data/file3.txt", "/var/data/file4.txt"]
1923    >>> dataset = tf.data.Dataset.from_tensor_slices(filenames)
1924    >>> dataset = dataset.interleave(lambda x: tf.data.TFRecordDataset(x),
1925    ...     cycle_length=4, num_parallel_calls=tf.data.AUTOTUNE,
1926    ...     deterministic=False)
1927
1928    Args:
1929      map_func: A function mapping a dataset element to a dataset.
1930      cycle_length: (Optional.) The number of input elements that will be
1931        processed concurrently. If not set, the tf.data runtime decides what it
1932        should be based on available CPU. If `num_parallel_calls` is set to
1933        `tf.data.AUTOTUNE`, the `cycle_length` argument identifies
1934        the maximum degree of parallelism.
1935      block_length: (Optional.) The number of consecutive elements to produce
1936        from each input element before cycling to another input element. If not
1937        set, defaults to 1.
1938      num_parallel_calls: (Optional.) If specified, the implementation creates a
1939        threadpool, which is used to fetch inputs from cycle elements
1940        asynchronously and in parallel. The default behavior is to fetch inputs
1941        from cycle elements synchronously with no parallelism. If the value
1942        `tf.data.AUTOTUNE` is used, then the number of parallel
1943        calls is set dynamically based on available CPU.
1944      deterministic: (Optional.) When `num_parallel_calls` is specified, this
1945        boolean controls the order in which the transformation produces
1946        elements. If set to `False`, the transformation is allowed to yield
1947        elements out of order to trade determinism for performance. If not
1948        specified, the `tf.data.Options.experimental_deterministic` option
1949        (`True` by default) controls the behavior.
1950
1951    Returns:
1952      Dataset: A `Dataset`.
1953    """
1954    if block_length is None:
1955      block_length = 1
1956
1957    if cycle_length is None:
1958      cycle_length = AUTOTUNE
1959
1960    if num_parallel_calls is None:
1961      if deterministic is not None:
1962        warnings.warn("The `deterministic` argument has no effect unless the "
1963                      "`num_parallel_calls` argument is specified.")
1964      return InterleaveDataset(self, map_func, cycle_length, block_length)
1965    else:
1966      return ParallelInterleaveDataset(
1967          self,
1968          map_func,
1969          cycle_length,
1970          block_length,
1971          num_parallel_calls,
1972          deterministic=deterministic)
1973
1974  def filter(self, predicate):
1975    """Filters this dataset according to `predicate`.
1976
1977    >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
1978    >>> dataset = dataset.filter(lambda x: x < 3)
1979    >>> list(dataset.as_numpy_iterator())
1980    [1, 2]
1981    >>> # `tf.math.equal(x, y)` is required for equality comparison
1982    >>> def filter_fn(x):
1983    ...   return tf.math.equal(x, 1)
1984    >>> dataset = dataset.filter(filter_fn)
1985    >>> list(dataset.as_numpy_iterator())
1986    [1]
1987
1988    Args:
1989      predicate: A function mapping a dataset element to a boolean.
1990
1991    Returns:
1992      Dataset: The `Dataset` containing the elements of this dataset for which
1993          `predicate` is `True`.
1994    """
1995    return FilterDataset(self, predicate)
1996
1997  def apply(self, transformation_func):
1998    """Applies a transformation function to this dataset.
1999
2000    `apply` enables chaining of custom `Dataset` transformations, which are
2001    represented as functions that take one `Dataset` argument and return a
2002    transformed `Dataset`.
2003
2004    >>> dataset = tf.data.Dataset.range(100)
2005    >>> def dataset_fn(ds):
2006    ...   return ds.filter(lambda x: x < 5)
2007    >>> dataset = dataset.apply(dataset_fn)
2008    >>> list(dataset.as_numpy_iterator())
2009    [0, 1, 2, 3, 4]
2010
2011    Args:
2012      transformation_func: A function that takes one `Dataset` argument and
2013        returns a `Dataset`.
2014
2015    Returns:
2016      Dataset: The `Dataset` returned by applying `transformation_func` to this
2017          dataset.
2018    """
2019    dataset = transformation_func(self)
2020    if not isinstance(dataset, DatasetV2):
2021      raise TypeError(
2022          "`transformation_func` must return a Dataset. Got {}.".format(
2023              dataset))
2024    dataset._input_datasets = [self]  # pylint: disable=protected-access
2025    return dataset
2026
2027  def window(self, size, shift=None, stride=1, drop_remainder=False):
2028    """Combines (nests of) input elements into a dataset of (nests of) windows.
2029
2030    A "window" is a finite dataset of flat elements of size `size` (or possibly
2031    fewer if there are not enough input elements to fill the window and
2032    `drop_remainder` evaluates to `False`).
2033
2034    The `shift` argument determines the number of input elements by which the
2035    window moves on each iteration.  If windows and elements are both numbered
2036    starting at 0, the first element in window `k` will be element `k * shift`
2037    of the input dataset. In particular, the first element of the first window
2038    will always be the first element of the input dataset.
2039
2040    The `stride` argument determines the stride of the input elements, and the
2041    `shift` argument determines the shift of the window.
2042
2043    For example:
2044
2045    >>> dataset = tf.data.Dataset.range(7).window(2)
2046    >>> for window in dataset:
2047    ...   print(list(window.as_numpy_iterator()))
2048    [0, 1]
2049    [2, 3]
2050    [4, 5]
2051    [6]
2052    >>> dataset = tf.data.Dataset.range(7).window(3, 2, 1, True)
2053    >>> for window in dataset:
2054    ...   print(list(window.as_numpy_iterator()))
2055    [0, 1, 2]
2056    [2, 3, 4]
2057    [4, 5, 6]
2058    >>> dataset = tf.data.Dataset.range(7).window(3, 1, 2, True)
2059    >>> for window in dataset:
2060    ...   print(list(window.as_numpy_iterator()))
2061    [0, 2, 4]
2062    [1, 3, 5]
2063    [2, 4, 6]
2064
2065    Note that when the `window` transformation is applied to a dataset of
2066    nested elements, it produces a dataset of nested windows.
2067
2068    >>> nested = ([1, 2, 3, 4], [5, 6, 7, 8])
2069    >>> dataset = tf.data.Dataset.from_tensor_slices(nested).window(2)
2070    >>> for window in dataset:
2071    ...   def to_numpy(ds):
2072    ...     return list(ds.as_numpy_iterator())
2073    ...   print(tuple(to_numpy(component) for component in window))
2074    ([1, 2], [5, 6])
2075    ([3, 4], [7, 8])
2076
2077    >>> dataset = tf.data.Dataset.from_tensor_slices({'a': [1, 2, 3, 4]})
2078    >>> dataset = dataset.window(2)
2079    >>> for window in dataset:
2080    ...   def to_numpy(ds):
2081    ...     return list(ds.as_numpy_iterator())
2082    ...   print({'a': to_numpy(window['a'])})
2083    {'a': [1, 2]}
2084    {'a': [3, 4]}
2085
2086    Args:
2087      size: A `tf.int64` scalar `tf.Tensor`, representing the number of elements
2088        of the input dataset to combine into a window. Must be positive.
2089      shift: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
2090        number of input elements by which the window moves in each iteration.
2091        Defaults to `size`. Must be positive.
2092      stride: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
2093        stride of the input elements in the sliding window. Must be positive.
2094        The default value of 1 means "retain every input element".
2095      drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
2096        whether the last windows should be dropped if their size is smaller than
2097        `size`.
2098
2099    Returns:
2100      Dataset: A `Dataset` of (nests of) windows -- a finite datasets of flat
2101        elements created from the (nests of) input elements.
2102
2103    """
2104    if shift is None:
2105      shift = size
2106    return WindowDataset(self, size, shift, stride, drop_remainder)
2107
2108  def reduce(self, initial_state, reduce_func):
2109    """Reduces the input dataset to a single element.
2110
2111    The transformation calls `reduce_func` successively on every element of
2112    the input dataset until the dataset is exhausted, aggregating information in
2113    its internal state. The `initial_state` argument is used for the initial
2114    state and the final state is returned as the result.
2115
2116    >>> tf.data.Dataset.range(5).reduce(np.int64(0), lambda x, _: x + 1).numpy()
2117    5
2118    >>> tf.data.Dataset.range(5).reduce(np.int64(0), lambda x, y: x + y).numpy()
2119    10
2120
2121    Args:
2122      initial_state: An element representing the initial state of the
2123        transformation.
2124      reduce_func: A function that maps `(old_state, input_element)` to
2125        `new_state`. It must take two arguments and return a new element
2126        The structure of `new_state` must match the structure of
2127        `initial_state`.
2128
2129    Returns:
2130      A dataset element corresponding to the final state of the transformation.
2131
2132    """
2133
2134    with ops.name_scope("initial_state"):
2135      initial_state = structure.normalize_element(initial_state)
2136    state_structure = structure.type_spec_from_value(initial_state)
2137
2138    # Iteratively rerun the reduce function until reaching a fixed point on
2139    # `state_structure`.
2140    need_to_rerun = True
2141    while need_to_rerun:
2142
2143      wrapped_func = StructuredFunctionWrapper(
2144          reduce_func,
2145          "reduce()",
2146          input_structure=(state_structure, self.element_spec),
2147          add_to_graph=False)
2148
2149      # Extract and validate class information from the returned values.
2150      output_classes = wrapped_func.output_classes
2151      state_classes = nest.map_structure(
2152          lambda component_spec: component_spec._to_legacy_output_classes(),  # pylint: disable=protected-access
2153          state_structure)
2154      for new_state_class, state_class in zip(
2155          nest.flatten(output_classes), nest.flatten(state_classes)):
2156        if not issubclass(new_state_class, state_class):
2157          raise TypeError(
2158              "The element classes for the new state must match the initial "
2159              "state. Expected %s; got %s." %
2160              (state_classes, wrapped_func.output_classes))
2161
2162      # Extract and validate type information from the returned values.
2163      output_types = wrapped_func.output_types
2164      state_types = nest.map_structure(
2165          lambda component_spec: component_spec._to_legacy_output_types(),  # pylint: disable=protected-access
2166          state_structure)
2167      for new_state_type, state_type in zip(
2168          nest.flatten(output_types), nest.flatten(state_types)):
2169        if new_state_type != state_type:
2170          raise TypeError(
2171              "The element types for the new state must match the initial "
2172              "state. Expected %s; got %s." %
2173              (state_types, wrapped_func.output_types))
2174
2175      # Extract shape information from the returned values.
2176      output_shapes = wrapped_func.output_shapes
2177      state_shapes = nest.map_structure(
2178          lambda component_spec: component_spec._to_legacy_output_shapes(),  # pylint: disable=protected-access
2179          state_structure)
2180      flat_state_shapes = nest.flatten(state_shapes)
2181      flat_new_state_shapes = nest.flatten(output_shapes)
2182      weakened_state_shapes = [
2183          original.most_specific_compatible_shape(new)
2184          for original, new in zip(flat_state_shapes, flat_new_state_shapes)
2185      ]
2186
2187      need_to_rerun = False
2188      for original_shape, weakened_shape in zip(flat_state_shapes,
2189                                                weakened_state_shapes):
2190        if original_shape.ndims is not None and (
2191            weakened_shape.ndims is None or
2192            original_shape.as_list() != weakened_shape.as_list()):
2193          need_to_rerun = True
2194          break
2195
2196      if need_to_rerun:
2197        # TODO(b/110122868): Support a "most specific compatible structure"
2198        # method for combining structures, to avoid using legacy structures
2199        # here.
2200        state_structure = structure.convert_legacy_structure(
2201            state_types,
2202            nest.pack_sequence_as(state_shapes, weakened_state_shapes),
2203            state_classes)
2204
2205    reduce_func = wrapped_func.function
2206    reduce_func.add_to_graph(ops.get_default_graph())
2207
2208    dataset = self._apply_options()
2209
2210    # pylint: disable=protected-access
2211    return structure.from_compatible_tensor_list(
2212        state_structure,
2213        gen_dataset_ops.reduce_dataset(
2214            dataset._variant_tensor,
2215            structure.to_tensor_list(state_structure, initial_state),
2216            reduce_func.captured_inputs,
2217            f=reduce_func,
2218            output_shapes=structure.get_flat_tensor_shapes(state_structure),
2219            output_types=structure.get_flat_tensor_types(state_structure)))
2220
2221  def unbatch(self):
2222    """Splits elements of a dataset into multiple elements.
2223
2224    For example, if elements of the dataset are shaped `[B, a0, a1, ...]`,
2225    where `B` may vary for each input element, then for each element in the
2226    dataset, the unbatched dataset will contain `B` consecutive elements
2227    of shape `[a0, a1, ...]`.
2228
2229    >>> elements = [ [1, 2, 3], [1, 2], [1, 2, 3, 4] ]
2230    >>> dataset = tf.data.Dataset.from_generator(lambda: elements, tf.int64)
2231    >>> dataset = dataset.unbatch()
2232    >>> list(dataset.as_numpy_iterator())
2233    [1, 2, 3, 1, 2, 1, 2, 3, 4]
2234
2235    Note: `unbatch` requires a data copy to slice up the batched tensor into
2236    smaller, unbatched tensors. When optimizing performance, try to avoid
2237    unnecessary usage of `unbatch`.
2238
2239    Returns:
2240      A `Dataset`.
2241    """
2242    normalized_dataset = normalize_to_dense(self)
2243    return _UnbatchDataset(normalized_dataset)
2244
2245  def with_options(self, options):
2246    """Returns a new `tf.data.Dataset` with the given options set.
2247
2248    The options are "global" in the sense they apply to the entire dataset.
2249    If options are set multiple times, they are merged as long as different
2250    options do not use different non-default values.
2251
2252    >>> ds = tf.data.Dataset.range(5)
2253    >>> ds = ds.interleave(lambda x: tf.data.Dataset.range(5),
2254    ...                    cycle_length=3,
2255    ...                    num_parallel_calls=3)
2256    >>> options = tf.data.Options()
2257    >>> # This will make the interleave order non-deterministic.
2258    >>> options.experimental_deterministic = False
2259    >>> ds = ds.with_options(options)
2260
2261    Args:
2262      options: A `tf.data.Options` that identifies the options the use.
2263
2264    Returns:
2265      Dataset: A `Dataset` with the given options.
2266
2267    Raises:
2268      ValueError: when an option is set more than once to a non-default value
2269    """
2270    return _OptionsDataset(self, options)
2271
2272  def cardinality(self):
2273    """Returns the cardinality of the dataset, if known.
2274
2275    `cardinality` may return `tf.data.INFINITE_CARDINALITY` if the dataset
2276    contains an infinite number of elements or `tf.data.UNKNOWN_CARDINALITY` if
2277    the analysis fails to determine the number of elements in the dataset
2278    (e.g. when the dataset source is a file).
2279
2280    >>> dataset = tf.data.Dataset.range(42)
2281    >>> print(dataset.cardinality().numpy())
2282    42
2283    >>> dataset = dataset.repeat()
2284    >>> cardinality = dataset.cardinality()
2285    >>> print((cardinality == tf.data.INFINITE_CARDINALITY).numpy())
2286    True
2287    >>> dataset = dataset.filter(lambda x: True)
2288    >>> cardinality = dataset.cardinality()
2289    >>> print((cardinality == tf.data.UNKNOWN_CARDINALITY).numpy())
2290    True
2291
2292    Returns:
2293      A scalar `tf.int64` `Tensor` representing the cardinality of the dataset.
2294      If the cardinality is infinite or unknown, `cardinality` returns the
2295      named constants `tf.data.INFINITE_CARDINALITY` and
2296      `tf.data.UNKNOWN_CARDINALITY` respectively.
2297    """
2298    return gen_dataset_ops.dataset_cardinality(self._variant_tensor)
2299
2300
2301@tf_export(v1=["data.Dataset"])
2302class DatasetV1(DatasetV2):
2303  """Represents a potentially large set of elements.
2304
2305  A `Dataset` can be used to represent an input pipeline as a
2306  collection of elements and a "logical plan" of transformations that act on
2307  those elements.
2308  """
2309
2310  def __init__(self):
2311    try:
2312      variant_tensor = self._as_variant_tensor()
2313    except AttributeError as e:
2314      if "_as_variant_tensor" in str(e):
2315        raise AttributeError("Please use _variant_tensor instead of "
2316                             "_as_variant_tensor() to obtain the variant "
2317                             "associated with a dataset")
2318      raise AttributeError("{}: A likely cause of this error is that the super "
2319                           "call for this dataset is not the last line of the "
2320                           "__init__ method. The base class causes the "
2321                           "_as_variant_tensor call in its constructor and "
2322                           "if that uses attributes defined in the __init__ "
2323                           "method, those attrs need to be defined before the "
2324                           "super call.".format(e))
2325    super(DatasetV1, self).__init__(variant_tensor)
2326
2327  @abc.abstractmethod
2328  def _as_variant_tensor(self):
2329    """Creates a scalar `tf.Tensor` of `tf.variant` representing this dataset.
2330
2331    Returns:
2332      A scalar `tf.Tensor` of `tf.variant` type, which represents this dataset.
2333    """
2334    raise NotImplementedError("Dataset._as_variant_tensor")
2335
2336  @deprecation.deprecated(
2337      None, "This is a deprecated API that should only be used in TF 1 graph "
2338      "mode and legacy TF 2 graph mode available through `tf.compat.v1`. In "
2339      "all other situations -- namely, eager mode and inside `tf.function` -- "
2340      "you can consume dataset elements using `for elem in dataset: ...` or "
2341      "by explicitly creating iterator via `iterator = iter(dataset)` and "
2342      "fetching its elements via `values = next(iterator)`. Furthermore, "
2343      "this API is not available in TF 2. During the transition from TF 1 "
2344      "to TF 2 you can use `tf.compat.v1.data.make_one_shot_iterator(dataset)` "
2345      "to create a TF 1 graph mode style iterator for a dataset created "
2346      "through TF 2 APIs. Note that this should be a transient state of your "
2347      "code base as there are in general no guarantees about the "
2348      "interoperability of TF 1 and TF 2 code.")
2349  def make_one_shot_iterator(self):
2350    """Creates an iterator for elements of this dataset.
2351
2352    Note: The returned iterator will be initialized automatically.
2353    A "one-shot" iterator does not currently support re-initialization. For
2354    that see `make_initializable_iterator`.
2355
2356    Example:
2357
2358    ```python
2359    # Building graph ...
2360    dataset = ...
2361    next_value = dataset.make_one_shot_iterator().get_next()
2362
2363    # ... from within a session ...
2364    try:
2365      while True:
2366        value = sess.run(next_value)
2367        ...
2368    except tf.errors.OutOfRangeError:
2369        pass
2370    ```
2371
2372    Returns:
2373      An `tf.data.Iterator` for elements of this dataset.
2374    """
2375    return self._make_one_shot_iterator()
2376
2377  def _make_one_shot_iterator(self):  # pylint: disable=missing-docstring
2378    if context.executing_eagerly():
2379      with ops.colocate_with(self._variant_tensor):
2380        return iterator_ops.OwnedIterator(self)
2381
2382    _ensure_same_dataset_graph(self)
2383    # Now that we create datasets at python object creation time, the capture
2384    # by value _make_dataset() function would try to capture these variant
2385    # tensor dataset inputs, which are marked as stateful ops and would throw
2386    # an error if we try and capture them. We therefore traverse the graph
2387    # to find all these ops and allowlist them so that the capturing
2388    # logic instead of throwing an error recreates these ops which is what was
2389    # happening before.
2390    all_ds_ops = traverse.obtain_all_variant_tensor_ops(self)
2391    graph_level_seed, op_level_seed = core_random_seed.get_seed(None)
2392
2393    # NOTE(mrry): We capture by value here to ensure that `_make_dataset()` is
2394    # a 0-argument function.
2395    @function.Defun(capture_by_value=True, allowlisted_stateful_ops=all_ds_ops)
2396    def _make_dataset():
2397      """Factory function for a dataset."""
2398      # NOTE(mrry): `Defun` does not capture the graph-level seed from the
2399      # enclosing graph, so if a graph-level seed is present we set the local
2400      # graph seed based on a combination of the graph- and op-level seeds.
2401      if graph_level_seed is not None:
2402        assert op_level_seed is not None
2403        core_random_seed.set_random_seed(
2404            (graph_level_seed + 87654321 * op_level_seed) % (2 ** 63 - 1))
2405
2406      dataset = self._apply_options()
2407      return dataset._variant_tensor  # pylint: disable=protected-access
2408
2409    try:
2410      _make_dataset.add_to_graph(ops.get_default_graph())
2411    except ValueError as err:
2412      if "Cannot capture a stateful node" in str(err):
2413        raise ValueError(
2414            "Failed to create a one-shot iterator for a dataset. "
2415            "`Dataset.make_one_shot_iterator()` does not support datasets that "
2416            "capture stateful objects, such as a `Variable` or `LookupTable`. "
2417            "In these cases, use `Dataset.make_initializable_iterator()`. "
2418            "(Original error: %s)" % err)
2419      else:
2420        six.reraise(ValueError, err)
2421
2422    with ops.colocate_with(self._variant_tensor):
2423      # pylint: disable=protected-access
2424      return iterator_ops.Iterator(
2425          gen_dataset_ops.one_shot_iterator(
2426              dataset_factory=_make_dataset, **self._flat_structure), None,
2427          get_legacy_output_types(self), get_legacy_output_shapes(self),
2428          get_legacy_output_classes(self))
2429
2430  @deprecation.deprecated(
2431      None, "This is a deprecated API that should only be used in TF 1 graph "
2432      "mode and legacy TF 2 graph mode available through `tf.compat.v1`. "
2433      "In all other situations -- namely, eager mode and inside `tf.function` "
2434      "-- you can consume dataset elements using `for elem in dataset: ...` "
2435      "or by explicitly creating iterator via `iterator = iter(dataset)` "
2436      "and fetching its elements via `values = next(iterator)`. "
2437      "Furthermore, this API is not available in TF 2. During the transition "
2438      "from TF 1 to TF 2 you can use "
2439      "`tf.compat.v1.data.make_initializable_iterator(dataset)` to create a TF "
2440      "1 graph mode style iterator for a dataset created through TF 2 APIs. "
2441      "Note that this should be a transient state of your code base as there "
2442      "are in general no guarantees about the interoperability of TF 1 and TF "
2443      "2 code.")
2444  def make_initializable_iterator(self, shared_name=None):
2445    """Creates an iterator for elements of this dataset.
2446
2447    Note: The returned iterator will be in an uninitialized state,
2448    and you must run the `iterator.initializer` operation before using it:
2449
2450    ```python
2451    # Building graph ...
2452    dataset = ...
2453    iterator = dataset.make_initializable_iterator()
2454    next_value = iterator.get_next()  # This is a Tensor.
2455
2456    # ... from within a session ...
2457    sess.run(iterator.initializer)
2458    try:
2459      while True:
2460        value = sess.run(next_value)
2461        ...
2462    except tf.errors.OutOfRangeError:
2463        pass
2464    ```
2465
2466    Args:
2467      shared_name: (Optional.) If non-empty, the returned iterator will be
2468        shared under the given name across multiple sessions that share the same
2469        devices (e.g. when using a remote server).
2470
2471    Returns:
2472      A `tf.data.Iterator` for elements of this dataset.
2473
2474    Raises:
2475      RuntimeError: If eager execution is enabled.
2476    """
2477    return self._make_initializable_iterator(shared_name)
2478
2479  def _make_initializable_iterator(self, shared_name=None):  # pylint: disable=missing-docstring
2480    if context.executing_eagerly():
2481      raise RuntimeError(
2482          "dataset.make_initializable_iterator is not supported when eager "
2483          "execution is enabled. Use `for element in dataset` instead.")
2484    _ensure_same_dataset_graph(self)
2485    dataset = self._apply_options()
2486    if shared_name is None:
2487      shared_name = ""
2488
2489    with ops.colocate_with(self._variant_tensor):
2490      iterator_resource = gen_dataset_ops.iterator_v2(
2491          container="", shared_name=shared_name, **self._flat_structure)
2492
2493      initializer = gen_dataset_ops.make_iterator(
2494          dataset._variant_tensor,  # pylint: disable=protected-access
2495          iterator_resource)
2496
2497      # pylint: disable=protected-access
2498      return iterator_ops.Iterator(iterator_resource, initializer,
2499                                   get_legacy_output_types(dataset),
2500                                   get_legacy_output_shapes(dataset),
2501                                   get_legacy_output_classes(dataset))
2502
2503  @property
2504  @deprecation.deprecated(
2505      None, "Use `tf.compat.v1.data.get_output_classes(dataset)`.")
2506  def output_classes(self):
2507    """Returns the class of each component of an element of this dataset.
2508
2509    Returns:
2510      A nested structure of Python `type` objects corresponding to each
2511      component of an element of this dataset.
2512    """
2513    return nest.map_structure(
2514        lambda component_spec: component_spec._to_legacy_output_classes(),  # pylint: disable=protected-access
2515        self.element_spec)
2516
2517  @property
2518  @deprecation.deprecated(
2519      None, "Use `tf.compat.v1.data.get_output_shapes(dataset)`.")
2520  def output_shapes(self):
2521    """Returns the shape of each component of an element of this dataset.
2522
2523    Returns:
2524      A nested structure of `tf.TensorShape` objects corresponding to each
2525      component of an element of this dataset.
2526    """
2527    return nest.map_structure(
2528        lambda component_spec: component_spec._to_legacy_output_shapes(),  # pylint: disable=protected-access
2529        self.element_spec)
2530
2531  @property
2532  @deprecation.deprecated(
2533      None, "Use `tf.compat.v1.data.get_output_types(dataset)`.")
2534  def output_types(self):
2535    """Returns the type of each component of an element of this dataset.
2536
2537    Returns:
2538      A nested structure of `tf.DType` objects corresponding to each component
2539      of an element of this dataset.
2540    """
2541    return nest.map_structure(
2542        lambda component_spec: component_spec._to_legacy_output_types(),  # pylint: disable=protected-access
2543        self.element_spec)
2544
2545  @property
2546  def element_spec(self):
2547    # TODO(b/110122868): Remove this override once all `Dataset` instances
2548    # implement `element_structure`.
2549    return structure.convert_legacy_structure(
2550        self.output_types, self.output_shapes, self.output_classes)
2551
2552  @staticmethod
2553  @functools.wraps(DatasetV2.from_tensors)
2554  def from_tensors(tensors):
2555    return DatasetV1Adapter(DatasetV2.from_tensors(tensors))
2556
2557  @staticmethod
2558  @functools.wraps(DatasetV2.from_tensor_slices)
2559  def from_tensor_slices(tensors):
2560    return DatasetV1Adapter(DatasetV2.from_tensor_slices(tensors))
2561
2562  @staticmethod
2563  @deprecation.deprecated(None, "Use `tf.data.Dataset.from_tensor_slices()`.")
2564  def from_sparse_tensor_slices(sparse_tensor):
2565    """Splits each rank-N `tf.sparse.SparseTensor` in this dataset row-wise.
2566
2567    Args:
2568      sparse_tensor: A `tf.sparse.SparseTensor`.
2569
2570    Returns:
2571      Dataset: A `Dataset` of rank-(N-1) sparse tensors.
2572    """
2573    return DatasetV1Adapter(SparseTensorSliceDataset(sparse_tensor))
2574
2575  @staticmethod
2576  @functools.wraps(DatasetV2.from_generator)
2577  def from_generator(generator,
2578                     output_types=None,
2579                     output_shapes=None,
2580                     args=None,
2581                     output_signature=None):
2582    return DatasetV1Adapter(
2583        DatasetV2.from_generator(generator, output_types, output_shapes, args,
2584                                 output_signature))
2585
2586  @staticmethod
2587  @functools.wraps(DatasetV2.range)
2588  def range(*args, **kwargs):
2589    return DatasetV1Adapter(DatasetV2.range(*args, **kwargs))
2590
2591  @staticmethod
2592  @functools.wraps(DatasetV2.zip)
2593  def zip(datasets):
2594    return DatasetV1Adapter(DatasetV2.zip(datasets))
2595
2596  @functools.wraps(DatasetV2.concatenate)
2597  def concatenate(self, dataset):
2598    return DatasetV1Adapter(super(DatasetV1, self).concatenate(dataset))
2599
2600  @functools.wraps(DatasetV2.prefetch)
2601  def prefetch(self, buffer_size):
2602    return DatasetV1Adapter(super(DatasetV1, self).prefetch(buffer_size))
2603
2604  @staticmethod
2605  @functools.wraps(DatasetV2.list_files)
2606  def list_files(file_pattern, shuffle=None, seed=None):
2607    return DatasetV1Adapter(DatasetV2.list_files(file_pattern, shuffle, seed))
2608
2609  @functools.wraps(DatasetV2.repeat)
2610  def repeat(self, count=None):
2611    return DatasetV1Adapter(super(DatasetV1, self).repeat(count))
2612
2613  @functools.wraps(DatasetV2.shuffle)
2614  def shuffle(self, buffer_size, seed=None, reshuffle_each_iteration=None):
2615    return DatasetV1Adapter(super(DatasetV1, self).shuffle(
2616        buffer_size, seed, reshuffle_each_iteration))
2617
2618  @functools.wraps(DatasetV2.cache)
2619  def cache(self, filename=""):
2620    return DatasetV1Adapter(super(DatasetV1, self).cache(filename))
2621
2622  @functools.wraps(DatasetV2.take)
2623  def take(self, count):
2624    return DatasetV1Adapter(super(DatasetV1, self).take(count))
2625
2626  @functools.wraps(DatasetV2.skip)
2627  def skip(self, count):
2628    return DatasetV1Adapter(super(DatasetV1, self).skip(count))
2629
2630  @functools.wraps(DatasetV2.shard)
2631  def shard(self, num_shards, index):
2632    return DatasetV1Adapter(super(DatasetV1, self).shard(num_shards, index))
2633
2634  @functools.wraps(DatasetV2.batch)
2635  def batch(self, batch_size, drop_remainder=False, num_parallel_calls=None):
2636    return DatasetV1Adapter(
2637        super(DatasetV1, self).batch(batch_size, drop_remainder,
2638                                     num_parallel_calls))
2639
2640  @functools.wraps(DatasetV2.padded_batch)
2641  def padded_batch(self,
2642                   batch_size,
2643                   padded_shapes=None,
2644                   padding_values=None,
2645                   drop_remainder=False):
2646    return DatasetV1Adapter(
2647        super(DatasetV1, self).padded_batch(batch_size, padded_shapes,
2648                                            padding_values, drop_remainder))
2649
2650  @functools.wraps(DatasetV2.map)
2651  def map(self, map_func, num_parallel_calls=None, deterministic=None):
2652    if num_parallel_calls is None:
2653      return DatasetV1Adapter(
2654          MapDataset(self, map_func, preserve_cardinality=False))
2655    else:
2656      return DatasetV1Adapter(
2657          ParallelMapDataset(
2658              self,
2659              map_func,
2660              num_parallel_calls,
2661              deterministic,
2662              preserve_cardinality=False))
2663
2664  @deprecation.deprecated(None, "Use `tf.data.Dataset.map()")
2665  def map_with_legacy_function(self,
2666                               map_func,
2667                               num_parallel_calls=None,
2668                               deterministic=None):
2669    """Maps `map_func` across the elements of this dataset.
2670
2671    Note: This is an escape hatch for existing uses of `map` that do not work
2672    with V2 functions. New uses are strongly discouraged and existing uses
2673    should migrate to `map` as this method will be removed in V2.
2674
2675    Args:
2676      map_func: A function mapping a nested structure of tensors (having shapes
2677        and types defined by `self.output_shapes` and `self.output_types`) to
2678        another nested structure of tensors.
2679      num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`,
2680        representing the number elements to process asynchronously in parallel.
2681        If not specified, elements will be processed sequentially. If the value
2682        `tf.data.AUTOTUNE` is used, then the number of parallel
2683        calls is set dynamically based on available CPU.
2684      deterministic: (Optional.) When `num_parallel_calls` is specified, this
2685        boolean controls the order in which the transformation produces
2686        elements. If set to `False`, the transformation is allowed to yield
2687        elements out of order to trade determinism for performance. If not
2688        specified, the `tf.data.Options.experimental_deterministic` option
2689        (`True` by default) controls the behavior.
2690
2691    Returns:
2692      Dataset: A `Dataset`.
2693    """
2694    if num_parallel_calls is None:
2695      if deterministic is not None:
2696        warnings.warn("The `deterministic` argument has no effect unless the "
2697                      "`num_parallel_calls` argument is specified.")
2698      return DatasetV1Adapter(
2699          MapDataset(
2700              self,
2701              map_func,
2702              preserve_cardinality=False,
2703              use_legacy_function=True))
2704    else:
2705      return DatasetV1Adapter(
2706          ParallelMapDataset(
2707              self,
2708              map_func,
2709              num_parallel_calls,
2710              deterministic,
2711              preserve_cardinality=False,
2712              use_legacy_function=True))
2713
2714  @functools.wraps(DatasetV2.flat_map)
2715  def flat_map(self, map_func):
2716    return DatasetV1Adapter(super(DatasetV1, self).flat_map(map_func))
2717
2718  @functools.wraps(DatasetV2.interleave)
2719  def interleave(self,
2720                 map_func,
2721                 cycle_length=None,
2722                 block_length=None,
2723                 num_parallel_calls=None,
2724                 deterministic=None):
2725    return DatasetV1Adapter(
2726        super(DatasetV1, self).interleave(map_func, cycle_length, block_length,
2727                                          num_parallel_calls, deterministic))
2728
2729  @functools.wraps(DatasetV2.filter)
2730  def filter(self, predicate):
2731    return DatasetV1Adapter(super(DatasetV1, self).filter(predicate))
2732
2733  @deprecation.deprecated(None, "Use `tf.data.Dataset.filter()")
2734  def filter_with_legacy_function(self, predicate):
2735    """Filters this dataset according to `predicate`.
2736
2737    Note: This is an escape hatch for existing uses of `filter` that do not work
2738    with V2 functions. New uses are strongly discouraged and existing uses
2739    should migrate to `filter` as this method will be removed in V2.
2740
2741    Args:
2742      predicate: A function mapping a nested structure of tensors (having shapes
2743        and types defined by `self.output_shapes` and `self.output_types`) to a
2744        scalar `tf.bool` tensor.
2745
2746    Returns:
2747      Dataset: The `Dataset` containing the elements of this dataset for which
2748          `predicate` is `True`.
2749    """
2750    return FilterDataset(self, predicate, use_legacy_function=True)
2751
2752  @functools.wraps(DatasetV2.apply)
2753  def apply(self, transformation_func):
2754    return DatasetV1Adapter(super(DatasetV1, self).apply(transformation_func))
2755
2756  @functools.wraps(DatasetV2.window)
2757  def window(self, size, shift=None, stride=1, drop_remainder=False):
2758    return DatasetV1Adapter(super(DatasetV1, self).window(
2759        size, shift, stride, drop_remainder))
2760
2761  @functools.wraps(DatasetV2.unbatch)
2762  def unbatch(self):
2763    return DatasetV1Adapter(super(DatasetV1, self).unbatch())
2764
2765  @functools.wraps(DatasetV2.with_options)
2766  def with_options(self, options):
2767    return DatasetV1Adapter(super(DatasetV1, self).with_options(options))
2768
2769
2770if tf2.enabled():
2771  Dataset = DatasetV2
2772else:
2773  Dataset = DatasetV1
2774
2775
2776class DatasetV1Adapter(DatasetV1):
2777  """Wraps a V2 `Dataset` object in the `tf.compat.v1.data.Dataset` API."""
2778
2779  def __init__(self, dataset):
2780    self._dataset = dataset
2781    super(DatasetV1Adapter, self).__init__()
2782
2783  def _as_variant_tensor(self):
2784    return self._dataset._variant_tensor  # pylint: disable=protected-access
2785
2786  def _has_captured_ref(self):
2787    return self._dataset._has_captured_ref()  # pylint: disable=protected-access
2788
2789  def _inputs(self):
2790    return self._dataset._inputs()  # pylint: disable=protected-access
2791
2792  def _functions(self):
2793    return self._dataset._functions()  # pylint: disable=protected-access
2794
2795  def options(self):
2796    return self._dataset.options()
2797
2798  @property
2799  def element_spec(self):
2800    return self._dataset.element_spec  # pylint: disable=protected-access
2801
2802  def __iter__(self):
2803    return iter(self._dataset)
2804
2805
2806def _ensure_same_dataset_graph(dataset):
2807  """Walks the dataset graph to ensure all datasets come from the same graph."""
2808  # pylint: disable=protected-access
2809  current_graph = ops.get_default_graph()
2810  bfs_q = Queue.Queue()
2811  bfs_q.put(dataset)
2812  visited = []
2813  while not bfs_q.empty():
2814    ds = bfs_q.get()
2815    visited.append(ds)
2816    ds_graph = ds._graph
2817    if current_graph != ds_graph:
2818      raise ValueError(
2819          "The graph (" + str(current_graph) + ") of the iterator is different "
2820          "from the graph (" + str(ds_graph) + ") the dataset: " +
2821          str(ds._variant_tensor) + " was  created in. If you are using the "
2822          "Estimator API, make sure that no part of the dataset returned by "
2823          "the `input_fn` function is defined outside the `input_fn` function. "
2824          "Please ensure that all datasets in the pipeline are created in the "
2825          "same graph as the iterator.")
2826    for input_ds in ds._inputs():
2827      if input_ds not in visited:
2828        bfs_q.put(input_ds)
2829
2830
2831@tf_export(v1=["data.make_one_shot_iterator"])
2832def make_one_shot_iterator(dataset):
2833  """Creates an iterator for elements of `dataset`.
2834
2835  Note: The returned iterator will be initialized automatically.
2836  A "one-shot" iterator does not support re-initialization.
2837
2838  Args:
2839    dataset: A `tf.data.Dataset`.
2840
2841  Returns:
2842    A `tf.data.Iterator` for elements of `dataset`.
2843  """
2844  try:
2845    # Call the defined `_make_one_shot_iterator()` if there is one, because some
2846    # datasets (e.g. for prefetching) override its behavior.
2847    return dataset._make_one_shot_iterator()  # pylint: disable=protected-access
2848  except AttributeError:
2849    return DatasetV1Adapter(dataset)._make_one_shot_iterator()  # pylint: disable=protected-access
2850
2851
2852@tf_export(v1=["data.make_initializable_iterator"])
2853def make_initializable_iterator(dataset, shared_name=None):
2854  """Creates an iterator for elements of `dataset`.
2855
2856  Note: The returned iterator will be in an uninitialized state,
2857  and you must run the `iterator.initializer` operation before using it:
2858
2859  ```python
2860  dataset = ...
2861  iterator = tf.compat.v1.data.make_initializable_iterator(dataset)
2862  # ...
2863  sess.run(iterator.initializer)
2864  ```
2865
2866  Args:
2867    dataset: A `tf.data.Dataset`.
2868    shared_name: (Optional.) If non-empty, the returned iterator will be shared
2869      under the given name across multiple sessions that share the same devices
2870      (e.g. when using a remote server).
2871
2872  Returns:
2873    A `tf.data.Iterator` for elements of `dataset`.
2874
2875  Raises:
2876    RuntimeError: If eager execution is enabled.
2877  """
2878  try:
2879    # Call the defined `_make_initializable_iterator()` if there is one, because
2880    # some datasets (e.g. for prefetching) override its behavior.
2881    return dataset._make_initializable_iterator(shared_name)  # pylint: disable=protected-access
2882  except AttributeError:
2883    return DatasetV1Adapter(dataset)._make_initializable_iterator(shared_name)  # pylint: disable=protected-access
2884
2885
2886@tf_export("data.experimental.get_structure")
2887def get_structure(dataset_or_iterator):
2888  """Returns the type signature for elements of the input dataset / iterator.
2889
2890  Args:
2891    dataset_or_iterator: A `tf.data.Dataset` or an `tf.data.Iterator`.
2892
2893  Returns:
2894    A nested structure of `tf.TypeSpec` objects matching the structure of an
2895    element of `dataset_or_iterator` and specifying the type of individual
2896    components.
2897
2898  Raises:
2899    TypeError: If input is not a `tf.data.Dataset` or an `tf.data.Iterator`
2900      object.
2901  """
2902  try:
2903    return dataset_or_iterator.element_spec  # pylint: disable=protected-access
2904  except AttributeError:
2905    raise TypeError("`dataset_or_iterator` must be a `tf.data.Dataset` or "
2906                    "tf.data.Iterator object, but got %s." %
2907                    type(dataset_or_iterator))
2908
2909
2910@tf_export(v1=["data.get_output_classes"])
2911def get_legacy_output_classes(dataset_or_iterator):
2912  """Returns the output classes for elements of the input dataset / iterator.
2913
2914  Args:
2915    dataset_or_iterator: A `tf.data.Dataset` or `tf.data.Iterator`.
2916
2917  Returns:
2918    A nested structure of Python `type` objects matching the structure of the
2919    dataset / iterator elements and specifying the class of the individual
2920    components.
2921  """
2922  return nest.map_structure(
2923      lambda component_spec: component_spec._to_legacy_output_classes(),  # pylint: disable=protected-access
2924      get_structure(dataset_or_iterator))
2925
2926
2927@tf_export(v1=["data.get_output_shapes"])
2928def get_legacy_output_shapes(dataset_or_iterator):
2929  """Returns the output shapes for elements of the input dataset / iterator.
2930
2931  Args:
2932    dataset_or_iterator: A `tf.data.Dataset` or `tf.data.Iterator`.
2933
2934  Returns:
2935    A nested structure of `tf.TensorShape` objects matching the structure of
2936    the dataset / iterator elements and specifying the shape of the individual
2937    components.
2938  """
2939  return nest.map_structure(
2940      lambda component_spec: component_spec._to_legacy_output_shapes(),  # pylint: disable=protected-access
2941      get_structure(dataset_or_iterator))
2942
2943
2944@tf_export(v1=["data.get_output_types"])
2945def get_legacy_output_types(dataset_or_iterator):
2946  """Returns the output shapes for elements of the input dataset / iterator.
2947
2948  Args:
2949    dataset_or_iterator: A `tf.data.Dataset` or `tf.data.Iterator`.
2950
2951  Returns:
2952    A nested structure of `tf.DType` objects matching the structure of
2953    dataset / iterator elements and specifying the shape of the individual
2954    components.
2955  """
2956  return nest.map_structure(
2957      lambda component_spec: component_spec._to_legacy_output_types(),  # pylint: disable=protected-access
2958      get_structure(dataset_or_iterator))
2959
2960
2961@tf_export("data.Options")
2962class Options(options_lib.OptionsBase):
2963  """Represents options for `tf.data.Dataset`.
2964
2965  A `tf.data.Options` object can be, for instance, used to control which static
2966  optimizations to apply to the input pipeline graph or whether to use
2967  performance modeling to dynamically tune the parallelism of operations such as
2968  `tf.data.Dataset.map` or `tf.data.Dataset.interleave`.
2969
2970  The options are set for the entire dataset and are carried over to datasets
2971  created through tf.data transformations.
2972
2973  The options can be set either by mutating the object returned by
2974  `tf.data.Dataset.options()` or by constructing an `Options` object and using
2975  the `tf.data.Dataset.with_options(options)` transformation, which returns a
2976  dataset with the options set.
2977
2978  >>> dataset = tf.data.Dataset.range(42)
2979  >>> dataset.options().experimental_deterministic = False
2980  >>> print(dataset.options().experimental_deterministic)
2981  False
2982
2983  >>> dataset = tf.data.Dataset.range(42)
2984  >>> options = tf.data.Options()
2985  >>> options.experimental_deterministic = False
2986  >>> dataset = dataset.with_options(options)
2987  >>> print(dataset.options().experimental_deterministic)
2988  False
2989
2990  Note: A known limitation of the `tf.data.Options` implementation is that the
2991  options are not preserved across tf.function boundaries. In particular, to
2992  set options for a dataset that is iterated within a tf.function, the options
2993  need to be set within the same tf.function.
2994  """
2995
2996  experimental_deterministic = options_lib.create_option(
2997      name="experimental_deterministic",
2998      ty=bool,
2999      docstring=
3000      "Whether the outputs need to be produced in deterministic order. If None,"
3001      " defaults to True.")
3002
3003  experimental_distribute = options_lib.create_option(
3004      name="experimental_distribute",
3005      ty=distribute_options.DistributeOptions,
3006      docstring=
3007      "The distribution strategy options associated with the dataset. See "
3008      "`tf.data.experimental.DistributeOptions` for more details.",
3009      default_factory=distribute_options.DistributeOptions)
3010
3011  experimental_optimization = options_lib.create_option(
3012      name="experimental_optimization",
3013      ty=optimization_options.OptimizationOptions,
3014      docstring=
3015      "The optimization options associated with the dataset. See "
3016      "`tf.data.experimental.OptimizationOptions` for more details.",
3017      default_factory=optimization_options.OptimizationOptions)
3018
3019  experimental_slack = options_lib.create_option(
3020      name="experimental_slack",
3021      ty=bool,
3022      docstring="Whether to introduce 'slack' in the last `prefetch` of the "
3023      "input pipeline, if it exists. This may reduce CPU contention with "
3024      "accelerator host-side activity at the start of a step. The slack "
3025      "frequency is determined by the number of devices attached to this "
3026      "input pipeline. If None, defaults to False.")
3027
3028  experimental_stats = options_lib.create_option(
3029      name="experimental_stats",
3030      ty=stats_options.StatsOptions,
3031      docstring=
3032      "The statistics options associated with the dataset. See "
3033      "`tf.data.experimental.StatsOptions` for more details.",
3034      default_factory=stats_options.StatsOptions)
3035
3036  experimental_threading = options_lib.create_option(
3037      name="experimental_threading",
3038      ty=threading_options.ThreadingOptions,
3039      docstring=
3040      "The threading options associated with the dataset. See "
3041      "`tf.data.experimental.ThreadingOptions` for more details.",
3042      default_factory=threading_options.ThreadingOptions)
3043
3044  experimental_external_state_policy = options_lib.create_option(
3045      name="experimental_external_state_policy",
3046      ty=distribute_options.ExternalStatePolicy,
3047      docstring="This option can be used to override the default policy for "
3048      "how to handle external state when serializing a dataset or "
3049      "checkpointing its iterator. There are three settings available - "
3050      "IGNORE: External state is ignored without a warning; WARN: External "
3051      "state is ignored and a warning is logged; FAIL: External state results "
3052      "in an error.")
3053
3054  def _to_proto(self):
3055    pb = dataset_options_pb2.Options()
3056    if self.experimental_deterministic is not None:
3057      pb.deterministic = self.experimental_deterministic
3058    pb.distribute_options.CopyFrom(self.experimental_distribute._to_proto())  # pylint: disable=protected-access
3059    if self.experimental_external_state_policy is not None:
3060      pb.external_state_policy = (
3061          distribute_options.ExternalStatePolicy._to_proto(  # pylint: disable=protected-access
3062              self.experimental_external_state_policy))
3063    pb.optimization_options.CopyFrom(self.experimental_optimization._to_proto())  # pylint: disable=protected-access
3064    if self.experimental_slack is not None:
3065      pb.slack = self.experimental_slack
3066    pb.threading_options.CopyFrom(self.experimental_threading._to_proto())  # pylint: disable=protected-access
3067    return pb
3068
3069  def _from_proto(self, pb):
3070    if pb.WhichOneof("optional_deterministic") is not None:
3071      self.experimental_deterministic = pb.deterministic
3072    self.experimental_distribute._from_proto(pb.distribute_options)  # pylint: disable=protected-access
3073    if pb.WhichOneof("optional_external_state_policy") is not None:
3074      self.experimental_external_state_policy = (
3075          distribute_options.ExternalStatePolicy._from_proto(  # pylint: disable=protected-access
3076              pb.external_state_policy))
3077    self.experimental_optimization._from_proto(pb.optimization_options)  # pylint: disable=protected-access
3078    if pb.WhichOneof("optional_slack") is not None:
3079      self.experimental_slack = pb.slack
3080    self.experimental_threading._from_proto(pb.threading_options)  # pylint: disable=protected-access
3081
3082  def _graph_rewrites(self):
3083    """Produces lists of enabled, disabled, default static graph rewrites.
3084
3085    Returns:
3086      result: a namedtuple with three attributes. `result.enabled` is the list
3087        of user enabled graph rewrites. `result.disabled` is the list of user
3088        disabled graph rewrites. `result.default` is the list of graph
3089        rewrites that are enabled by default (the user has not explicitly
3090        enabled or disabled them).
3091    """
3092    if self.experimental_optimization is not None:
3093      result = self.experimental_optimization._graph_rewrites()  # pylint: disable=protected-access
3094    else:
3095      # Apply default options
3096      result = optimization_options.OptimizationOptions()._graph_rewrites()  # pylint: disable=protected-access
3097
3098    if self.experimental_deterministic is False:  # pylint: disable=g-bool-id-comparison
3099      result.enabled.append("make_sloppy")
3100    elif self.experimental_deterministic is True:  # pylint: disable=g-bool-id-comparison
3101      result.disabled.append("make_sloppy")
3102    if self.experimental_stats:
3103      if  self.experimental_stats.latency_all_edges is True:  # pylint: disable=g-bool-id-comparison
3104        result.enabled.append("latency_all_edges")
3105      elif self.experimental_stats.latency_all_edges is False:  # pylint: disable=g-bool-id-comparison
3106        result.disabled.append("latency_all_edges")
3107    if self.experimental_slack is True:  # pylint: disable=g-bool-id-comparison
3108      result.enabled.append("slack")
3109    elif self.experimental_slack is False:  # pylint: disable=g-bool-id-comparison
3110      result.disabled.append("slack")
3111
3112    graph_rewrites = options_lib.graph_rewrites()
3113    return graph_rewrites(enabled=list(set(result.enabled)),
3114                          disabled=list(set(result.disabled)),
3115                          default=list(set(result.default)))
3116
3117  def _graph_rewrite_configs(self, autotune):
3118    """Produces the list of configurations for enabled graph optimizations."""
3119    result = []
3120    if self.experimental_optimization:
3121      result.extend(
3122          self.experimental_optimization._graph_rewrite_configs(autotune))  # pylint: disable=protected-access
3123
3124    if self.experimental_slack:
3125      num_devices = self.experimental_distribute.num_devices
3126      if num_devices is None:
3127        num_devices = 1
3128      result.append("slack:slack_period:%d" % num_devices)
3129    return result
3130
3131  def _autotune_settings(self):
3132    if self.experimental_optimization is not None:
3133      return self.experimental_optimization._autotune_settings()  # pylint: disable=protected-access
3134
3135    # Return default autotune options
3136    return optimization_options.OptimizationOptions()._autotune_settings()  # pylint: disable=protected-access
3137
3138  def merge(self, options):
3139    """Merges itself with the given `tf.data.Options`.
3140
3141    If this object and the `options` to merge set an option differently, a
3142    warning is generated and this object's value is updated with the `options`
3143    object's value.
3144
3145    Args:
3146      options: a `tf.data.Options` to merge with
3147
3148    Returns:
3149      New `tf.data.Options` object which is the result of merging self with
3150      the input `tf.data.Options`.
3151    """
3152    return options_lib.merge_options(self, options)
3153
3154
3155class DatasetSource(DatasetV2):
3156  """Abstract class representing a dataset with no inputs."""
3157
3158  def _inputs(self):
3159    return []
3160
3161
3162class UnaryDataset(DatasetV2):
3163  """Abstract class representing a dataset with one input."""
3164
3165  def __init__(self, input_dataset, variant_tensor):
3166    self._input_dataset = input_dataset
3167    super(UnaryDataset, self).__init__(variant_tensor)
3168
3169  def _inputs(self):
3170    return [self._input_dataset]
3171
3172
3173class UnaryUnchangedStructureDataset(UnaryDataset):
3174  """Represents a unary dataset with the same input and output structure."""
3175
3176  def __init__(self, input_dataset, variant_tensor):
3177    self._input_dataset = input_dataset
3178    super(UnaryUnchangedStructureDataset, self).__init__(
3179        input_dataset, variant_tensor)
3180
3181  @property
3182  def element_spec(self):
3183    return self._input_dataset.element_spec
3184
3185
3186class TensorDataset(DatasetSource):
3187  """A `Dataset` with a single element."""
3188
3189  def __init__(self, element):
3190    """See `Dataset.from_tensors()` for details."""
3191    element = structure.normalize_element(element)
3192    self._structure = structure.type_spec_from_value(element)
3193    self._tensors = structure.to_tensor_list(self._structure, element)
3194
3195    variant_tensor = gen_dataset_ops.tensor_dataset(
3196        self._tensors,
3197        output_shapes=structure.get_flat_tensor_shapes(self._structure))
3198    super(TensorDataset, self).__init__(variant_tensor)
3199
3200  @property
3201  def element_spec(self):
3202    return self._structure
3203
3204
3205class TensorSliceDataset(DatasetSource):
3206  """A `Dataset` of slices from a dataset element."""
3207
3208  def __init__(self, element):
3209    """See `Dataset.from_tensor_slices()` for details."""
3210    element = structure.normalize_element(element)
3211    batched_spec = structure.type_spec_from_value(element)
3212    self._tensors = structure.to_batched_tensor_list(batched_spec, element)
3213    self._structure = nest.map_structure(
3214        lambda component_spec: component_spec._unbatch(), batched_spec)  # pylint: disable=protected-access
3215
3216    batch_dim = tensor_shape.Dimension(tensor_shape.dimension_value(
3217        self._tensors[0].get_shape()[0]))
3218    for t in self._tensors[1:]:
3219      batch_dim.assert_is_compatible_with(tensor_shape.Dimension(
3220          tensor_shape.dimension_value(t.get_shape()[0])))
3221
3222    variant_tensor = gen_dataset_ops.tensor_slice_dataset(
3223        self._tensors,
3224        output_shapes=structure.get_flat_tensor_shapes(self._structure))
3225    super(TensorSliceDataset, self).__init__(variant_tensor)
3226
3227  @property
3228  def element_spec(self):
3229    return self._structure
3230
3231
3232class SparseTensorSliceDataset(DatasetSource):
3233  """A `Dataset` that splits a rank-N `tf.sparse.SparseTensor` into its rows."""
3234
3235  def __init__(self, sparse_tensor):
3236    """See `Dataset.from_sparse_tensor_slices()` for details."""
3237    if not isinstance(sparse_tensor, sparse_tensor_lib.SparseTensor):
3238      raise TypeError(
3239          "`sparse_tensor` must be a `tf.sparse.SparseTensor` object."
3240          "Was {}.".format(sparse_tensor))
3241    self._sparse_tensor = sparse_tensor
3242
3243    indices_shape = self._sparse_tensor.indices.get_shape()
3244    shape_shape = self._sparse_tensor.dense_shape.get_shape()
3245    rank = (indices_shape.dims[1] - 1).merge_with(shape_shape.dims[0] - 1)
3246    self._structure = (tensor_spec.TensorSpec([None, rank], dtypes.int64),
3247                       tensor_spec.TensorSpec([None],
3248                                              self._sparse_tensor.dtype),
3249                       tensor_spec.TensorSpec([rank], dtypes.int64))
3250
3251    variant_tensor = gen_dataset_ops.sparse_tensor_slice_dataset(
3252        self._sparse_tensor.indices, self._sparse_tensor.values,
3253        self._sparse_tensor.dense_shape)
3254    super(SparseTensorSliceDataset, self).__init__(variant_tensor)
3255
3256  @property
3257  def element_spec(self):
3258    return self._structure
3259
3260
3261class _VariantDataset(DatasetV2):
3262  """A Dataset wrapper around a `tf.variant`-typed function argument."""
3263
3264  def __init__(self, dataset_variant, structure):
3265    self._structure = structure
3266    super(_VariantDataset, self).__init__(dataset_variant)
3267
3268  def _inputs(self):
3269    return []
3270
3271  @property
3272  def element_spec(self):
3273    return self._structure
3274
3275
3276class _NestedVariant(composite_tensor.CompositeTensor):
3277
3278  def __init__(self, variant_tensor, element_spec, dataset_shape):
3279    self._variant_tensor = variant_tensor
3280    self._element_spec = element_spec
3281    self._dataset_shape = dataset_shape
3282
3283  @property
3284  def _type_spec(self):
3285    return DatasetSpec(self._element_spec, self._dataset_shape)
3286
3287
3288@tf_export("data.experimental.from_variant")
3289def from_variant(variant, structure):
3290  """Constructs a dataset from the given variant and structure.
3291
3292  Args:
3293    variant: A scalar `tf.variant` tensor representing a dataset.
3294    structure: A `tf.data.experimental.Structure` object representing the
3295      structure of each element in the dataset.
3296
3297  Returns:
3298    A `tf.data.Dataset` instance.
3299  """
3300  return _VariantDataset(variant, structure)  # pylint: disable=protected-access
3301
3302
3303@tf_export("data.experimental.to_variant")
3304def to_variant(dataset):
3305  """Returns a variant representing the given dataset.
3306
3307  Args:
3308    dataset: A `tf.data.Dataset`.
3309
3310  Returns:
3311    A scalar `tf.variant` tensor representing the given dataset.
3312  """
3313  return dataset._variant_tensor  # pylint: disable=protected-access
3314
3315
3316@tf_export(
3317    "data.DatasetSpec",
3318    v1=["data.DatasetSpec", "data.experimental.DatasetStructure"])
3319class DatasetSpec(type_spec.BatchableTypeSpec):
3320  """Type specification for `tf.data.Dataset`.
3321
3322  See `tf.TypeSpec` for more information about TensorFlow type specifications.
3323
3324  >>> dataset = tf.data.Dataset.range(3)
3325  >>> tf.data.DatasetSpec.from_value(dataset)
3326  DatasetSpec(TensorSpec(shape=(), dtype=tf.int64, name=None), TensorShape([]))
3327  """
3328
3329  __slots__ = ["_element_spec", "_dataset_shape"]
3330
3331  def __init__(self, element_spec, dataset_shape=()):
3332    self._element_spec = element_spec
3333    self._dataset_shape = tensor_shape.as_shape(dataset_shape)
3334
3335  @property
3336  def value_type(self):
3337    return Dataset
3338
3339  @property
3340  def element_spec(self):
3341    """The inner element spec."""
3342    return self._element_spec
3343
3344  def _serialize(self):
3345    return (self._element_spec, self._dataset_shape)
3346
3347  @property
3348  def _component_specs(self):
3349    return tensor_spec.TensorSpec(self._dataset_shape, dtypes.variant)
3350
3351  def _to_components(self, value):
3352    return value._variant_tensor  # pylint: disable=protected-access
3353
3354  def _from_components(self, components):
3355    # pylint: disable=protected-access
3356    if self._dataset_shape.ndims == 0:
3357      return _VariantDataset(components, self._element_spec)
3358    else:
3359      return _NestedVariant(components, self._element_spec, self._dataset_shape)
3360
3361  def _to_tensor_list(self, value):
3362    return [
3363        ops.convert_to_tensor(
3364            tf_nest.map_structure(lambda x: x._variant_tensor, value))  # pylint: disable=protected-access
3365    ]
3366
3367  @staticmethod
3368  def from_value(value):
3369    """Creates a `DatasetSpec` for the given `tf.data.Dataset` value."""
3370    return DatasetSpec(value.element_spec)  # pylint: disable=protected-access
3371
3372  def _batch(self, batch_size):
3373    return DatasetSpec(
3374        self._element_spec,
3375        tensor_shape.TensorShape([batch_size]).concatenate(self._dataset_shape))
3376
3377  def _unbatch(self):
3378    if self._dataset_shape.ndims == 0:
3379      raise ValueError("Unbatching a dataset is only supported for rank >= 1")
3380    return DatasetSpec(self._element_spec, self._dataset_shape[1:])
3381
3382  def _to_batched_tensor_list(self, value):
3383    if self._dataset_shape.ndims == 0:
3384      raise ValueError("Unbatching a dataset is only supported for rank >= 1")
3385    return self._to_tensor_list(value)
3386
3387  def _to_legacy_output_types(self):
3388    return self
3389
3390  def _to_legacy_output_shapes(self):
3391    return self
3392
3393  def _to_legacy_output_classes(self):
3394    return self
3395
3396
3397class StructuredFunctionWrapper(object):
3398  """A function wrapper that supports structured arguments and return values."""
3399
3400  def __init__(self,
3401               func,
3402               transformation_name,
3403               dataset=None,
3404               input_classes=None,
3405               input_shapes=None,
3406               input_types=None,
3407               input_structure=None,
3408               add_to_graph=True,
3409               use_legacy_function=False,
3410               defun_kwargs=None):
3411    """Creates a new `StructuredFunctionWrapper` for the given function.
3412
3413    Args:
3414      func: A function from a nested structure to another nested structure.
3415      transformation_name: Human-readable name of the transformation in which
3416        this function is being instantiated, for error messages.
3417      dataset: (Optional.) A `tf.data.Dataset`. If given, the structure of this
3418        dataset will be assumed as the structure for `func` arguments; otherwise
3419        `input_classes`, `input_shapes`, and `input_types` must be defined.
3420      input_classes: (Optional.) A nested structure of `type`. If given, this
3421        argument defines the Python types for `func` arguments.
3422      input_shapes: (Optional.) A nested structure of `tf.TensorShape`. If
3423        given, this argument defines the shapes and structure for `func`
3424        arguments.
3425      input_types: (Optional.) A nested structure of `tf.DType`. If given, this
3426        argument defines the element types and structure for `func` arguments.
3427      input_structure: (Optional.) A `Structure` object. If given, this argument
3428        defines the element types and structure for `func` arguments.
3429      add_to_graph: (Optional.) If `True`, the function will be added to the
3430        default graph, if it exists.
3431      use_legacy_function: (Optional.) A boolean that determines whether the
3432        function be created using `tensorflow.python.eager.function.defun`
3433        (default behavior) or `tensorflow.python.framework.function.Defun`
3434        (legacy behavior).
3435      defun_kwargs: (Optional.) A dictionary mapping string argument names to
3436        values. If supplied, will be passed to `function` as keyword arguments.
3437
3438    Raises:
3439      ValueError: If an invalid combination of `dataset`, `input_classes`,
3440        `input_shapes`, and `input_types` is passed.
3441    """
3442    # pylint: disable=protected-access
3443    if input_structure is None:
3444      if dataset is None:
3445        if input_classes is None or input_shapes is None or input_types is None:
3446          raise ValueError("Either `dataset`, `input_structure` or all of "
3447                           "`input_classes`, `input_shapes`, and `input_types` "
3448                           "must be specified.")
3449        self._input_structure = structure.convert_legacy_structure(
3450            input_types, input_shapes, input_classes)
3451      else:
3452        if not (input_classes is None and input_shapes is None and
3453                input_types is None):
3454          raise ValueError("Either `dataset`, `input_structure` or all of "
3455                           "`input_classes`, `input_shapes`, and `input_types` "
3456                           "must be specified.")
3457        self._input_structure = dataset.element_spec
3458    else:
3459      if not (dataset is None and input_classes is None and input_shapes is None
3460              and input_types is None):
3461        raise ValueError("Either `dataset`, `input_structure`, or all of "
3462                         "`input_classes`, `input_shapes`, and `input_types` "
3463                         "must be specified.")
3464      self._input_structure = input_structure
3465
3466    self._func = func
3467
3468    # There is no graph to add in eager mode.
3469    add_to_graph &= not context.executing_eagerly()
3470    # There are some lifetime issues when a legacy function is not added to a
3471    # out-living graph. It's already deprecated so de-prioritizing the fix.
3472    add_to_graph |= use_legacy_function
3473
3474    if defun_kwargs is None:
3475      defun_kwargs = {}
3476
3477    readable_transformation_name = transformation_name.replace(
3478        ".", "_")[:-2] if len(transformation_name) > 2 else ""
3479
3480    func_name = "_".join(
3481        [readable_transformation_name,
3482         function_utils.get_func_name(func)])
3483    # Sanitize function name to remove symbols that interfere with graph
3484    # construction.
3485    for symbol in ["<", ">", "\\", "'", " "]:
3486      func_name = func_name.replace(symbol, "")
3487
3488    ag_ctx = autograph_ctx.control_status_ctx()
3489
3490    def _warn_if_collections(transformation_name):
3491      """Prints a warning if the given graph uses common graph collections.
3492
3493      NOTE(mrry): Currently a warning is only generated for resources. Any
3494      variables created will be automatically hoisted out to the outermost scope
3495      using `init_scope()`. Some collections (such as for control-flow contexts)
3496      are benign and should not generate a warning.
3497
3498      Args:
3499        transformation_name: A human-readable name for the transformation.
3500      """
3501      warnings.warn("Creating resources inside a function passed to %s "
3502                    "is not supported. Create each resource outside the "
3503                    "function, and capture it inside the function to use it." %
3504                    transformation_name, stacklevel=5)
3505
3506    def _wrapper_helper(*args):
3507      """Wrapper for passing nested structures to and from tf.data functions."""
3508      nested_args = structure.from_compatible_tensor_list(
3509          self._input_structure, args)
3510      if not _should_unpack_args(nested_args):
3511        nested_args = (nested_args,)
3512
3513      ret = autograph.tf_convert(func, ag_ctx)(*nested_args)
3514      # If `func` returns a list of tensors, `nest.flatten()` and
3515      # `ops.convert_to_tensor()` would conspire to attempt to stack
3516      # those tensors into a single tensor, because the customized
3517      # version of `nest.flatten()` does not recurse into lists. Since
3518      # it is more likely that the list arose from returning the
3519      # result of an operation (such as `tf.numpy_function()`) that returns a
3520      # list of not-necessarily-stackable tensors, we treat the
3521      # returned value is a `tuple` instead. A user wishing to pack
3522      # the return value into a single tensor can use an explicit
3523      # `tf.stack()` before returning.
3524      if isinstance(ret, list):
3525        ret = tuple(ret)
3526
3527      try:
3528        self._output_structure = structure.type_spec_from_value(ret)
3529      except (ValueError, TypeError):
3530        six.reraise(
3531            TypeError,
3532            TypeError("Unsupported return value from function passed to "
3533                      "%s: %s." % (transformation_name, ret)),
3534            sys.exc_info()[2])
3535      return ret
3536
3537    if use_legacy_function:
3538      func_name = func_name + "_" + str(ops.uid())
3539
3540      @function.Defun(
3541          *structure.get_flat_tensor_types(self._input_structure),
3542          func_name=func_name,
3543          **defun_kwargs)
3544      def wrapper_fn(*args):
3545        ret = _wrapper_helper(*args)
3546        # _warn_if_collections(transformation_name, ops.get_default_graph(), 0)
3547        return structure.to_tensor_list(self._output_structure, ret)
3548
3549      self._function = wrapper_fn
3550      resource_tracker = tracking.ResourceTracker()
3551      with tracking.resource_tracker_scope(resource_tracker):
3552        if add_to_graph:
3553          self._function.add_to_graph(ops.get_default_graph())
3554        else:
3555          # Use the private method that will execute `wrapper_fn` but delay
3556          # adding it to the graph in case (e.g.) we need to rerun the function.
3557          self._function._create_definition_if_needed()
3558      if resource_tracker.resources:
3559        _warn_if_collections(transformation_name)
3560
3561    else:
3562      if def_function.functions_run_eagerly():
3563        warnings.warn(
3564            "Even though the tf.config.experimental_run_functions_eagerly "
3565            "option is set, this option does not apply to tf.data functions. "
3566            "tf.data functions are still traced and executed as graphs.")
3567
3568      defun_kwargs.update({"func_name": func_name})
3569      defun_kwargs.update({"_tf_data_function": True})
3570
3571      # Note: _wrapper_helper will apply autograph based on context.
3572      @eager_function.defun_with_attributes(
3573          input_signature=structure.get_flat_tensor_specs(
3574              self._input_structure),
3575          autograph=False,
3576          attributes=defun_kwargs)
3577      def wrapper_fn(*args):  # pylint: disable=missing-docstring
3578        ret = _wrapper_helper(*args)
3579        ret = structure.to_tensor_list(self._output_structure, ret)
3580        return [ops.convert_to_tensor(t) for t in ret]
3581
3582      resource_tracker = tracking.ResourceTracker()
3583      with tracking.resource_tracker_scope(resource_tracker):
3584        # TODO(b/141462134): Switch to using garbage collection.
3585        self._function = wrapper_fn.get_concrete_function()
3586        if add_to_graph:
3587          self._function.add_to_graph(ops.get_default_graph())
3588
3589      if resource_tracker.resources:
3590        _warn_if_collections(transformation_name)
3591
3592      outer_graph_seed = ops.get_default_graph().seed
3593      if outer_graph_seed and self._function.graph.seed == outer_graph_seed:
3594        if self._function.graph._seed_used:
3595          warnings.warn(
3596              "Seed %s from outer graph might be getting used by function %s, "
3597              "if the random op has not been provided any seed. Explicitly set "
3598              "the seed in the function if this is not the intended behavior."
3599              %(outer_graph_seed, func_name), stacklevel=4)
3600
3601  @property
3602  def output_structure(self):
3603    return self._output_structure
3604
3605  @property
3606  def output_classes(self):
3607    return nest.map_structure(
3608        lambda component_spec: component_spec._to_legacy_output_classes(),  # pylint: disable=protected-access
3609        self._output_structure)
3610
3611  @property
3612  def output_shapes(self):
3613    return nest.map_structure(
3614        lambda component_spec: component_spec._to_legacy_output_shapes(),  # pylint: disable=protected-access
3615        self._output_structure)
3616
3617  @property
3618  def output_types(self):
3619    return nest.map_structure(
3620        lambda component_spec: component_spec._to_legacy_output_types(),  # pylint: disable=protected-access
3621        self._output_structure)
3622
3623  @property
3624  def function(self):
3625    return self._function
3626
3627
3628class _GeneratorDataset(DatasetSource):
3629  """A `Dataset` that generates elements by invoking a function."""
3630
3631  def __init__(self, init_args, init_func, next_func, finalize_func,
3632               output_signature):
3633    """Constructs a `_GeneratorDataset`.
3634
3635    Args:
3636      init_args: A nested structure representing the arguments to `init_func`.
3637      init_func: A TensorFlow function that will be called on `init_args` each
3638        time a C++ iterator over this dataset is constructed. Returns a nested
3639        structure representing the "state" of the dataset.
3640      next_func: A TensorFlow function that will be called on the result of
3641        `init_func` to produce each element, and that raises `OutOfRangeError`
3642        to terminate iteration.
3643      finalize_func: A TensorFlow function that will be called on the result of
3644        `init_func` immediately before a C++ iterator over this dataset is
3645        destroyed. The return value is ignored.
3646      output_signature: A nested structure of `tf.TypeSpec` objects describing
3647        the output of `next_func`.
3648    """
3649    self._init_args = init_args
3650
3651    self._init_structure = structure.type_spec_from_value(init_args)
3652
3653    self._init_func = StructuredFunctionWrapper(
3654        init_func,
3655        self._transformation_name(),
3656        input_structure=self._init_structure)
3657
3658    self._next_func = StructuredFunctionWrapper(
3659        next_func,
3660        self._transformation_name(),
3661        input_structure=self._init_func.output_structure)
3662
3663    self._finalize_func = StructuredFunctionWrapper(
3664        finalize_func,
3665        self._transformation_name(),
3666        input_structure=self._init_func.output_structure)
3667
3668    self._output_signature = output_signature
3669
3670    variant_tensor = gen_dataset_ops.generator_dataset(
3671        structure.to_tensor_list(self._init_structure, self._init_args) +
3672        self._init_func.function.captured_inputs,
3673        self._next_func.function.captured_inputs,
3674        self._finalize_func.function.captured_inputs,
3675        init_func=self._init_func.function,
3676        next_func=self._next_func.function,
3677        finalize_func=self._finalize_func.function,
3678        **self._flat_structure)
3679    super(_GeneratorDataset, self).__init__(variant_tensor)
3680
3681  @property
3682  def element_spec(self):
3683    return self._output_signature
3684
3685  def _transformation_name(self):
3686    return "Dataset.from_generator()"
3687
3688
3689class ZipDataset(DatasetV2):
3690  """A `Dataset` that zips its inputs together."""
3691
3692  def __init__(self, datasets):
3693    """See `Dataset.zip()` for details."""
3694    for ds in nest.flatten(datasets):
3695      if not isinstance(ds, DatasetV2):
3696        if isinstance(ds, list):
3697          message = ("The argument to `Dataset.zip()` must be a nested "
3698                     "structure of `Dataset` objects. Nested structures do not "
3699                     "support Python lists; please use a tuple instead.")
3700        else:
3701          message = ("The argument to `Dataset.zip()` must be a nested "
3702                     "structure of `Dataset` objects.")
3703        raise TypeError(message)
3704    self._datasets = datasets
3705    self._structure = nest.pack_sequence_as(
3706        self._datasets,
3707        [ds.element_spec for ds in nest.flatten(self._datasets)])
3708    variant_tensor = gen_dataset_ops.zip_dataset(
3709        [ds._variant_tensor for ds in nest.flatten(self._datasets)],
3710        **self._flat_structure)
3711    super(ZipDataset, self).__init__(variant_tensor)
3712
3713  def _inputs(self):
3714    return nest.flatten(self._datasets)
3715
3716  @property
3717  def element_spec(self):
3718    return self._structure
3719
3720
3721class ConcatenateDataset(DatasetV2):
3722  """A `Dataset` that concatenates its input with given dataset."""
3723
3724  def __init__(self, input_dataset, dataset_to_concatenate):
3725    """See `Dataset.concatenate()` for details."""
3726    self._input_dataset = input_dataset
3727    self._dataset_to_concatenate = dataset_to_concatenate
3728
3729    output_types = get_legacy_output_types(input_dataset)
3730    if output_types != get_legacy_output_types(dataset_to_concatenate):
3731      raise TypeError(
3732          "Two datasets to concatenate have different types %s and %s" %
3733          (output_types, get_legacy_output_types(dataset_to_concatenate)))
3734
3735    output_classes = get_legacy_output_classes(input_dataset)
3736    if output_classes != get_legacy_output_classes(dataset_to_concatenate):
3737      raise TypeError(
3738          "Two datasets to concatenate have different classes %s and %s" %
3739          (output_classes, get_legacy_output_classes(dataset_to_concatenate)))
3740
3741    input_shapes = get_legacy_output_shapes(self._input_dataset)
3742    output_shapes = nest.pack_sequence_as(input_shapes, [
3743        ts1.most_specific_compatible_shape(ts2)
3744        for (ts1, ts2) in zip(
3745            nest.flatten(input_shapes),
3746            nest.flatten(get_legacy_output_shapes(
3747                self._dataset_to_concatenate)))
3748    ])
3749
3750    self._structure = structure.convert_legacy_structure(
3751        output_types, output_shapes, output_classes)
3752
3753    self._input_datasets = [input_dataset, dataset_to_concatenate]
3754    # pylint: disable=protected-access
3755    variant_tensor = gen_dataset_ops.concatenate_dataset(
3756        input_dataset._variant_tensor, dataset_to_concatenate._variant_tensor,
3757        **self._flat_structure)
3758    # pylint: enable=protected-access
3759    super(ConcatenateDataset, self).__init__(variant_tensor)
3760
3761  def _inputs(self):
3762    return self._input_datasets
3763
3764  @property
3765  def element_spec(self):
3766    return self._structure
3767
3768
3769class RepeatDataset(UnaryUnchangedStructureDataset):
3770  """A `Dataset` that repeats its input several times."""
3771
3772  def __init__(self, input_dataset, count):
3773    """See `Dataset.repeat()` for details."""
3774    self._input_dataset = input_dataset
3775    if count is None:
3776      self._count = constant_op.constant(-1, dtype=dtypes.int64, name="count")
3777    else:
3778      self._count = ops.convert_to_tensor(
3779          count, dtype=dtypes.int64, name="count")
3780    variant_tensor = gen_dataset_ops.repeat_dataset(
3781        input_dataset._variant_tensor,  # pylint: disable=protected-access
3782        count=self._count,
3783        **self._flat_structure)
3784    super(RepeatDataset, self).__init__(input_dataset, variant_tensor)
3785
3786
3787class RangeDataset(DatasetSource):
3788  """A `Dataset` of a step separated range of values."""
3789
3790  def __init__(self, *args, **kwargs):
3791    """See `Dataset.range()` for details."""
3792    self._parse_args(*args, **kwargs)
3793    self._structure = tensor_spec.TensorSpec([], self._output_type)
3794    variant_tensor = gen_dataset_ops.range_dataset(
3795        start=self._start,
3796        stop=self._stop,
3797        step=self._step,
3798        **self._flat_structure)
3799    super(RangeDataset, self).__init__(variant_tensor)
3800
3801  def _parse_args(self, *args, **kwargs):
3802    """Parse arguments according to the same rules as the `range()` builtin."""
3803    if len(args) == 1:
3804      self._start = self._build_tensor(0, "start")
3805      self._stop = self._build_tensor(args[0], "stop")
3806      self._step = self._build_tensor(1, "step")
3807    elif len(args) == 2:
3808      self._start = self._build_tensor(args[0], "start")
3809      self._stop = self._build_tensor(args[1], "stop")
3810      self._step = self._build_tensor(1, "step")
3811    elif len(args) == 3:
3812      self._start = self._build_tensor(args[0], "start")
3813      self._stop = self._build_tensor(args[1], "stop")
3814      self._step = self._build_tensor(args[2], "step")
3815    else:
3816      raise ValueError("Invalid arguments to RangeDataset: %s" % str(args))
3817    if "output_type" in kwargs:
3818      self._output_type = kwargs["output_type"]
3819    else:
3820      self._output_type = dtypes.int64
3821
3822  def _build_tensor(self, int64_value, name):
3823    return ops.convert_to_tensor(int64_value, dtype=dtypes.int64, name=name)
3824
3825  @property
3826  def element_spec(self):
3827    return self._structure
3828
3829
3830class CacheDataset(UnaryUnchangedStructureDataset):
3831  """A `Dataset` that caches elements of its input."""
3832
3833  def __init__(self, input_dataset, filename):
3834    """See `Dataset.cache()` for details."""
3835    self._input_dataset = input_dataset
3836    self._filename = ops.convert_to_tensor(
3837        filename, dtype=dtypes.string, name="filename")
3838    if tf2.enabled() and (context.executing_eagerly() or ops.inside_function()):
3839      variant_tensor = gen_dataset_ops.cache_dataset_v2(
3840          input_dataset._variant_tensor,  # pylint: disable=protected-access
3841          filename=self._filename,
3842          cache=gen_dataset_ops.dummy_memory_cache(),
3843          **self._flat_structure)
3844    else:
3845      variant_tensor = gen_dataset_ops.cache_dataset(
3846          input_dataset._variant_tensor,  # pylint: disable=protected-access
3847          filename=self._filename,
3848          **self._flat_structure)
3849    super(CacheDataset, self).__init__(input_dataset, variant_tensor)
3850
3851
3852class ShuffleDataset(UnaryUnchangedStructureDataset):
3853  """A `Dataset` that randomly shuffles the elements of its input."""
3854
3855  def __init__(self,
3856               input_dataset,
3857               buffer_size,
3858               seed=None,
3859               reshuffle_each_iteration=None):
3860    """Randomly shuffles the elements of this dataset.
3861
3862    Args:
3863      input_dataset: The input dataset.
3864      buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
3865        elements from this dataset from which the new dataset will sample.
3866      seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random
3867        seed that will be used to create the distribution. See
3868        `tf.random.set_seed` for behavior.
3869      reshuffle_each_iteration: (Optional.) A boolean, which if true indicates
3870        that the dataset should be pseudorandomly reshuffled each time it is
3871        iterated over. (Defaults to `True`.)
3872
3873    Returns:
3874      A `Dataset`.
3875
3876    Raises:
3877      ValueError: if invalid arguments are provided.
3878    """
3879    self._input_dataset = input_dataset
3880    self._buffer_size = ops.convert_to_tensor(
3881        buffer_size, dtype=dtypes.int64, name="buffer_size")
3882    self._seed, self._seed2 = random_seed.get_seed(seed)
3883    if reshuffle_each_iteration is None:
3884      reshuffle_each_iteration = True
3885    self._reshuffle_each_iteration = reshuffle_each_iteration
3886
3887    if (tf2.enabled() and
3888        (context.executing_eagerly() or ops.inside_function())):
3889      variant_tensor = gen_dataset_ops.shuffle_dataset_v3(
3890          input_dataset._variant_tensor,  # pylint: disable=protected-access
3891          buffer_size=self._buffer_size,
3892          seed=self._seed,
3893          seed2=self._seed2,
3894          seed_generator=gen_dataset_ops.dummy_seed_generator(),
3895          reshuffle_each_iteration=self._reshuffle_each_iteration,
3896          **self._flat_structure)
3897    else:
3898      variant_tensor = gen_dataset_ops.shuffle_dataset(
3899          input_dataset._variant_tensor,  # pylint: disable=protected-access
3900          buffer_size=self._buffer_size,
3901          seed=self._seed,
3902          seed2=self._seed2,
3903          reshuffle_each_iteration=self._reshuffle_each_iteration,
3904          **self._flat_structure)
3905    super(ShuffleDataset, self).__init__(input_dataset, variant_tensor)
3906
3907
3908class TakeDataset(UnaryUnchangedStructureDataset):
3909  """A `Dataset` containing the first `count` elements from its input."""
3910
3911  def __init__(self, input_dataset, count):
3912    """See `Dataset.take()` for details."""
3913    self._input_dataset = input_dataset
3914    self._count = ops.convert_to_tensor(count, dtype=dtypes.int64, name="count")
3915    variant_tensor = gen_dataset_ops.take_dataset(
3916        input_dataset._variant_tensor,  # pylint: disable=protected-access
3917        count=self._count,
3918        **self._flat_structure)
3919    super(TakeDataset, self).__init__(input_dataset, variant_tensor)
3920
3921
3922class SkipDataset(UnaryUnchangedStructureDataset):
3923  """A `Dataset` skipping the first `count` elements from its input."""
3924
3925  def __init__(self, input_dataset, count):
3926    """See `Dataset.skip()` for details."""
3927    self._input_dataset = input_dataset
3928    self._count = ops.convert_to_tensor(count, dtype=dtypes.int64, name="count")
3929    variant_tensor = gen_dataset_ops.skip_dataset(
3930        input_dataset._variant_tensor,  # pylint: disable=protected-access
3931        count=self._count,
3932        **self._flat_structure)
3933    super(SkipDataset, self).__init__(input_dataset, variant_tensor)
3934
3935
3936class ShardDataset(UnaryUnchangedStructureDataset):
3937  """A `Dataset` for sharding its input."""
3938
3939  def __init__(self, input_dataset, num_shards, index):
3940    """See `Dataset.shard()` for details."""
3941    self._input_dataset = input_dataset
3942    self._num_shards = ops.convert_to_tensor(
3943        num_shards, dtype=dtypes.int64, name="num_shards")
3944    self._index = ops.convert_to_tensor(index, dtype=dtypes.int64, name="index")
3945    variant_tensor = gen_dataset_ops.shard_dataset(
3946        input_dataset._variant_tensor,  # pylint: disable=protected-access
3947        num_shards=self._num_shards,
3948        index=self._index,
3949        **self._flat_structure)
3950    super(ShardDataset, self).__init__(input_dataset, variant_tensor)
3951
3952
3953class BatchDataset(UnaryDataset):
3954  """A `Dataset` that batches contiguous elements from its input."""
3955
3956  def __init__(self, input_dataset, batch_size, drop_remainder):
3957    """See `Dataset.batch()` for details."""
3958    self._input_dataset = input_dataset
3959    self._batch_size = ops.convert_to_tensor(
3960        batch_size, dtype=dtypes.int64, name="batch_size")
3961    self._drop_remainder = ops.convert_to_tensor(
3962        drop_remainder, dtype=dtypes.bool, name="drop_remainder")
3963
3964    constant_drop_remainder = tensor_util.constant_value(self._drop_remainder)
3965    # pylint: disable=protected-access
3966    if constant_drop_remainder:
3967      # NOTE(mrry): `constant_drop_remainder` may be `None` (unknown statically)
3968      # or `False` (explicitly retaining the remainder).
3969      # pylint: disable=g-long-lambda
3970      constant_batch_size = tensor_util.constant_value(self._batch_size)
3971      self._structure = nest.map_structure(
3972          lambda component_spec: component_spec._batch(constant_batch_size),
3973          input_dataset.element_spec)
3974    else:
3975      self._structure = nest.map_structure(
3976          lambda component_spec: component_spec._batch(None),
3977          input_dataset.element_spec)
3978    variant_tensor = gen_dataset_ops.batch_dataset_v2(
3979        input_dataset._variant_tensor,
3980        batch_size=self._batch_size,
3981        drop_remainder=self._drop_remainder,
3982        **self._flat_structure)
3983    super(BatchDataset, self).__init__(input_dataset, variant_tensor)
3984
3985  @property
3986  def element_spec(self):
3987    return self._structure
3988
3989
3990class ParallelBatchDataset(UnaryDataset):
3991  """A `Dataset` that batches contiguous elements from its input in parallel."""
3992
3993  def __init__(self, input_dataset, batch_size, drop_remainder,
3994               num_parallel_calls):
3995    """See `Dataset.batch()` for details."""
3996    self._input_dataset = input_dataset
3997    self._batch_size = ops.convert_to_tensor(
3998        batch_size, dtype=dtypes.int64, name="batch_size")
3999    self._drop_remainder = ops.convert_to_tensor(
4000        drop_remainder, dtype=dtypes.bool, name="drop_remainder")
4001    self._num_parallel_calls = ops.convert_to_tensor(
4002        num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls")
4003
4004    constant_drop_remainder = tensor_util.constant_value(self._drop_remainder)
4005    # pylint: disable=protected-access
4006    if constant_drop_remainder:
4007      # NOTE(mrry): `constant_drop_remainder` may be `None` (unknown statically)
4008      # or `False` (explicitly retaining the remainder).
4009      # pylint: disable=g-long-lambda
4010      constant_batch_size = tensor_util.constant_value(self._batch_size)
4011      self._structure = nest.map_structure(
4012          lambda component_spec: component_spec._batch(constant_batch_size),
4013          input_dataset.element_spec)
4014    else:
4015      self._structure = nest.map_structure(
4016          lambda component_spec: component_spec._batch(None),
4017          input_dataset.element_spec)
4018    variant_tensor = gen_dataset_ops.parallel_batch_dataset(
4019        input_dataset._variant_tensor,
4020        batch_size=self._batch_size,
4021        num_parallel_calls=self._num_parallel_calls,
4022        drop_remainder=self._drop_remainder,
4023        **self._flat_structure)
4024    super(ParallelBatchDataset, self).__init__(input_dataset, variant_tensor)
4025
4026  @property
4027  def element_spec(self):
4028    return self._structure
4029
4030
4031class _NumpyIterator(object):
4032  """Iterator over a dataset with elements converted to numpy."""
4033
4034  __slots__ = ["_iterator"]
4035
4036  def __init__(self, dataset):
4037    self._iterator = iter(dataset)
4038
4039  def __iter__(self):
4040    return self
4041
4042  def __next__(self):
4043
4044    def to_numpy(x):
4045      numpy = x._numpy()  # pylint: disable=protected-access
4046      if isinstance(numpy, np.ndarray):
4047        # `numpy` shares the same underlying buffer as the `x` Tensor.
4048        # Tensors are expected to be immutable, so we disable writes.
4049        numpy.setflags(write=False)
4050      return numpy
4051
4052    return nest.map_structure(to_numpy, next(self._iterator))
4053
4054  def next(self):
4055    return self.__next__()
4056
4057
4058class _VariantTracker(tracking.CapturableResource):
4059  """Allows export of functions capturing a Dataset in SavedModels.
4060
4061  When saving a SavedModel, `tf.saved_model.save` traverses the object
4062  graph. Since Datasets reference _VariantTracker objects, that traversal will
4063  find a _VariantTracker for each Dataset and so know how to save and restore
4064  functions which reference the Dataset's variant Tensor.
4065  """
4066
4067  def __init__(self, variant_tensor, resource_creator):
4068    """Record that `variant_tensor` is associated with `resource_creator`.
4069
4070    Args:
4071      variant_tensor: The variant-dtype Tensor associated with the Dataset. This
4072        Tensor will be a captured input to functions which use the Dataset, and
4073        is used by saving code to identify the corresponding _VariantTracker.
4074      resource_creator: A zero-argument function which creates a new
4075        variant-dtype Tensor. This function will be included in SavedModels and
4076        run to re-create the Dataset's variant Tensor on restore.
4077    """
4078    super(_VariantTracker, self).__init__(device="CPU")
4079    self._resource_handle = variant_tensor
4080    self._create_resource = resource_creator
4081
4082
4083def _is_padded_shape_compatible_with(padded_shape, input_component_shape):
4084  """Returns `True` if `input_component_shape` can be padded to `padded_shape`.
4085
4086  Args:
4087    padded_shape: A `tf.TensorShape`.
4088    input_component_shape: A `tf.TensorShape`.
4089
4090  Returns:
4091    `True` if `input_component_shape` can be padded to `padded_shape`, otherwise
4092    `False`.
4093  """
4094
4095  if padded_shape.dims is None or input_component_shape.dims is None:
4096    return True
4097  if len(padded_shape.dims) != len(input_component_shape.dims):
4098    return False
4099  for padded_dim, input_dim in zip(
4100      padded_shape.dims, input_component_shape.dims):
4101    if (padded_dim.value is not None and input_dim.value is not None
4102        and padded_dim.value < input_dim.value):
4103      return False
4104  return True
4105
4106
4107def _padded_shape_to_tensor(padded_shape, input_component_shape):
4108  """Converts `padded_shape` to a `tf.Tensor` representing that shape.
4109
4110  Args:
4111    padded_shape: A shape-like object, which may be a `tf.TensorShape`, a Python
4112      sequence, or a 1-D `tf.Tensor` of `tf.int64` elements.
4113    input_component_shape: A `tf.TensorShape`, with which `padded_shape` must
4114      be compatible.
4115
4116  Returns:
4117    A 1-D `tf.Tensor` of `tf.int64` elements, representing `padded_shape`.
4118
4119  Raises:
4120    ValueError: If `padded_shape` is not a shape or not compatible with
4121      `input_component_shape`.
4122    TypeError: If `padded_shape` is not convertible to a `tf.int64` tensor.
4123  """
4124  try:
4125    # Try to convert the `padded_shape` to a `tf.TensorShape`
4126    padded_shape_as_shape = tensor_shape.as_shape(padded_shape)
4127    # We will return the "canonical" tensor representation, which uses
4128    # `-1` in place of `None`.
4129    ret = ops.convert_to_tensor(
4130        [dim if dim is not None else -1
4131         for dim in padded_shape_as_shape.as_list()], dtype=dtypes.int64)
4132  except (TypeError, ValueError):
4133    # The argument was not trivially convertible to a
4134    # `tf.TensorShape`, so fall back on the conversion to tensor
4135    # machinery.
4136    ret = ops.convert_to_tensor(padded_shape, preferred_dtype=dtypes.int64)
4137    if ret.shape.dims is not None and len(ret.shape.dims) != 1:
4138      six.reraise(ValueError, ValueError(
4139          "Padded shape %s must be a 1-D tensor of tf.int64 values, but its "
4140          "shape was %s." % (padded_shape, ret.shape)), sys.exc_info()[2])
4141    if ret.dtype != dtypes.int64:
4142      six.reraise(
4143          TypeError,
4144          TypeError(
4145              "Padded shape %s must be a 1-D tensor of tf.int64 values, but "
4146              "its element type was %s." % (padded_shape, ret.dtype.name)),
4147          sys.exc_info()[2])
4148    padded_shape_as_shape = tensor_util.constant_value_as_shape(ret)
4149
4150  if not _is_padded_shape_compatible_with(padded_shape_as_shape,
4151                                          input_component_shape):
4152    raise ValueError("The padded shape %s is not compatible with the "
4153                     "corresponding input component shape %s."
4154                     % (padded_shape_as_shape, input_component_shape))
4155
4156  return ret
4157
4158
4159def _padding_value_to_tensor(value, output_type):
4160  """Converts the padding value to a tensor.
4161
4162  Args:
4163    value: The padding value.
4164    output_type: Its expected dtype.
4165
4166  Returns:
4167    A scalar `Tensor`.
4168
4169  Raises:
4170    ValueError: if the padding value is not a scalar.
4171    TypeError: if the padding value's type does not match `output_type`.
4172  """
4173  value = ops.convert_to_tensor(value, name="padding_value")
4174  if not value.shape.is_compatible_with(tensor_shape.TensorShape([])):
4175    raise ValueError("Padding value should be a scalar, but is not: %s" % value)
4176  if value.dtype != output_type:
4177    raise TypeError("Padding value tensor (%s) does not match output type: %s" %
4178                    (value, output_type))
4179  return value
4180
4181
4182def _padding_values_or_default(padding_values, input_dataset):
4183  """Returns padding values with None elements replaced with default values."""
4184
4185  def make_zero(t):
4186    if t.base_dtype == dtypes.string:
4187      return ""
4188    elif t.base_dtype == dtypes.variant:
4189      error_msg = ("Unable to create padding for field of type 'variant' "
4190                   "because t.base_type == dtypes.variant == "
4191                   "{}.".format(t.base_dtype))
4192      raise TypeError(error_msg)
4193    elif t.base_dtype == dtypes.bfloat16:
4194      # Special case `bfloat16` because it is not supported by NumPy.
4195      return constant_op.constant(0, dtype=dtypes.bfloat16)
4196    else:
4197      return np.zeros_like(t.as_numpy_dtype())
4198
4199  def value_or_default(value, default):
4200    return default if value is None else value
4201
4202  default_padding = nest.map_structure(
4203      make_zero,
4204      get_legacy_output_types(input_dataset))
4205  return nest.map_structure_up_to(padding_values, value_or_default,
4206                                  padding_values, default_padding)
4207
4208
4209class PaddedBatchDataset(UnaryDataset):
4210  """A `Dataset` that batches and pads contiguous elements from its input."""
4211
4212  def __init__(self, input_dataset, batch_size, padded_shapes, padding_values,
4213               drop_remainder):
4214    """See `Dataset.batch()` for details."""
4215    self._input_dataset = input_dataset
4216
4217    def check_types(component_spec):
4218      if not isinstance(component_spec, tensor_spec.TensorSpec):
4219        raise TypeError("Padded batching of components of type ",
4220                        type(component_spec), " is not supported.")
4221
4222    nest.map_structure(check_types, input_dataset.element_spec)
4223    self._input_dataset = input_dataset
4224    self._batch_size = ops.convert_to_tensor(
4225        batch_size, dtype=dtypes.int64, name="batch_size")
4226    padding_values = _padding_values_or_default(padding_values, input_dataset)
4227
4228    input_shapes = get_legacy_output_shapes(input_dataset)
4229    flat_padded_shapes = nest.flatten_up_to(input_shapes, padded_shapes)
4230
4231    flat_padded_shapes_as_tensors = []
4232
4233    for input_component_shape, padded_shape in zip(
4234        nest.flatten(input_shapes), flat_padded_shapes):
4235      flat_padded_shapes_as_tensors.append(
4236          _padded_shape_to_tensor(padded_shape, input_component_shape))
4237
4238    self._padded_shapes = nest.pack_sequence_as(input_shapes,
4239                                                flat_padded_shapes_as_tensors)
4240
4241    # If padding_values is a single element and input_shapes is a structure,
4242    # "broadcast" padding_values to the same structure as input_shapes.
4243    if nest.is_sequence(input_shapes) and not nest.is_sequence(padding_values):
4244      padding_values = nest.map_structure(lambda _: padding_values,
4245                                          input_shapes)
4246
4247    self._padding_values = nest.map_structure_up_to(
4248        input_shapes, _padding_value_to_tensor, padding_values,
4249        get_legacy_output_types(input_dataset))
4250    self._drop_remainder = ops.convert_to_tensor(
4251        drop_remainder, dtype=dtypes.bool, name="drop_remainder")
4252
4253    def _padded_shape_to_batch_shape(s):
4254      return tensor_shape.TensorShape([
4255          tensor_util.constant_value(self._batch_size)
4256          if smart_cond.smart_constant_value(self._drop_remainder) else None
4257      ]).concatenate(tensor_util.constant_value_as_shape(s))
4258
4259    output_shapes = nest.map_structure(
4260        _padded_shape_to_batch_shape, self._padded_shapes)
4261    self._structure = structure.convert_legacy_structure(
4262        get_legacy_output_types(self._input_dataset), output_shapes,
4263        get_legacy_output_classes(self._input_dataset))
4264
4265    # pylint: disable=protected-access
4266    # TODO(jsimsa): Switch to using v2 only any time after 6/30/2018.
4267    if smart_cond.smart_constant_value(self._drop_remainder) is False:
4268      variant_tensor = gen_dataset_ops.padded_batch_dataset(
4269          input_dataset._variant_tensor,  # pylint: disable=protected-access
4270          batch_size=self._batch_size,
4271          padded_shapes=[
4272              ops.convert_to_tensor(s, dtype=dtypes.int64)
4273              for s in nest.flatten(self._padded_shapes)
4274          ],
4275          padding_values=nest.flatten(self._padding_values),
4276          output_shapes=structure.get_flat_tensor_shapes(self._structure))
4277    else:
4278      variant_tensor = gen_dataset_ops.padded_batch_dataset_v2(
4279          input_dataset._variant_tensor,  # pylint: disable=protected-access
4280          batch_size=self._batch_size,
4281          padded_shapes=[
4282              ops.convert_to_tensor(s, dtype=dtypes.int64)
4283              for s in nest.flatten(self._padded_shapes)
4284          ],
4285          padding_values=nest.flatten(self._padding_values),
4286          drop_remainder=self._drop_remainder,
4287          output_shapes=structure.get_flat_tensor_shapes(self._structure))
4288    super(PaddedBatchDataset, self).__init__(input_dataset, variant_tensor)
4289
4290  @property
4291  def element_spec(self):
4292    return self._structure
4293
4294
4295def _should_unpack_args(args):
4296  """Returns `True` if `args` should be `*args` when passed to a callable."""
4297  return type(args) is tuple  # pylint: disable=unidiomatic-typecheck
4298
4299
4300class MapDataset(UnaryDataset):
4301  """A `Dataset` that maps a function over elements in its input."""
4302
4303  def __init__(self,
4304               input_dataset,
4305               map_func,
4306               use_inter_op_parallelism=True,
4307               preserve_cardinality=False,
4308               use_legacy_function=False):
4309    """See `Dataset.map()` for details."""
4310    self._input_dataset = input_dataset
4311    self._use_inter_op_parallelism = use_inter_op_parallelism
4312    self._preserve_cardinality = preserve_cardinality
4313    self._map_func = StructuredFunctionWrapper(
4314        map_func,
4315        self._transformation_name(),
4316        dataset=input_dataset,
4317        use_legacy_function=use_legacy_function)
4318    variant_tensor = gen_dataset_ops.map_dataset(
4319        input_dataset._variant_tensor,  # pylint: disable=protected-access
4320        self._map_func.function.captured_inputs,
4321        f=self._map_func.function,
4322        use_inter_op_parallelism=self._use_inter_op_parallelism,
4323        preserve_cardinality=self._preserve_cardinality,
4324        **self._flat_structure)
4325    super(MapDataset, self).__init__(input_dataset, variant_tensor)
4326
4327  def _functions(self):
4328    return [self._map_func]
4329
4330  @property
4331  def element_spec(self):
4332    return self._map_func.output_structure
4333
4334  def _transformation_name(self):
4335    return "Dataset.map()"
4336
4337
4338class ParallelMapDataset(UnaryDataset):
4339  """A `Dataset` that maps a function over elements in its input in parallel."""
4340
4341  def __init__(self,
4342               input_dataset,
4343               map_func,
4344               num_parallel_calls,
4345               deterministic,
4346               use_inter_op_parallelism=True,
4347               preserve_cardinality=False,
4348               use_legacy_function=False):
4349    """See `Dataset.map()` for details."""
4350    self._input_dataset = input_dataset
4351    self._use_inter_op_parallelism = use_inter_op_parallelism
4352    self._map_func = StructuredFunctionWrapper(
4353        map_func,
4354        self._transformation_name(),
4355        dataset=input_dataset,
4356        use_legacy_function=use_legacy_function)
4357    if deterministic is None:
4358      self._deterministic = "default"
4359    elif deterministic:
4360      self._deterministic = "true"
4361    else:
4362      self._deterministic = "false"
4363    self._preserve_cardinality = preserve_cardinality
4364    self._num_parallel_calls = ops.convert_to_tensor(
4365        num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls")
4366    variant_tensor = gen_dataset_ops.parallel_map_dataset_v2(
4367        input_dataset._variant_tensor,  # pylint: disable=protected-access
4368        self._map_func.function.captured_inputs,
4369        f=self._map_func.function,
4370        num_parallel_calls=self._num_parallel_calls,
4371        deterministic=self._deterministic,
4372        use_inter_op_parallelism=self._use_inter_op_parallelism,
4373        preserve_cardinality=self._preserve_cardinality,
4374        **self._flat_structure)
4375    super(ParallelMapDataset, self).__init__(input_dataset, variant_tensor)
4376
4377  def _functions(self):
4378    return [self._map_func]
4379
4380  @property
4381  def element_spec(self):
4382    return self._map_func.output_structure
4383
4384  def _transformation_name(self):
4385    return "Dataset.map()"
4386
4387
4388class FlatMapDataset(UnaryDataset):
4389  """A `Dataset` that maps a function over its input and flattens the result."""
4390
4391  def __init__(self, input_dataset, map_func):
4392    """See `Dataset.flat_map()` for details."""
4393    self._input_dataset = input_dataset
4394    self._map_func = StructuredFunctionWrapper(
4395        map_func, self._transformation_name(), dataset=input_dataset)
4396    if not isinstance(self._map_func.output_structure, DatasetSpec):
4397      raise TypeError(
4398          "`map_func` must return a `Dataset` object. Got {}".format(
4399              type(self._map_func.output_structure)))
4400    self._structure = self._map_func.output_structure._element_spec  # pylint: disable=protected-access
4401    variant_tensor = gen_dataset_ops.flat_map_dataset(
4402        input_dataset._variant_tensor,  # pylint: disable=protected-access
4403        self._map_func.function.captured_inputs,
4404        f=self._map_func.function,
4405        **self._flat_structure)
4406    super(FlatMapDataset, self).__init__(input_dataset, variant_tensor)
4407
4408  def _functions(self):
4409    return [self._map_func]
4410
4411  @property
4412  def element_spec(self):
4413    return self._structure
4414
4415  def _transformation_name(self):
4416    return "Dataset.flat_map()"
4417
4418
4419class InterleaveDataset(UnaryDataset):
4420  """A `Dataset` that interleaves the result of transformed inputs."""
4421
4422  def __init__(self, input_dataset, map_func, cycle_length, block_length):
4423    """See `Dataset.interleave()` for details."""
4424
4425    self._input_dataset = input_dataset
4426    self._map_func = StructuredFunctionWrapper(
4427        map_func, self._transformation_name(), dataset=input_dataset)
4428    if not isinstance(self._map_func.output_structure, DatasetSpec):
4429      raise TypeError(
4430          "`map_func` must return a `Dataset` object. Got {}".format(
4431              type(self._map_func.output_structure)))
4432    self._structure = self._map_func.output_structure._element_spec  # pylint: disable=protected-access
4433    self._cycle_length = ops.convert_to_tensor(
4434        cycle_length, dtype=dtypes.int64, name="cycle_length")
4435    self._block_length = ops.convert_to_tensor(
4436        block_length, dtype=dtypes.int64, name="block_length")
4437
4438    variant_tensor = gen_dataset_ops.interleave_dataset(
4439        input_dataset._variant_tensor,  # pylint: disable=protected-access
4440        self._map_func.function.captured_inputs,  # pylint: disable=protected-access
4441        self._cycle_length,
4442        self._block_length,
4443        f=self._map_func.function,
4444        **self._flat_structure)
4445    super(InterleaveDataset, self).__init__(input_dataset, variant_tensor)
4446
4447  def _functions(self):
4448    return [self._map_func]
4449
4450  @property
4451  def element_spec(self):
4452    return self._structure
4453
4454  def _transformation_name(self):
4455    return "Dataset.interleave()"
4456
4457
4458class ParallelInterleaveDataset(UnaryDataset):
4459  """A `Dataset` that maps a function over its input and interleaves the result."""
4460
4461  def __init__(self,
4462               input_dataset,
4463               map_func,
4464               cycle_length,
4465               block_length,
4466               num_parallel_calls,
4467               buffer_output_elements=AUTOTUNE,
4468               prefetch_input_elements=AUTOTUNE,
4469               deterministic=None):
4470    """See `Dataset.interleave()` for details."""
4471    self._input_dataset = input_dataset
4472    self._map_func = StructuredFunctionWrapper(
4473        map_func, self._transformation_name(), dataset=input_dataset)
4474    if not isinstance(self._map_func.output_structure, DatasetSpec):
4475      raise TypeError(
4476          "`map_func` must return a `Dataset` object. Got {}".format(
4477              type(self._map_func.output_structure)))
4478    self._structure = self._map_func.output_structure._element_spec  # pylint: disable=protected-access
4479    self._cycle_length = ops.convert_to_tensor(
4480        cycle_length, dtype=dtypes.int64, name="cycle_length")
4481    self._block_length = ops.convert_to_tensor(
4482        block_length, dtype=dtypes.int64, name="block_length")
4483    self._buffer_output_elements = ops.convert_to_tensor(
4484        buffer_output_elements,
4485        dtype=dtypes.int64,
4486        name="buffer_output_elements")
4487    self._prefetch_input_elements = ops.convert_to_tensor(
4488        prefetch_input_elements,
4489        dtype=dtypes.int64,
4490        name="prefetch_input_elements")
4491
4492    self._num_parallel_calls = ops.convert_to_tensor(
4493        num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls")
4494    if deterministic is None:
4495      deterministic_string = "default"
4496    elif deterministic:
4497      deterministic_string = "true"
4498    else:
4499      deterministic_string = "false"
4500
4501    variant_tensor = gen_dataset_ops.parallel_interleave_dataset_v4(
4502        input_dataset._variant_tensor,  # pylint: disable=protected-access
4503        self._map_func.function.captured_inputs,  # pylint: disable=protected-access
4504        self._cycle_length,
4505        self._block_length,
4506        self._buffer_output_elements,
4507        self._prefetch_input_elements,
4508        self._num_parallel_calls,
4509        f=self._map_func.function,
4510        deterministic=deterministic_string,
4511        **self._flat_structure)
4512    super(ParallelInterleaveDataset, self).__init__(input_dataset,
4513                                                    variant_tensor)
4514
4515  def _functions(self):
4516    return [self._map_func]
4517
4518  @property
4519  def element_spec(self):
4520    return self._structure
4521
4522  def _transformation_name(self):
4523    return "Dataset.interleave()"
4524
4525
4526class FilterDataset(UnaryUnchangedStructureDataset):
4527  """A `Dataset` that filters its input according to a predicate function."""
4528
4529  def __init__(self, input_dataset, predicate, use_legacy_function=False):
4530    """See `Dataset.filter()` for details."""
4531    self._input_dataset = input_dataset
4532    wrapped_func = StructuredFunctionWrapper(
4533        predicate,
4534        self._transformation_name(),
4535        dataset=input_dataset,
4536        use_legacy_function=use_legacy_function)
4537    if not wrapped_func.output_structure.is_compatible_with(
4538        tensor_spec.TensorSpec([], dtypes.bool)):
4539      error_msg = ("`predicate` return type must be convertible to a scalar "
4540                   "boolean tensor. Was {}.").format(
4541                       wrapped_func.output_structure)
4542      raise ValueError(error_msg)
4543    self._predicate = wrapped_func
4544    variant_tensor = gen_dataset_ops.filter_dataset(
4545        input_dataset._variant_tensor,  # pylint: disable=protected-access
4546        other_arguments=self._predicate.function.captured_inputs,
4547        predicate=self._predicate.function,
4548        **self._flat_structure)
4549    super(FilterDataset, self).__init__(input_dataset, variant_tensor)
4550
4551  def _functions(self):
4552    return [self._predicate]
4553
4554  def _transformation_name(self):
4555    return "Dataset.filter()"
4556
4557
4558class PrefetchDataset(UnaryUnchangedStructureDataset):
4559  """A `Dataset` that asynchronously prefetches its input."""
4560
4561  def __init__(self, input_dataset, buffer_size, slack_period=None):
4562    """See `Dataset.prefetch()` for details.
4563
4564    Args:
4565      input_dataset: The input dataset.
4566      buffer_size: See `Dataset.prefetch()` for details.
4567      slack_period: (Optional.) An integer. If non-zero, determines the number
4568        of GetNext calls before injecting slack into the execution. This may
4569        reduce CPU contention at the start of a step. Note that a tensorflow
4570        user should not have to set this manually; enable this behavior
4571        automatically via `tf.data.Options.experimental_slack` instead. Defaults
4572        to None.
4573    """
4574    self._input_dataset = input_dataset
4575    if buffer_size is None:
4576      buffer_size = AUTOTUNE
4577    self._buffer_size = ops.convert_to_tensor(
4578        buffer_size, dtype=dtypes.int64, name="buffer_size")
4579    # pylint: disable=protected-access
4580    # We colocate the prefetch dataset with its input as this collocation only
4581    # happens automatically in graph mode.
4582    with ops.colocate_with(input_dataset._variant_tensor):
4583      variant_tensor = gen_dataset_ops.prefetch_dataset(
4584          input_dataset._variant_tensor,
4585          buffer_size=self._buffer_size,
4586          slack_period=slack_period,
4587          **self._flat_structure)
4588    super(PrefetchDataset, self).__init__(input_dataset, variant_tensor)
4589
4590
4591class WindowDataset(UnaryDataset):
4592  """A dataset that creates window datasets from the input elements."""
4593
4594  def __init__(self, input_dataset, size, shift, stride, drop_remainder):
4595    """See `window_dataset()` for more details."""
4596    self._input_dataset = input_dataset
4597    self._size = ops.convert_to_tensor(size, dtype=dtypes.int64, name="size")
4598    self._shift = ops.convert_to_tensor(shift, dtype=dtypes.int64, name="shift")
4599    self._stride = ops.convert_to_tensor(
4600        stride, dtype=dtypes.int64, name="stride")
4601    self._drop_remainder = ops.convert_to_tensor(
4602        drop_remainder, dtype=dtypes.bool, name="drop_remainder")
4603    self._structure = nest.pack_sequence_as(
4604        get_legacy_output_classes(input_dataset), [
4605            DatasetSpec(  # pylint: disable=g-complex-comprehension
4606                structure.convert_legacy_structure(
4607                    output_type, output_shape, output_class))
4608            for output_class, output_shape, output_type in zip(
4609                nest.flatten(get_legacy_output_classes(input_dataset)),
4610                nest.flatten(get_legacy_output_shapes(input_dataset)),
4611                nest.flatten(get_legacy_output_types(input_dataset)))
4612        ])
4613    variant_tensor = gen_dataset_ops.window_dataset(
4614        input_dataset._variant_tensor,  # pylint: disable=protected-access
4615        self._size,
4616        self._shift,
4617        self._stride,
4618        self._drop_remainder,
4619        **self._flat_structure)
4620    super(WindowDataset, self).__init__(input_dataset, variant_tensor)
4621
4622  @property
4623  def element_spec(self):
4624    return self._structure
4625
4626
4627class _OptionsDataset(UnaryUnchangedStructureDataset):
4628  """An identity `Dataset` that stores options."""
4629
4630  def __init__(self, input_dataset, options):
4631    self._input_dataset = input_dataset
4632    variant_tensor = input_dataset._variant_tensor  # pylint: disable=protected-access
4633    super(_OptionsDataset, self).__init__(input_dataset, variant_tensor)
4634
4635    if self._options_attr:
4636      self._options_attr = self._options_attr.merge(options)
4637    else:
4638      self._options_attr = options
4639
4640  def options(self):
4641    return self._options_attr
4642
4643
4644class _ModelDataset(UnaryUnchangedStructureDataset):
4645  """A `Dataset` that acts as an identity, and models performance."""
4646
4647  def __init__(self, input_dataset, algorithm, cpu_budget, ram_budget):
4648    self._input_dataset = input_dataset
4649    variant_tensor = gen_dataset_ops.model_dataset(
4650        input_dataset._variant_tensor,  # pylint: disable=protected-access
4651        algorithm=algorithm.value,
4652        cpu_budget=cpu_budget,
4653        ram_budget=ram_budget,
4654        **self._flat_structure)
4655    super(_ModelDataset, self).__init__(input_dataset, variant_tensor)
4656
4657
4658class _OptimizeDataset(UnaryUnchangedStructureDataset):
4659  """A `Dataset` that acts as an identity, and applies optimizations."""
4660
4661  def __init__(self,
4662               input_dataset,
4663               optimizations_enabled,
4664               optimizations_disabled,
4665               optimizations_default,
4666               optimization_configs=None):
4667    self._input_dataset = input_dataset
4668    if optimization_configs is None:
4669      optimization_configs = []
4670
4671    # We sort the options here before embedding as constant tensors to ensure
4672    # that serialization to NodeDef is determinstic.
4673    if optimizations_enabled:
4674      optimizations_enabled.sort()
4675    if optimizations_disabled:
4676      optimizations_disabled.sort()
4677    if optimizations_default:
4678      optimizations_default.sort()
4679
4680    self._optimizations_enabled = convert.optional_param_to_tensor(
4681        argument_name="optimizations_enabled",
4682        argument_value=optimizations_enabled,
4683        argument_default=[],
4684        argument_dtype=dtypes.string)
4685    self._optimizations_disabled = convert.optional_param_to_tensor(
4686        argument_name="optimizations_disabled",
4687        argument_value=optimizations_disabled,
4688        argument_default=[],
4689        argument_dtype=dtypes.string)
4690    self._optimizations_default = convert.optional_param_to_tensor(
4691        argument_name="optimizations_default",
4692        argument_value=optimizations_default,
4693        argument_default=[],
4694        argument_dtype=dtypes.string)
4695
4696    variant_tensor = gen_dataset_ops.optimize_dataset_v2(
4697        input_dataset._variant_tensor,  # pylint: disable=protected-access
4698        self._optimizations_enabled,
4699        self._optimizations_disabled,
4700        self._optimizations_default,
4701        optimization_configs=optimization_configs,
4702        **self._flat_structure)
4703
4704    super(_OptimizeDataset, self).__init__(input_dataset, variant_tensor)
4705
4706
4707class _SetStatsAggregatorDataset(UnaryUnchangedStructureDataset):
4708  """A `Dataset` that acts as an identity, and sets a stats aggregator."""
4709
4710  def __init__(self, input_dataset, aggregator, prefix, counter_prefix):
4711    self._input_dataset = input_dataset
4712    self._stats_aggregator = aggregator
4713    self._prefix = prefix
4714    self._counter_prefix = counter_prefix
4715    variant_tensor = ged_ops.set_stats_aggregator_dataset(
4716        input_dataset._variant_tensor,  # pylint: disable=protected-access
4717        self._stats_aggregator._resource,  # pylint: disable=protected-access
4718        self._prefix,
4719        self._counter_prefix,
4720        **self._flat_structure)
4721    super(_SetStatsAggregatorDataset, self).__init__(input_dataset,
4722                                                     variant_tensor)
4723
4724
4725class _MaxIntraOpParallelismDataset(UnaryUnchangedStructureDataset):
4726  """A `Dataset` that acts as an identity, overriding intra-op parallelism."""
4727
4728  def __init__(self, input_dataset, max_intra_op_parallelism):
4729    self._input_dataset = input_dataset
4730    self._max_intra_op_parallelism = ops.convert_to_tensor(
4731        max_intra_op_parallelism,
4732        dtype=dtypes.int64,
4733        name="max_intra_op_parallelism")
4734    variant_tensor = ged_ops.max_intra_op_parallelism_dataset(
4735        input_dataset._variant_tensor,  # pylint: disable=protected-access
4736        self._max_intra_op_parallelism,
4737        **self._flat_structure)
4738    super(_MaxIntraOpParallelismDataset, self).__init__(input_dataset,
4739                                                        variant_tensor)
4740
4741
4742class _PrivateThreadPoolDataset(UnaryUnchangedStructureDataset):
4743  """A `Dataset` that acts as an identity, setting a private threadpool."""
4744
4745  def __init__(self, input_dataset, num_threads):
4746    self._input_dataset = input_dataset
4747    self._num_threads = ops.convert_to_tensor(
4748        num_threads, dtype=dtypes.int64, name="num_threads")
4749    variant_tensor = ged_ops.private_thread_pool_dataset(
4750        input_dataset._variant_tensor,  # pylint: disable=protected-access
4751        self._num_threads,
4752        **self._flat_structure)
4753    super(_PrivateThreadPoolDataset, self).__init__(input_dataset,
4754                                                    variant_tensor)
4755
4756
4757def normalize_to_dense(dataset):
4758  """Normalizes non-tensor components in a dataset to dense representations.
4759
4760  This is necessary for dataset transformations that slice along the batch
4761  dimension and are oblivious to non-tensors, e.g. `unbatch`, `rebatch`.
4762
4763  Args:
4764    dataset: Dataset to normalize.
4765
4766  Returns:
4767    A dataset whose sparse and ragged tensors have been normalized to their
4768    dense representations.
4769  """
4770
4771  # NOTE(mrry): This leads to a somewhat inefficient re-encoding step for all
4772  # non-tensor components.
4773  #
4774  # TODO(mrry): Consider optimizing this if it turns out to be a bottleneck.
4775  if _should_unpack_args(dataset.element_spec):
4776    def normalize(*args):
4777      return structure.to_batched_tensor_list(dataset.element_spec, tuple(args))
4778  else:
4779    def normalize(arg):
4780      return structure.to_batched_tensor_list(dataset.element_spec, arg)
4781
4782  normalized_dataset = dataset.map(normalize)
4783
4784  # NOTE(mrry): Our `map()` has lost information about the structure of
4785  # non-tensor components, so re-apply the structure of the original dataset.
4786  return _RestructuredDataset(normalized_dataset, dataset.element_spec)
4787
4788
4789class _RestructuredDataset(UnaryDataset):
4790  """An internal helper for changing the structure and shape of a dataset."""
4791
4792  def __init__(self, dataset, structure):
4793    self._input_dataset = dataset
4794    self._structure = structure
4795
4796    variant_tensor = self._input_dataset._variant_tensor  # pylint: disable=protected-access
4797    super(_RestructuredDataset, self).__init__(dataset, variant_tensor)
4798
4799  @property
4800  def element_spec(self):
4801    return self._structure
4802
4803
4804class _UnbatchDataset(UnaryDataset):
4805  """A dataset that splits the elements of its input into multiple elements."""
4806
4807  def __init__(self, input_dataset):
4808    """See `unbatch()` for more details."""
4809    flat_shapes = input_dataset._flat_shapes  # pylint: disable=protected-access
4810    if any(s.ndims == 0 for s in flat_shapes):
4811      raise ValueError("Cannot unbatch an input with scalar components.")
4812    known_batch_dim = tensor_shape.Dimension(None)
4813    for s in flat_shapes:
4814      try:
4815        known_batch_dim = known_batch_dim.merge_with(s[0])
4816      except ValueError:
4817        raise ValueError("Cannot unbatch an input whose components have "
4818                         "different batch sizes.")
4819    self._input_dataset = input_dataset
4820    self._structure = nest.map_structure(
4821        lambda component_spec: component_spec._unbatch(),  # pylint: disable=protected-access
4822        get_structure(input_dataset))
4823    variant_tensor = ged_ops.unbatch_dataset(
4824        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
4825        **self._flat_structure)
4826    super(_UnbatchDataset, self).__init__(input_dataset, variant_tensor)
4827
4828  @property
4829  def element_spec(self):
4830    return self._structure
4831
4832
4833def _collect_resource_inputs(op):
4834  """Collects resource inputs for the given ops (and its variant inputs)."""
4835
4836  def _process(op_queue, seen_ops):
4837    """Processes the next element of the op queue.
4838
4839    Args:
4840      op_queue: Queue of Dataset operations to process.
4841      seen_ops: Already processed set of Operations.
4842
4843    Returns:
4844      A 2-tuple containing sets of resource handles. The first tuple entry
4845      contains read-only handles and the second entry contains read-write
4846      handles.
4847    """
4848
4849    reads = []
4850    writes = []
4851    op = op_queue.pop()
4852    if op in seen_ops:
4853      return reads, writes
4854    seen_ops.add(op)
4855    # TODO(b/150139257): All resource inputs are in writes right now since we
4856    # have not updated the functional ops to set the special attribute that ACD
4857    # uses to figure out which of the op's inputs are read-only.
4858    reads, writes = acd_utils.get_read_write_resource_inputs(op)
4859    # Conservatively assume that any variant inputs are datasets.
4860    op_queue.extend(t.op for t in op.inputs if t.dtype == dtypes.variant)
4861    return reads, writes
4862
4863  op_queue = [op]
4864  seen_ops = set()
4865  all_reads = []
4866  all_writes = []
4867  while op_queue:
4868    reads, writes = _process(op_queue, seen_ops)
4869    all_reads.extend(reads)
4870    all_writes.extend(writes)
4871
4872  return all_reads, all_writes
4873
4874
4875@auto_control_deps.register_acd_resource_resolver
4876def _resource_resolver(op, resource_reads, resource_writes):
4877  """Updates resource inputs for tf.data ops with indirect dependencies."""
4878
4879  updated = False
4880  if op.type in [
4881      "DatasetToSingleElement", "DatasetToTFRecord", "ReduceDataset"
4882  ]:
4883    reads, writes = _collect_resource_inputs(op)
4884    for inp in reads:
4885      if inp not in resource_reads:
4886        updated = True
4887        resource_reads.add(inp)
4888    for inp in writes:
4889      if inp not in resource_writes:
4890        updated = True
4891        resource_writes.add(inp)
4892
4893  if op.type in [
4894      "IteratorGetNext", "IteratorGetNextSync", "IteratorGetNextAsOptional"
4895  ]:
4896    iterator_resource = op.inputs[0]
4897    make_iterator_ops = [
4898        op for op in iterator_resource.consumers() if op.type == "MakeIterator"
4899    ]
4900
4901    if len(make_iterator_ops) == 1:
4902      reads, writes = _collect_resource_inputs(make_iterator_ops[0])
4903      for inp in reads:
4904        if inp not in resource_reads:
4905          updated = True
4906          resource_reads.add(inp)
4907      for inp in writes:
4908        if inp not in resource_writes:
4909          updated = True
4910          resource_writes.add(inp)
4911
4912  return updated
4913