1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Adapter module that convert different input data objects into tf.dataset."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import abc
22import contextlib
23import functools
24import itertools
25import math
26import random
27
28import numpy as np
29import six
30
31from tensorflow.python.data.experimental.ops import cardinality
32from tensorflow.python.data.experimental.ops import distribute_options
33from tensorflow.python.data.ops import dataset_ops
34from tensorflow.python.data.ops import iterator_ops
35from tensorflow.python.distribute import distribution_strategy_context as ds_context
36from tensorflow.python.distribute import input_lib
37from tensorflow.python.eager import context
38from tensorflow.python.eager import monitoring
39from tensorflow.python.framework import dtypes
40from tensorflow.python.framework import errors
41from tensorflow.python.framework import ops
42from tensorflow.python.framework import smart_cond
43from tensorflow.python.framework import sparse_tensor
44from tensorflow.python.framework import tensor_shape
45from tensorflow.python.keras import backend
46from tensorflow.python.keras.engine import training_utils
47from tensorflow.python.keras.utils import data_utils
48from tensorflow.python.keras.utils import dataset_creator
49from tensorflow.python.keras.utils import tf_utils
50from tensorflow.python.ops import array_ops
51from tensorflow.python.ops import math_ops
52from tensorflow.python.ops import random_ops
53from tensorflow.python.ops import script_ops
54from tensorflow.python.platform import tf_logging as logging
55from tensorflow.python.util import nest
56from tensorflow.python.util.tf_export import keras_export
57
58keras_data_adapter_gauge = monitoring.BoolGauge(
59    "/tensorflow/api/keras/data_adapters", "keras data adapter usage", "method")
60
61try:
62  from scipy import sparse as scipy_sparse  # pylint: disable=g-import-not-at-top
63except ImportError:
64  scipy_sparse = None
65try:
66  import pandas as pd  # pylint: disable=g-import-not-at-top
67except ImportError:
68  pd = None
69
70
71@six.add_metaclass(abc.ABCMeta)
72class DataAdapter(object):
73  """Base class for input data adapter.
74
75  In TF 2.0, tf.data is the preferred API for user to feed in data. In order
76  to simplify the training code path, all the input data object will be
77  converted to `tf.data.Dataset` if possible.
78
79  Note that since this class is mainly targeted for TF 2.0, it might have a lot
80  of assumptions under the hood, eg eager context by default, distribution
81  strategy, etc. In the meantime, some legacy feature support might be dropped,
82  eg, Iterator from dataset API in v1, etc.
83
84  The sample usage of this class is like:
85
86  ```
87  x = tf.data.Dataset.range(100)
88  adapter_cls = [NumpyArrayDataAdapter, ..., DatasetAdapter]
89  applicable_adapters = [cls for cls in adapter_cls if cls.can_handle(x)]
90  if len(applicable_adapters) != 1:
91    raise ValueError("Expect only one adapter class to handle the input")
92
93  dataset = applicable_adapters[0](x).get_dataset()
94  for data in dataset:
95    # training
96  ```
97  """
98
99  @staticmethod
100  def can_handle(x, y=None):
101    """Whether the current DataAdapter could handle the input x and y.
102
103    Structure wise, x and y can be single object, or list of objects if there
104    multiple input/output, or dictionary of objects when the intput/output are
105    named.
106
107    Args:
108      x: input features.
109      y: target labels. Note that y could be None in the case of prediction.
110
111    Returns:
112      boolean
113    """
114    raise NotImplementedError
115
116  @abc.abstractmethod
117  def __init__(self, x, y=None, **kwargs):
118    """Create a DataAdapter based on data inputs.
119
120    The caller must make sure to call `can_handle()` first before invoking this
121    method. Provide unsupported data type will result into unexpected behavior.
122
123    Args:
124      x: input features.
125      y: target labels. Note that y could be None in the case of prediction.
126      **kwargs: Other keyword arguments for DataAdapter during the construction
127        of the tf.dataset.Dataset. For example:
128        - Numpy data might have `sample_weights` which will be used for
129          weighting the loss function during training.
130        - Numpy data might need to have `batch_size` parameter when constructing
131          the dataset and iterator.
132        - Certain input might need to be distribution strategy aware. When
133          `distribution_strategy` is passed, the created dataset need to respect
134          the strategy.
135        DataAdapter might choose to ignore any keyword argument if it doesn't
136        use it, or raise exception if any required argument is not provide.
137    """
138    if not self.can_handle(x, y):
139      raise ValueError("{} Cannot handle input {}, {}".format(
140          self.__class__, x, y))
141
142  @abc.abstractmethod
143  def get_dataset(self):
144    """Get a dataset instance for the current DataAdapter.
145
146    Note that the dataset returned does not repeat for epoch, so caller might
147    need to create new iterator for the same dataset at the beginning of the
148    epoch. This behavior might change in future.
149
150    Returns:
151      An tf.dataset.Dataset. Caller might use the dataset in different
152      context, eg iter(dataset) in eager to get the value directly, or in graph
153      mode, provide the iterator tensor to Keras model function.
154    """
155    raise NotImplementedError
156
157  @abc.abstractmethod
158  def get_size(self):
159    """Return the size (number of batches) for the dataset created.
160
161    For certain type of the data input, the number of batches is known, eg for
162    Numpy data, the size is same as (number_of_element / batch_size). Whereas
163    for dataset or python generator, the size is unknown since it may or may not
164    have a end state.
165
166    Returns:
167      int, the number of batches for the dataset, or None if it is unknown. The
168      caller could use this to control the loop of training, show progress bar,
169      or handle unexpected StopIteration error.
170    """
171    raise NotImplementedError
172
173  @abc.abstractmethod
174  def batch_size(self):
175    """Return the batch size of the dataset created.
176
177    For certain type of the data input, the batch size is known, and even
178    required, like numpy array. Where as for dataset, the batch is unknown
179    unless we take a peek.
180
181    Returns:
182      int, the batch size of the dataset, or None if it is unknown.
183    """
184    raise NotImplementedError
185
186  def representative_batch_size(self):
187    """Return a representative size for batches in the dataset.
188
189    This is not guaranteed to be the batch size for all batches in the
190    dataset. It just needs to be a rough approximation for batch sizes in
191    the dataset.
192
193    Returns:
194      int, a representative size for batches found in the dataset,
195      or None if it is unknown.
196    """
197    return self.batch_size()
198
199  @abc.abstractmethod
200  def has_partial_batch(self):
201    """Whether the dataset has partial batch at the end."""
202    raise NotImplementedError
203
204  @abc.abstractmethod
205  def partial_batch_size(self):
206    """The size of the final partial batch for dataset.
207
208    Will return None if has_partial_batch is False or batch_size is None.
209    """
210    raise NotImplementedError
211
212  @abc.abstractmethod
213  def should_recreate_iterator(self):
214    """Returns whether a new iterator should be created every epoch."""
215    raise NotImplementedError
216
217  def get_samples(self):
218    """Returns number of samples in the data, or `None`."""
219    if not self.get_size() or not self.batch_size():
220      return None
221    total_sample = self.get_size() * self.batch_size()
222    if self.has_partial_batch():
223      total_sample -= (self.batch_size() - self.partial_batch_size())
224    return total_sample
225
226  def on_epoch_end(self):
227    """A hook called after each epoch."""
228    pass
229
230
231class TensorLikeDataAdapter(DataAdapter):
232  """Adapter that handles Tensor-like objects, e.g. EagerTensor and NumPy."""
233
234  @staticmethod
235  def can_handle(x, y=None):
236    # TODO(kaftan): Check performance implications of using a flatten
237    #  here for other types of inputs.
238    flat_inputs = nest.flatten(x)
239    if y is not None:
240      flat_inputs += nest.flatten(y)
241
242    tensor_types = (ops.Tensor, np.ndarray)
243    if pd:
244      tensor_types = (ops.Tensor, np.ndarray, pd.Series, pd.DataFrame)
245
246    def _is_tensor(v):
247      if isinstance(v, tensor_types):
248        return True
249      return False
250
251    return all(_is_tensor(v) for v in flat_inputs)
252
253  def __init__(self,
254               x,
255               y=None,
256               sample_weights=None,
257               sample_weight_modes=None,
258               batch_size=None,
259               epochs=1,
260               steps=None,
261               shuffle=False,
262               **kwargs):
263    super(TensorLikeDataAdapter, self).__init__(x, y, **kwargs)
264    x, y, sample_weights = _process_tensorlike((x, y, sample_weights))
265    sample_weight_modes = broadcast_sample_weight_modes(
266        sample_weights, sample_weight_modes)
267
268    # If sample_weights are not specified for an output use 1.0 as weights.
269    (sample_weights, _, _) = training_utils.handle_partial_sample_weights(
270        y, sample_weights, sample_weight_modes, check_all_flat=True)
271
272    inputs = pack_x_y_sample_weight(x, y, sample_weights)
273
274    num_samples = set(int(i.shape[0]) for i in nest.flatten(inputs)).pop()
275    _check_data_cardinality(inputs)
276
277    # If batch_size is not passed but steps is, calculate from the input data.
278    # Default to 32 for backwards compat.
279    if not batch_size:
280      batch_size = int(math.ceil(num_samples / steps)) if steps else 32
281
282    self._size = int(math.ceil(num_samples / batch_size))
283    self._batch_size = batch_size
284
285    num_full_batches = int(num_samples // batch_size)
286    self._partial_batch_size = num_samples % batch_size
287
288    if isinstance(shuffle, str):
289      shuffle = shuffle.lower()
290
291    self._shuffle = shuffle
292    # Vectorized version of shuffle.
293    # This is a performance improvement over using `from_tensor_slices`.
294    # The indices of the data are shuffled and batched, and these indices
295    # are then zipped with the data and used to extract a batch of the data
296    # at each step. The performance improvements here come from:
297    # 1. vectorized batch using gather
298    # 2. parallelized map
299    # 3. pipelined permutation generation
300    # 4. optimized permutation batching
301    # 5. disabled static optimizations
302
303    indices_dataset = dataset_ops.DatasetV2.range(1)
304    if shuffle != "batch":
305      indices_dataset = indices_dataset.repeat(epochs)
306
307    def permutation(_):
308      # It turns out to be more performant to make a new set of indices rather
309      # than reusing the same range Tensor. (presumably because of buffer
310      # forwarding.)
311      indices = math_ops.range(num_samples, dtype=dtypes.int64)
312      if shuffle and shuffle != "batch":
313        indices = random_ops.random_shuffle(indices)
314      return indices
315
316    # We prefetch a single element. Computing large permutations can take quite
317    # a while so we don't want to wait for prefetching over an epoch boundary to
318    # trigger the next permutation. On the other hand, too many simultaneous
319    # shuffles can contend on a hardware level and degrade all performance.
320    indices_dataset = indices_dataset.map(permutation).prefetch(1)
321
322    def slice_batch_indices(indices):
323      """Convert a Tensor of indices into a dataset of batched indices.
324
325      This step can be accomplished in several ways. The most natural is to
326      slice the Tensor in a Dataset map. (With a condition on the upper index to
327      handle the partial batch.) However it turns out that coercing the Tensor
328      into a shape which is divisible by the batch size (and handling the last
329      partial batch separately) allows for a much more favorable memory access
330      pattern and improved performance.
331
332      Args:
333        indices: Tensor which determines the data order for an entire epoch.
334
335      Returns:
336        A Dataset of batched indices.
337      """
338      num_in_full_batch = num_full_batches * batch_size
339      first_k_indices = array_ops.slice(indices, [0], [num_in_full_batch])
340      first_k_indices = array_ops.reshape(
341          first_k_indices, [num_full_batches, batch_size])
342
343      flat_dataset = dataset_ops.DatasetV2.from_tensor_slices(first_k_indices)
344      if self._partial_batch_size:
345        index_remainder = dataset_ops.DatasetV2.from_tensors(array_ops.slice(
346            indices, [num_in_full_batch], [self._partial_batch_size]))
347        flat_dataset = flat_dataset.concatenate(index_remainder)
348
349      if shuffle == "batch":
350        # 1024 is a magic constant that has not been properly evaluated
351        flat_dataset = flat_dataset.shuffle(1024).repeat(epochs)
352      return flat_dataset
353
354    indices_dataset = indices_dataset.flat_map(slice_batch_indices)
355
356    dataset = self.slice_inputs(indices_dataset, inputs)
357
358    if shuffle == "batch":
359      def shuffle_batch(*batch):
360        return nest.map_structure(random_ops.random_shuffle, batch)
361      dataset = dataset.map(shuffle_batch)
362
363    self._dataset = dataset
364
365  def slice_inputs(self, indices_dataset, inputs):
366    """Slice inputs into a Dataset of batches.
367
368    Given a Dataset of batch indices and the unsliced inputs,
369    this step slices the inputs in a parallelized fashion
370    and produces a dataset of input batches.
371
372    Args:
373      indices_dataset: A Dataset of batched indices
374      inputs: A python data structure that contains the inputs, targets,
375        and possibly sample weights.
376
377    Returns:
378      A Dataset of input batches matching the batch indices.
379    """
380    dataset = dataset_ops.DatasetV2.zip((
381        indices_dataset,
382        dataset_ops.DatasetV2.from_tensors(inputs).repeat()
383    ))
384
385    def grab_batch(i, data):
386      return nest.map_structure(lambda d: array_ops.gather(d, i, axis=0), data)
387
388    dataset = dataset.map(
389        grab_batch, num_parallel_calls=dataset_ops.AUTOTUNE)
390
391    # Default optimizations are disabled to avoid the overhead of (unnecessary)
392    # input pipeline graph serialization and deserialization
393    options = dataset_ops.Options()
394    options.experimental_optimization.apply_default_optimizations = False
395    if self._shuffle:
396      # See b/141490660 for more details.
397      options.experimental_external_state_policy = (
398          distribute_options.ExternalStatePolicy.IGNORE)
399    dataset = dataset.with_options(options)
400    return dataset
401
402  def get_dataset(self):
403    return self._dataset
404
405  def get_size(self):
406    return self._size
407
408  def batch_size(self):
409    return self._batch_size
410
411  def has_partial_batch(self):
412    return self._partial_batch_size > 0
413
414  def partial_batch_size(self):
415    return self._partial_batch_size or None
416
417  def should_recreate_iterator(self):
418    # An infinite dataset is always created here.
419    return False
420
421
422class GenericArrayLikeDataAdapter(TensorLikeDataAdapter):
423  """Adapter that handles array-like data without forcing it into memory.
424
425  This adapter handles array-like datasets that may be too big to fully
426  fit into memory.
427
428  Specifically, this adapter handles any Python class which implements:
429  `__get_item__`, `__len__`, `shape`, and `dtype` with the same meanings
430  as Numpy, but it ignores any case where all the inputs are Tensors or Numpy
431  arrays (because that case is handled by the base TensorLikeDataAdapter).
432
433  It ignores scipy sparse matrices and Composite Tensors because those are
434  handled by the CompositeTensorDataAdapter.
435
436  It also does not handle lists/tuples of scalars, because those are handled
437  by the ListsOfScalarsDataAdapter.
438  """
439
440  @staticmethod
441  def can_handle(x, y=None):
442    flat_inputs = nest.flatten(x)
443    if y is not None:
444      flat_inputs += nest.flatten(y)
445
446    def _is_array_like(v):
447      """Return True if v is a Tensor, array, or is array-like."""
448      return (
449          hasattr(v, "__getitem__") and
450          hasattr(v, "shape") and
451          hasattr(v, "dtype") and
452          hasattr(v, "__len__")
453      )
454
455    if (not TensorLikeDataAdapter.can_handle(x, y) and
456        not CompositeTensorDataAdapter.can_handle(x, y)):
457      return all(_is_array_like(v) for v in flat_inputs)
458    else:
459      return False
460
461  def __init__(self, *args, **kwargs):
462    logging.warn(
463        "Keras is training/fitting/evaluating on array-like data. Keras may "
464        "not be optimized for this format, so if your input data format is "
465        "supported by TensorFlow I/O (https://github.com/tensorflow/io) we "
466        "recommend using that to load a Dataset instead.")
467
468    super(GenericArrayLikeDataAdapter, self).__init__(*args, **kwargs)
469
470  def slice_inputs(self, indices_dataset, inputs):
471    """Slice inputs into a Dataset of batches.
472
473    Given a Dataset of batch indices and the unsliced inputs,
474    this step slices the inputs in a parallelized fashion
475    and produces a dataset of input batches.
476
477    Args:
478      indices_dataset: A Dataset of batched indices
479      inputs: A python data structure that contains the inputs, targets,
480        and possibly sample weights.
481
482    Returns:
483      A Dataset of input batches matching the batch indices.
484    """
485    flat_inputs = nest.flatten(inputs)
486    def dynamic_shape_like(t):
487      shape = list(t.shape)
488      shape[0] = None
489      return tuple(shape)
490
491    flat_dtypes = [inp.dtype for inp in flat_inputs]
492    contiguous = True
493    if self._shuffle and self._shuffle != "batch":
494      contiguous = False
495
496    def grab_batch(indices):
497      """Grab a batch of data from the inputs."""
498      # This uses a py_function to avoid converting the array-like
499      # into a Tensor before slicing it, because converting the array-like
500      # to a Tensor may force it into memory..
501      def py_method(ind):
502        def slice_array(data):
503          return training_utils.slice_arrays(data, ind.numpy(),
504                                             contiguous=contiguous)
505        return [slice_array(inp) for inp in flat_inputs]
506
507      flat_out = script_ops.eager_py_func(py_method, [indices], flat_dtypes)
508      for v, original_inp in zip(flat_out, flat_inputs):
509        v.set_shape(dynamic_shape_like(original_inp))
510      return nest.pack_sequence_as(inputs, flat_out)
511
512    dataset = indices_dataset.map(
513        grab_batch, num_parallel_calls=dataset_ops.AUTOTUNE)
514
515    return dataset
516
517
518class DatasetCreatorAdapter(DataAdapter):
519  """Adapter that handles dataset functions."""
520
521  def __init__(self, *args, **kwargs):
522    super(DatasetCreatorAdapter, self).__init__(*args, **kwargs)
523
524  @staticmethod
525  def can_handle(x, y=None):
526    if isinstance(x, dataset_creator.DatasetCreator):
527      assert y is None
528      return True
529
530  def should_recreate_iterator(self):
531    # We expect users to shuffle the dataset in their `dataset_fn` supplied to
532    # `DatasetCreator`. Since that is a buffered shuffle, we intend to not reset
533    # the dataset so the batches that are not shuffled can still be pulled.
534    return False
535
536  def get_size(self):
537    raise NotImplementedError()
538
539  def get_dataset(self):
540    raise NotImplementedError()
541
542  def batch_size(self):
543    raise NotImplementedError()
544
545  def has_partial_batch(self):
546    raise NotImplementedError()
547
548  def partial_batch_size(self):
549    raise NotImplementedError()
550
551
552class CompositeTensorDataAdapter(DataAdapter):
553  """Adapter that handles composite tensor."""
554
555  @staticmethod
556  def can_handle(x, y=None):
557    flat_inputs = nest.flatten(x)
558    if y is not None:
559      flat_inputs += nest.flatten(y)
560
561    def _is_composite(v):
562      # Dataset/iterator inherits from CompositeTensor but should be handled
563      # by DatasetAdapter and GeneratorAdapter.
564      if (tf_utils.is_extension_type(v) and
565          not isinstance(v, (dataset_ops.DatasetV2,
566                             iterator_ops.IteratorBase))):
567        return True
568      # Support Scipy sparse tensors if scipy is installed
569      if scipy_sparse is not None and scipy_sparse.issparse(v):
570        return True
571      return False
572
573    def _is_tensor_or_composite(v):
574      if isinstance(v, (ops.Tensor, np.ndarray)):
575        return True
576      return _is_composite(v)
577
578    return (any(_is_composite(v) for v in flat_inputs) and
579            all(_is_tensor_or_composite(v) for v in flat_inputs))
580
581  def __init__(self,
582               x,
583               y=None,
584               sample_weights=None,
585               sample_weight_modes=None,
586               batch_size=None,
587               steps=None,
588               shuffle=False,
589               **kwargs):
590    super(CompositeTensorDataAdapter, self).__init__(x, y, **kwargs)
591    x, y, sample_weights = _process_tensorlike((x, y, sample_weights))
592    sample_weight_modes = broadcast_sample_weight_modes(
593        sample_weights, sample_weight_modes)
594
595    # If sample_weights are not specified for an output use 1.0 as weights.
596    (sample_weights, _, _) = training_utils.handle_partial_sample_weights(
597        y, sample_weights, sample_weight_modes, check_all_flat=True)
598
599    inputs = pack_x_y_sample_weight(x, y, sample_weights)
600
601    dataset = dataset_ops.DatasetV2.from_tensor_slices(inputs)
602    num_samples = int(nest.flatten(x)[0].shape[0])
603    if shuffle:
604      dataset = dataset.shuffle(num_samples)
605
606    # If batch_size is not passed but steps is, calculate from the input data.
607    # Default to 32 for backwards compat.
608    if not batch_size:
609      batch_size = int(math.ceil(num_samples / steps)) if steps else 32
610
611    dataset = dataset.batch(batch_size)
612    self._size = int(math.ceil(num_samples / batch_size))
613    self._batch_size = batch_size
614    self._has_partial_batch = (self._size != (num_samples // batch_size))
615
616    self._partial_batch_size = None
617    if self._has_partial_batch:
618      self._partial_batch_size = (
619          num_samples - (self._size - 1) * self._batch_size)
620
621    self._dataset = dataset
622
623  def get_dataset(self):
624    return self._dataset
625
626  def get_size(self):
627    return self._size
628
629  def batch_size(self):
630    return self._batch_size
631
632  def has_partial_batch(self):
633    return self._has_partial_batch
634
635  def partial_batch_size(self):
636    return self._partial_batch_size
637
638  def should_recreate_iterator(self):
639    return True
640
641
642class ListsOfScalarsDataAdapter(DataAdapter):
643  """Adapter that handles lists of scalars and lists of lists of scalars."""
644
645  @staticmethod
646  def can_handle(x, y=None):
647    handles_x = ListsOfScalarsDataAdapter._is_list_of_scalars(x)
648    handles_y = True
649    if y is not None:
650      handles_y = ListsOfScalarsDataAdapter._is_list_of_scalars(y)
651    return handles_x and handles_y
652
653  @staticmethod
654  def _is_list_of_scalars(inp):
655    if isinstance(inp, (float, int, str, bytes, bytearray)):
656      return True
657    if isinstance(inp, (list, tuple)) and inp:
658      return ListsOfScalarsDataAdapter._is_list_of_scalars(inp[0])
659    return False
660
661  def __init__(self,
662               x,
663               y=None,
664               sample_weights=None,
665               sample_weight_modes=None,
666               batch_size=None,
667               shuffle=False,
668               **kwargs):
669    super(ListsOfScalarsDataAdapter, self).__init__(x, y, **kwargs)
670    x = np.asarray(x)
671    if y is not None:
672      y = np.asarray(y)
673    if sample_weights is not None:
674      sample_weights = np.asarray(sample_weights)
675    sample_weight_modes = broadcast_sample_weight_modes(
676        sample_weights, sample_weight_modes)
677
678    self._internal_adapter = TensorLikeDataAdapter(
679        x,
680        y=y,
681        sample_weights=sample_weights,
682        sample_weight_modes=sample_weight_modes,
683        batch_size=batch_size,
684        shuffle=shuffle,
685        **kwargs)
686
687  def get_dataset(self):
688    return self._internal_adapter.get_dataset()
689
690  def get_size(self):
691    return self._internal_adapter.get_size()
692
693  def batch_size(self):
694    return self._internal_adapter.batch_size()
695
696  def has_partial_batch(self):
697    return self._internal_adapter.has_partial_batch()
698
699  def partial_batch_size(self):
700    return self._internal_adapter.partial_batch_size()
701
702  def should_recreate_iterator(self):
703    return True
704
705
706class DatasetAdapter(DataAdapter):
707  """Adapter that handles `tf.data.Dataset`."""
708
709  @staticmethod
710  def can_handle(x, y=None):
711    return (isinstance(x, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)) or
712            _is_distributed_dataset(x))
713
714  def __init__(self,
715               x,
716               y=None,
717               sample_weights=None,
718               steps=None,
719               **kwargs):
720    super(DatasetAdapter, self).__init__(x, y, **kwargs)
721    # Note that the dataset instance is immutable, its fine to reuse the user
722    # provided dataset.
723    self._dataset = x
724
725    # The user-provided steps.
726    self._user_steps = steps
727
728    self._validate_args(y, sample_weights, steps)
729
730  def get_dataset(self):
731    return self._dataset
732
733  def get_size(self):
734    return  # Inferred in `DataHandler`.
735
736  def batch_size(self):
737    return None
738
739  def has_partial_batch(self):
740    return False
741
742  def partial_batch_size(self):
743    return None
744
745  def should_recreate_iterator(self):
746    # Since DistributedDatasets have no cardinality, the user must provide
747    # all steps that need to be run, calling `.repeat()` as needed.
748    if _is_distributed_dataset(self._dataset):
749      return False
750
751    # If user doesn't supply `steps`, or if they supply `steps` that
752    # exactly equals the size of the `Dataset`, create a new iterator
753    # each epoch.
754    return (self._user_steps is None or
755            cardinality.cardinality(self._dataset).numpy() == self._user_steps)
756
757  def _validate_args(self, y, sample_weights, steps):
758    """Validates `__init__` arguments."""
759    # Arguments that shouldn't be passed.
760    if not is_none_or_empty(y):
761      raise ValueError("`y` argument is not supported when using "
762                       "dataset as input.")
763    if not is_none_or_empty(sample_weights):
764      raise ValueError("`sample_weight` argument is not supported when using "
765                       "dataset as input.")
766
767    if steps is None:
768      if _is_distributed_dataset(self._dataset):
769        raise ValueError("When providing a distributed dataset, you must "
770                         "specify the number of steps to run.")
771
772      size = cardinality.cardinality(self._dataset).numpy()
773      if size == cardinality.INFINITE and steps is None:
774        raise ValueError(
775            "When providing an infinite dataset, you must specify "
776            "the number of steps to run (if you did not intend to "
777            "create an infinite dataset, make sure to not call "
778            "`repeat()` on the dataset).")
779
780
781class GeneratorDataAdapter(DataAdapter):
782  """Adapter that handles python generators and iterators."""
783
784  @staticmethod
785  def can_handle(x, y=None):
786    return ((hasattr(x, "__next__") or hasattr(x, "next"))
787            and hasattr(x, "__iter__")
788            and not isinstance(x, data_utils.Sequence))
789
790  def __init__(self,
791               x,
792               y=None,
793               sample_weights=None,
794               workers=1,
795               use_multiprocessing=False,
796               max_queue_size=10,
797               model=None,
798               **kwargs):
799    # Generators should never shuffle as exhausting the generator in order to
800    # shuffle the batches is inefficient.
801    kwargs.pop("shuffle", None)
802
803    if not is_none_or_empty(y):
804      raise ValueError("`y` argument is not supported when using "
805                       "python generator as input.")
806    if not is_none_or_empty(sample_weights):
807      raise ValueError("`sample_weight` argument is not supported when using "
808                       "python generator as input.")
809
810    super(GeneratorDataAdapter, self).__init__(x, y, **kwargs)
811
812    # Since we have to know the dtype of the python generator when we build the
813    # dataset, we have to look at a batch to infer the structure.
814    peek, x = self._peek_and_restore(x)
815    peek = self._standardize_batch(peek)
816    peek = _process_tensorlike(peek)
817
818    # Need to build the Model on concrete input shapes.
819    if model is not None and not model.built:
820      concrete_x, _, _ = unpack_x_y_sample_weight(peek)
821      model.distribute_strategy.run(
822          lambda x: model(x, training=False), args=(concrete_x,))
823
824    self._first_batch_size = int(nest.flatten(peek)[0].shape[0])
825
826    def _get_dynamic_shape(t):
827      shape = t.shape
828      # Unknown number of dimensions, `as_list` cannot be called.
829      if shape.rank is None:
830        return shape
831      return tensor_shape.TensorShape([None for _ in shape.as_list()])
832
833    output_shapes = nest.map_structure(_get_dynamic_shape, peek)
834    output_types = nest.map_structure(lambda t: t.dtype, peek)
835
836    # Note that dataset API takes a callable that creates a generator object,
837    # rather than generator itself, which is why we define a function here.
838    generator_fn = self._handle_multiprocessing(x, workers, use_multiprocessing,
839                                                max_queue_size)
840
841    def wrapped_generator():
842      for data in generator_fn():
843        yield self._standardize_batch(data)
844
845    dataset = dataset_ops.DatasetV2.from_generator(
846        wrapped_generator, output_types, output_shapes=output_shapes)
847
848    if workers == 1 and not use_multiprocessing:
849      dataset = dataset.prefetch(1)
850
851    self._dataset = dataset
852
853  def _standardize_batch(self, data):
854    """Standardizes a batch output by a generator."""
855    # Removes `None`s.
856    x, y, sample_weight = unpack_x_y_sample_weight(data)
857    data = pack_x_y_sample_weight(x, y, sample_weight)
858
859    data = nest.list_to_tuple(data)
860
861    def _convert_dtype(t):
862      if (isinstance(t, np.ndarray) and issubclass(t.dtype.type, np.floating)):
863        return np.array(t, dtype=backend.floatx())
864      return t
865
866    data = nest.map_structure(_convert_dtype, data)
867    return data
868
869  @staticmethod
870  def _peek_and_restore(x):
871    peek = next(x)
872    return peek, itertools.chain([peek], x)
873
874  def _handle_multiprocessing(self, x, workers, use_multiprocessing,
875                              max_queue_size):
876    """Create a callable, possibly including an Enqueuer."""
877    if workers > 1 or (workers > 0 and use_multiprocessing):
878      def generator_fn():
879        enqueuer = data_utils.GeneratorEnqueuer(
880            x, use_multiprocessing=use_multiprocessing)
881        enqueuer.start(workers=workers, max_queue_size=max_queue_size)
882        return enqueuer.get()
883    else:
884      generator_fn = lambda: x
885    return generator_fn
886
887  def get_dataset(self):
888    return self._dataset
889
890  def get_size(self):
891    return None
892
893  def batch_size(self):
894    return None
895
896  def representative_batch_size(self):
897    return self._first_batch_size
898
899  def has_partial_batch(self):
900    return False
901
902  def partial_batch_size(self):
903    return
904
905  def should_recreate_iterator(self):
906    return False
907
908
909class KerasSequenceAdapter(GeneratorDataAdapter):
910  """Adapter that handles `keras.utils.Sequence`."""
911
912  @staticmethod
913  def can_handle(x, y=None):
914    return isinstance(x, data_utils.Sequence)
915
916  def __init__(self,
917               x,
918               y=None,
919               sample_weights=None,
920               shuffle=False,
921               workers=1,
922               use_multiprocessing=False,
923               max_queue_size=10,
924               model=None,
925               **kwargs):
926    if not is_none_or_empty(y):
927      raise ValueError("`y` argument is not supported when using "
928                       "`keras.utils.Sequence` as input.")
929    if not is_none_or_empty(sample_weights):
930      raise ValueError("`sample_weight` argument is not supported when using "
931                       "`keras.utils.Sequence` as input.")
932
933    self._size = len(x)
934    self._shuffle_sequence = shuffle
935    self._keras_sequence = x
936    self._enqueuer = None
937    super(KerasSequenceAdapter, self).__init__(
938        x,
939        shuffle=False,  # Shuffle is handed in the _make_callable override.
940        workers=workers,
941        use_multiprocessing=use_multiprocessing,
942        max_queue_size=max_queue_size,
943        model=model,
944        **kwargs)
945
946  @staticmethod
947  def _peek_and_restore(x):
948    return x[0], x
949
950  def _handle_multiprocessing(self, x, workers, use_multiprocessing,
951                              max_queue_size):
952    if workers > 1 or (workers > 0 and use_multiprocessing):
953      def generator_fn():
954        self._enqueuer = data_utils.OrderedEnqueuer(
955            x, use_multiprocessing=use_multiprocessing,
956            shuffle=self._shuffle_sequence)
957        self._enqueuer.start(workers=workers, max_queue_size=max_queue_size)
958        return self._enqueuer.get()
959    else:
960      def generator_fn():
961        order = range(len(x))
962        if self._shuffle_sequence:
963          # Match the shuffle convention in OrderedEnqueuer.
964          order = list(order)
965          random.shuffle(order)
966
967        for i in order:
968          yield x[i]
969
970    return generator_fn
971
972  def get_size(self):
973    return self._size
974
975  def should_recreate_iterator(self):
976    return True
977
978  def on_epoch_end(self):
979    if self._enqueuer:
980      self._enqueuer.stop()
981    self._keras_sequence.on_epoch_end()
982
983
984ALL_ADAPTER_CLS = [
985    ListsOfScalarsDataAdapter, TensorLikeDataAdapter,
986    GenericArrayLikeDataAdapter, DatasetAdapter, GeneratorDataAdapter,
987    KerasSequenceAdapter, CompositeTensorDataAdapter, DatasetCreatorAdapter
988]
989
990
991def select_data_adapter(x, y):
992  """Selects a data adapter than can handle a given x and y."""
993  adapter_cls = [cls for cls in ALL_ADAPTER_CLS if cls.can_handle(x, y)]
994  if not adapter_cls:
995    # TODO(scottzhu): This should be a less implementation-specific error.
996    raise ValueError(
997        "Failed to find data adapter that can handle "
998        "input: {}, {}".format(
999            _type_name(x), _type_name(y)))
1000  elif len(adapter_cls) > 1:
1001    raise RuntimeError(
1002        "Data adapters should be mutually exclusive for "
1003        "handling inputs. Found multiple adapters {} to handle "
1004        "input: {}, {}".format(
1005            adapter_cls, _type_name(x), _type_name(y)))
1006  # Instrument the data adapter usage before returning it
1007  keras_data_adapter_gauge.get_cell(adapter_cls[0].__name__).set(True)
1008  return adapter_cls[0]
1009
1010
1011def _type_name(x):
1012  """Generates a description of the type of an object."""
1013  if isinstance(x, dict):
1014    key_types = set(_type_name(key) for key in x.keys())
1015    val_types = set(_type_name(key) for key in x.values())
1016    return "({} containing {} keys and {} values)".format(
1017        type(x), key_types, val_types)
1018  if isinstance(x, (list, tuple)):
1019    types = set(_type_name(val) for val in x)
1020    return "({} containing values of types {})".format(
1021        type(x), types)
1022  return str(type(x))
1023
1024
1025def _process_tensorlike(inputs):
1026  """Process tensor-like inputs.
1027
1028  This function:
1029
1030  (1) Converts `Numpy` arrays to `Tensor`s.
1031  (2) Converts `Scipy` sparse matrices to `SparseTensor`s.
1032  (2) Converts `list`s to `tuple`s (for `tf.data` support).
1033
1034  Args:
1035    inputs: Structure of `Tensor`s, `NumPy` arrays, or tensor-like.
1036
1037  Returns:
1038    Structure of `Tensor`s or tensor-like.
1039  """
1040
1041  def _convert_numpy_and_scipy(x):
1042    if isinstance(x, np.ndarray):
1043      dtype = None
1044      if issubclass(x.dtype.type, np.floating):
1045        dtype = backend.floatx()
1046      return ops.convert_to_tensor_v2_with_dispatch(x, dtype=dtype)
1047    elif scipy_sparse and scipy_sparse.issparse(x):
1048      return _scipy_sparse_to_sparse_tensor(x)
1049    return x
1050
1051  inputs = nest.map_structure(_convert_numpy_and_scipy, inputs)
1052  return nest.list_to_tuple(inputs)
1053
1054
1055def is_none_or_empty(inputs):
1056  # util method to check if the input is a None or a empty list.
1057  # the python "not" check will raise an error like below if the input is a
1058  # numpy array
1059  # "The truth value of an array with more than one element is ambiguous.
1060  # Use a.any() or a.all()"
1061  return inputs is None or not nest.flatten(inputs)
1062
1063
1064def broadcast_sample_weight_modes(target_structure, sample_weight_modes):
1065  """Match sample_weight_modes structure with output structure."""
1066  if target_structure is None or not nest.flatten(target_structure):
1067    return sample_weight_modes
1068
1069  if isinstance(sample_weight_modes, str):
1070    if isinstance(target_structure, dict):
1071      return {key: sample_weight_modes for key in target_structure.keys()}
1072    return [sample_weight_modes for _ in target_structure]
1073
1074  if sample_weight_modes:
1075    try:
1076      nest.assert_same_structure(
1077          training_utils.list_to_tuple(target_structure),
1078          training_utils.list_to_tuple(sample_weight_modes))
1079    except (ValueError, TypeError):
1080      target_str = str(nest.map_structure(lambda _: "...", target_structure))
1081      mode_str = str(nest.map_structure(lambda _: "...", sample_weight_modes))
1082
1083      # Attempt to coerce sample_weight_modes to the target structure. This
1084      # implicitly depends on the fact that Model flattens outputs for its
1085      # internal representation.
1086      try:
1087        sample_weight_modes = nest.pack_sequence_as(
1088            target_structure, nest.flatten(sample_weight_modes))
1089        logging.warning(
1090            "sample_weight modes were coerced from\n  {}\n    to  \n  {}"
1091            .format(target_str, mode_str))
1092      except (ValueError, TypeError):
1093        raise ValueError(
1094            "Unable to match target structure and sample_weight_modes "
1095            "structure:\n  {}\n    to  \n  {}".format(target_str, mode_str))
1096
1097  return sample_weight_modes
1098
1099
1100class DataHandler(object):
1101  """Handles iterating over epoch-level `tf.data.Iterator` objects."""
1102
1103  def __init__(self,
1104               x,
1105               y=None,
1106               sample_weight=None,
1107               batch_size=None,
1108               steps_per_epoch=None,
1109               initial_epoch=0,
1110               epochs=1,
1111               shuffle=False,
1112               class_weight=None,
1113               max_queue_size=10,
1114               workers=1,
1115               use_multiprocessing=False,
1116               model=None,
1117               steps_per_execution=None,
1118               distribute=True):
1119    """Initializes a `DataHandler`.
1120
1121    Arguments:
1122      x: See `Model.fit`.
1123      y: See `Model.fit`.
1124      sample_weight: See `Model.fit`.
1125      batch_size: See `Model.fit`.
1126      steps_per_epoch: See `Model.fit`.
1127      initial_epoch: See `Model.fit`.
1128      epochs: See `Model.fit`.
1129      shuffle: See `Model.fit`.
1130      class_weight: See `Model.fit`.
1131      max_queue_size: See `Model.fit`.
1132      workers: See `Model.fit`.
1133      use_multiprocessing: See `Model.fit`.
1134      model: The `Model` instance. Needed in order to correctly `build` the
1135        `Model` using generator-like inputs (see `GeneratorDataAdapter`).
1136      steps_per_execution: See `Model.compile`.
1137      distribute: Whether to distribute the `tf.dataset`.
1138        `PreprocessingLayer.adapt` does not support distributed datasets,
1139        `Model` should always set this to `True`.
1140    """
1141
1142    self._initial_epoch = initial_epoch
1143    self._epochs = epochs
1144    self._insufficient_data = False
1145    self._model = model
1146
1147    # `steps_per_execution_value` is the cached initial value.
1148    # `steps_per_execution` is mutable and may be changed by the DataAdapter
1149    # to handle partial executions.
1150    if steps_per_execution is None:
1151      self._steps_per_execution = 1
1152      self._steps_per_execution_value = 1
1153    else:
1154      self._steps_per_execution = steps_per_execution
1155      self._steps_per_execution_value = steps_per_execution.numpy().item()
1156
1157    adapter_cls = select_data_adapter(x, y)
1158    self._verify_data_adapter_compatibility(adapter_cls)
1159    self._adapter = adapter_cls(
1160        x,
1161        y,
1162        batch_size=batch_size,
1163        steps=steps_per_epoch,
1164        epochs=epochs - initial_epoch,
1165        sample_weights=sample_weight,
1166        shuffle=shuffle,
1167        max_queue_size=max_queue_size,
1168        workers=workers,
1169        use_multiprocessing=use_multiprocessing,
1170        distribution_strategy=ds_context.get_strategy(),
1171        model=model)
1172
1173    strategy = ds_context.get_strategy()
1174
1175    self._current_step = 0
1176    self._step_increment = self._steps_per_execution_value - 1
1177    self._insufficient_data = False
1178
1179    self._configure_dataset_and_inferred_steps(strategy, x, steps_per_epoch,
1180                                               class_weight, distribute)
1181
1182  def _verify_data_adapter_compatibility(self, adapter_cls):
1183    if adapter_cls == DatasetCreatorAdapter:
1184      raise NotImplementedError("`DatasetCreator` input is only supported in "
1185                                "`ParameterServerStrategy` at this time.")
1186
1187  def _configure_dataset_and_inferred_steps(self, strategy, x, steps_per_epoch,
1188                                            class_weight, distribute):
1189    """Configure the `_dataset` and `_inferred_steps` attributes."""
1190    del x
1191    dataset = self._adapter.get_dataset()
1192    if class_weight:
1193      dataset = dataset.map(_make_class_weight_map_fn(class_weight))
1194    self._inferred_steps = self._infer_steps(steps_per_epoch, dataset)
1195
1196    # `PreprocessingLayer.adapt` does not currently support distributed
1197    # datasets, so we pass `distribute=False` there.
1198    if distribute and not _is_distributed_dataset(dataset):
1199      dataset = strategy.experimental_distribute_dataset(dataset)
1200    self._dataset = dataset
1201    self._validate_data_handler()
1202
1203  def enumerate_epochs(self):
1204    """Yields `(epoch, tf.data.Iterator)`."""
1205    with self._truncate_execution_to_epoch():
1206      data_iterator = iter(self._dataset)
1207      for epoch in range(self._initial_epoch, self._epochs):
1208        if self._insufficient_data:  # Set by `catch_stop_iteration`.
1209          break
1210        if self._adapter.should_recreate_iterator():
1211          data_iterator = iter(self._dataset)
1212        yield epoch, data_iterator
1213        self._adapter.on_epoch_end()
1214
1215  @contextlib.contextmanager
1216  def _truncate_execution_to_epoch(self):
1217    """Truncates steps per execution to at most one epoch."""
1218    should_truncate = (
1219        self._inferred_steps is not None and
1220        self._steps_per_execution_value > self._inferred_steps)
1221    original_value = self._steps_per_execution_value
1222    try:
1223      if should_truncate:
1224        self._steps_per_execution.assign(self._inferred_steps)
1225        self._steps_per_execution_value = self._inferred_steps
1226      yield
1227    finally:
1228      if should_truncate:
1229        self._steps_per_execution.assign(original_value)
1230        self._steps_per_execution_value = original_value
1231
1232  def sync(self):
1233    context.async_wait()
1234
1235  @contextlib.contextmanager
1236  def catch_stop_iteration(self):
1237    """Catches errors when an iterator runs out of data."""
1238    try:
1239      yield
1240      self.sync()
1241    except (StopIteration, errors.OutOfRangeError):
1242      if self._inferred_steps is None:
1243        self._inferred_steps = self._current_step
1244      else:
1245        self._insufficient_data = True
1246        total_epochs = self._epochs - self._initial_epoch
1247        logging.warning(
1248            "Your input ran out of data; interrupting training. "
1249            "Make sure that your dataset or generator can generate at "
1250            "least `steps_per_epoch * epochs` batches (in this case, "
1251            "{} batches). You may need to use the repeat() function "
1252            "when building your dataset.".format(total_epochs *
1253                                                 self._inferred_steps))
1254
1255  def steps(self):
1256    """Yields steps for the current epoch."""
1257    self._current_step = 0
1258    # `self._inferred_steps` can be changed by `catch_stop_iteration`.
1259    while (self._inferred_steps is None or
1260           self._current_step < self._inferred_steps):
1261      if self._insufficient_data:  # Set by `catch_stop_iteration`.
1262        break
1263
1264      can_run_full_execution = (
1265          self._steps_per_execution_value == 1 or
1266          self._inferred_steps is None or
1267          self._inferred_steps - self._current_step >=
1268          self._steps_per_execution_value)
1269
1270      if can_run_full_execution:
1271        self._step_increment = self._steps_per_execution_value - 1
1272        yield self._current_step
1273        self._current_step += self._steps_per_execution_value
1274      else:
1275        # Last partial execution.
1276        steps_remaining = self._inferred_steps - self._current_step
1277        self._steps_per_execution.assign(steps_remaining)
1278        self._step_increment = steps_remaining - 1
1279        yield self._current_step
1280        self._current_step += steps_remaining
1281        self._steps_per_execution.assign(self._steps_per_execution_value)
1282
1283  @property
1284  def step_increment(self):
1285    """The number to increment the step for `on_batch_end` methods."""
1286    return self._step_increment
1287
1288  @property
1289  def inferred_steps(self):
1290    """The inferred steps per epoch of the created `Dataset`.
1291
1292    This will be `None` in the case where:
1293
1294    (1) A `Dataset` of unknown cardinality was passed to the `DataHandler`, and
1295    (2) `steps_per_epoch` was not provided, and
1296    (3) The first epoch of iteration has not yet completed.
1297
1298    Returns:
1299      The inferred steps per epoch of the created `Dataset`.
1300    """
1301    return self._inferred_steps
1302
1303  @property
1304  def should_sync(self):
1305    # Catch OutOfRangeError for Datasets of unknown size.
1306    # This blocks until the batch has finished executing.
1307    # TODO(b/150292341): Allow multiple async steps here.
1308    return self._inferred_steps is None
1309
1310  def _infer_steps(self, steps, dataset):
1311    """Infers steps_per_epoch needed to loop through a dataset."""
1312    if steps is not None:
1313      return steps
1314
1315    adapter_steps = self._adapter.get_size()
1316    if adapter_steps is not None:
1317      return adapter_steps
1318
1319    size = cardinality.cardinality(dataset)
1320    if size == cardinality.INFINITE and steps is None:
1321      raise ValueError("When passing an infinitely repeating dataset, you "
1322                       "must specify how many steps to draw.")
1323    if size >= 0:
1324      return size.numpy().item()
1325    return None
1326
1327  @property
1328  def _samples(self):
1329    return self._adapter.get_samples()
1330
1331  def _validate_data_handler(self):
1332    # TODO(b/152094471): Support this with DistIter.get_next_as_optional.
1333    if self._steps_per_execution_value > 1 and self._inferred_steps is None:
1334      raise ValueError(
1335          "Could not infer the size of the data. With "
1336          "`steps_per_execution > 1`, you must specify the number of steps "
1337          "to run.")
1338
1339  def resolve_logs(self, logs):
1340    return logs
1341
1342
1343class _ClusterCoordinatorDataHandler(DataHandler):
1344  """A `DataHandler` that is compatible with `ClusterCoordinator`."""
1345
1346  def _verify_data_adapter_compatibility(self, adapter_cls):
1347    if adapter_cls != DatasetCreatorAdapter:
1348      raise NotImplementedError("Only `DatasetCreator` input is supported in "
1349                                "`ParameterServerStrategy` at this time.")
1350
1351  def _configure_dataset_and_inferred_steps(self, strategy, x, steps_per_epoch,
1352                                            class_weight, distribute):
1353    if not isinstance(x, dataset_creator.DatasetCreator):
1354      raise TypeError("When using `ParameterServerStrategy`, `x` must be a "
1355                      "`DatasetCreator`.")
1356
1357    def per_worker_dataset_fn():
1358      return strategy.distribute_datasets_from_function(x)
1359
1360    self._dataset = self._model._cluster_coordinator.create_per_worker_dataset(  # pylint: disable=protected-access
1361        per_worker_dataset_fn)
1362    if steps_per_epoch is None:
1363      raise ValueError(
1364          "`steps_per_epoch` must be specified with `ParameterServerStrategy`.")
1365    self._inferred_steps = steps_per_epoch
1366
1367  def sync(self):
1368    self._model._cluster_coordinator.join()  # pylint: disable=protected-access
1369
1370  def resolve_logs(self, logs):
1371    return logs.fetch()
1372
1373
1374def get_data_handler(*args, **kwargs):
1375  if getattr(kwargs["model"], "_cluster_coordinator", None):
1376    return _ClusterCoordinatorDataHandler(*args, **kwargs)
1377  return DataHandler(*args, **kwargs)
1378
1379
1380def _make_class_weight_map_fn(class_weight):
1381  """Applies class weighting to a `Dataset`.
1382
1383  The `Dataset` is assumed to be in format `(x, y)` or `(x, y, sw)`, where
1384  `y` must be a single `Tensor`.
1385
1386  Args:
1387    class_weight: A map where the keys are integer class ids and values are
1388      the class weights, e.g. `{0: 0.2, 1: 0.6, 2: 0.3}`
1389
1390  Returns:
1391    A function that can be used with `tf.data.Dataset.map` to apply class
1392    weighting.
1393  """
1394  class_ids = list(sorted(class_weight.keys()))
1395  expected_class_ids = list(range(len(class_ids)))
1396  if class_ids != expected_class_ids:
1397    error_msg = (
1398        "Expected `class_weight` to be a dict with keys from 0 to one less "
1399        "than the number of classes, found {}").format(class_weight)
1400    raise ValueError(error_msg)
1401
1402  class_weight_tensor = ops.convert_to_tensor_v2_with_dispatch(
1403      [class_weight[int(c)] for c in class_ids])
1404
1405  def _class_weights_map_fn(*data):
1406    """Convert `class_weight` to `sample_weight`."""
1407    x, y, sw = unpack_x_y_sample_weight(data)
1408
1409    if nest.is_nested(y):
1410      raise ValueError(
1411          "`class_weight` is only supported for Models with a single output.")
1412
1413    if y.shape.rank > 2:
1414      raise ValueError("`class_weight` not supported for "
1415                       "3+ dimensional targets.")
1416
1417    y_classes = smart_cond.smart_cond(
1418        y.shape.rank == 2 and backend.shape(y)[1] > 1,
1419        lambda: backend.argmax(y, axis=1),
1420        lambda: math_ops.cast(backend.reshape(y, (-1,)), dtypes.int64))
1421
1422    cw = array_ops.gather_v2(class_weight_tensor, y_classes)
1423    if sw is not None:
1424      cw = math_ops.cast(cw, sw.dtype)
1425      sw, cw = expand_1d((sw, cw))
1426      # `class_weight` and `sample_weight` are multiplicative.
1427      sw = sw * cw
1428    else:
1429      sw = cw
1430
1431    return x, y, sw
1432
1433  return _class_weights_map_fn
1434
1435
1436def expand_1d(data):
1437  """Expands 1-dimensional `Tensor`s into 2-dimensional `Tensor`s."""
1438
1439  def _expand_single_1d_tensor(t):
1440    # Leaves `CompositeTensor`s as-is.
1441    if (isinstance(t, ops.Tensor) and
1442        isinstance(t.shape, tensor_shape.TensorShape) and t.shape.rank == 1):
1443      return array_ops.expand_dims_v2(t, axis=-1)
1444    return t
1445
1446  return nest.map_structure(_expand_single_1d_tensor, data)
1447
1448
1449def train_validation_split(arrays, validation_split):
1450  """Split arrays into train and validation subsets in deterministic order.
1451
1452  The last part of data will become validation data.
1453
1454  Args:
1455    arrays: Tensors to split. Allowed inputs are arbitrarily nested structures
1456      of Tensors and NumPy arrays.
1457    validation_split: Float between 0 and 1. The proportion of the dataset to
1458      include in the validation split. The rest of the dataset will be included
1459      in the training split.
1460  Returns:
1461    `(train_arrays, validation_arrays)`
1462  """
1463
1464  def _can_split(t):
1465    tensor_types = (ops.Tensor, np.ndarray)
1466    if pd:
1467      tensor_types = (ops.Tensor, np.ndarray, pd.Series, pd.DataFrame)
1468    return isinstance(t, tensor_types) or t is None
1469
1470  flat_arrays = nest.flatten(arrays)
1471  unsplitable = [type(t) for t in flat_arrays if not _can_split(t)]
1472  if unsplitable:
1473    raise ValueError(
1474        "`validation_split` is only supported for Tensors or NumPy "
1475        "arrays, found following types in the input: {}".format(unsplitable))
1476
1477  if all(t is None for t in flat_arrays):
1478    return arrays, arrays
1479
1480  first_non_none = None
1481  for t in flat_arrays:
1482    if t is not None:
1483      first_non_none = t
1484      break
1485
1486  # Assumes all arrays have the same batch shape or are `None`.
1487  batch_dim = int(first_non_none.shape[0])
1488  split_at = int(math.floor(batch_dim * (1. - validation_split)))
1489
1490  if split_at == 0 or split_at == batch_dim:
1491    raise ValueError(
1492        "Training data contains {batch_dim} samples, which is not sufficient "
1493        "to split it into a validation and training set as specified by "
1494        "`validation_split={validation_split}`. Either provide more data, or a "
1495        "different value for the `validation_split` argument." .format(
1496            batch_dim=batch_dim, validation_split=validation_split))
1497
1498  def _split(t, start, end):
1499    if t is None:
1500      return t
1501    return t[start:end]
1502
1503  train_arrays = nest.map_structure(
1504      functools.partial(_split, start=0, end=split_at), arrays)
1505  val_arrays = nest.map_structure(
1506      functools.partial(_split, start=split_at, end=batch_dim), arrays)
1507
1508  return train_arrays, val_arrays
1509
1510
1511@keras_export("keras.utils.unpack_x_y_sample_weight", v1=[])
1512def unpack_x_y_sample_weight(data):
1513  """Unpacks user-provided data tuple.
1514
1515  This is a convenience utility to be used when overriding
1516  `Model.train_step`, `Model.test_step`, or `Model.predict_step`.
1517  This utility makes it easy to support data of the form `(x,)`,
1518  `(x, y)`, or `(x, y, sample_weight)`.
1519
1520  Standalone usage:
1521
1522  >>> features_batch = tf.ones((10, 5))
1523  >>> labels_batch = tf.zeros((10, 5))
1524  >>> data = (features_batch, labels_batch)
1525  >>> # `y` and `sample_weight` will default to `None` if not provided.
1526  >>> x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data)
1527  >>> sample_weight is None
1528  True
1529
1530  Example in overridden `Model.train_step`:
1531
1532  ```python
1533  class MyModel(tf.keras.Model):
1534
1535    def train_step(self, data):
1536      # If `sample_weight` is not provided, all samples will be weighted
1537      # equally.
1538      x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data)
1539
1540      with tf.GradientTape() as tape:
1541        y_pred = self(x, training=True)
1542        loss = self.compiled_loss(
1543          y, y_pred, sample_weight, regularization_losses=self.losses)
1544        trainable_variables = self.trainable_variables
1545        gradients = tape.gradient(loss, trainable_variables)
1546        self.optimizer.apply_gradients(zip(gradients, trainable_variables))
1547
1548      self.compiled_metrics.update_state(y, y_pred, sample_weight)
1549      return {m.name: m.result() for m in self.metrics}
1550  ```
1551
1552  Args:
1553    data: A tuple of the form `(x,)`, `(x, y)`, or `(x, y, sample_weight)`.
1554
1555  Returns:
1556    The unpacked tuple, with `None`s for `y` and `sample_weight` if they are not
1557    provided.
1558  """
1559  if not isinstance(data, tuple):
1560    return (data, None, None)
1561  elif len(data) == 1:
1562    return (data[0], None, None)
1563  elif len(data) == 2:
1564    return (data[0], data[1], None)
1565  elif len(data) == 3:
1566    return (data[0], data[1], data[2])
1567  else:
1568    error_msg = ("Data is expected to be in format `x`, `(x,)`, `(x, y)`, "
1569                 "or `(x, y, sample_weight)`, found: {}").format(data)
1570    raise ValueError(error_msg)
1571
1572
1573@keras_export("keras.utils.pack_x_y_sample_weight", v1=[])
1574def pack_x_y_sample_weight(x, y=None, sample_weight=None):
1575  """Packs user-provided data into a tuple.
1576
1577  This is a convenience utility for packing data into the tuple formats
1578  that `Model.fit` uses.
1579
1580  Standalone usage:
1581
1582  >>> x = tf.ones((10, 1))
1583  >>> data = tf.keras.utils.pack_x_y_sample_weight(x)
1584  >>> isinstance(data, tf.Tensor)
1585  True
1586  >>> y = tf.ones((10, 1))
1587  >>> data = tf.keras.utils.pack_x_y_sample_weight(x, y)
1588  >>> isinstance(data, tuple)
1589  True
1590  >>> x, y = data
1591
1592  Args:
1593    x: Features to pass to `Model`.
1594    y: Ground-truth targets to pass to `Model`.
1595    sample_weight: Sample weight for each element.
1596
1597  Returns:
1598    Tuple in the format used in `Model.fit`.
1599  """
1600  if y is None:
1601    # For single x-input, we do no tuple wrapping since in this case
1602    # there is no ambiguity. This also makes NumPy and Dataset
1603    # consistent in that the user does not have to wrap their Dataset
1604    # data in an unecessary tuple
1605    if not nest.is_nested(x):
1606      return x
1607    else:
1608      return (x,)
1609  elif sample_weight is None:
1610    return (x, y)
1611  else:
1612    return (x, y, sample_weight)
1613
1614
1615def single_batch_iterator(strategy,
1616                          x,
1617                          y=None,
1618                          sample_weight=None,
1619                          class_weight=None):
1620  """Creates a single-batch dataset."""
1621  x, y, sample_weight = _process_tensorlike((x, y, sample_weight))
1622  if y is None:
1623    data = (x,)
1624  elif sample_weight is None:
1625    data = (x, y)
1626  else:
1627    data = (x, y, sample_weight)
1628
1629  _check_data_cardinality(data)
1630  dataset = dataset_ops.DatasetV2.from_tensors(data)
1631  if class_weight:
1632    dataset = dataset.map(_make_class_weight_map_fn(class_weight))
1633  dataset = strategy.experimental_distribute_dataset(dataset)
1634  return iter(dataset)
1635
1636
1637def _check_data_cardinality(data):
1638  num_samples = set(int(i.shape[0]) for i in nest.flatten(data))
1639  if len(num_samples) > 1:
1640    msg = "Data cardinality is ambiguous:\n"
1641    for label, single_data in zip(["x", "y", "sample_weight"], data):
1642      msg += "  {} sizes: {}\n".format(
1643          label, ", ".join(str(i.shape[0]) for i in nest.flatten(single_data)))
1644    msg += "Make sure all arrays contain the same number of samples."
1645    raise ValueError(msg)
1646
1647
1648def _scipy_sparse_to_sparse_tensor(t):
1649  """Converts a SciPy sparse matrix to a SparseTensor."""
1650  sparse_coo = t.tocoo()
1651  row, col = sparse_coo.row, sparse_coo.col
1652  data, shape = sparse_coo.data, sparse_coo.shape
1653  if issubclass(data.dtype.type, np.floating):
1654    data = data.astype(backend.floatx())
1655  indices = np.concatenate(
1656      (np.expand_dims(row, axis=1), np.expand_dims(col, axis=1)), axis=1)
1657  return sparse_tensor.SparseTensor(indices, data, shape)
1658
1659
1660def _is_distributed_dataset(ds):
1661  return isinstance(ds, input_lib.DistributedDatasetInterface)
1662