1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=g-classes-have-attributes
16"""Input dataset creator for `model.fit`."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20from tensorflow.python.data.ops import dataset_ops
21
22
23class DatasetCreator(object):
24  """Object that returns a `tf.data.Dataset` upon invoking.
25
26  `DatasetCreator` is designated as a supported type for `x`, or the input, in
27  `tf.keras.Model.fit`. Pass an instance of this class to `fit` when using a
28  callable (with a `input_context` argument) that returns a `tf.data.Dataset`.
29
30  ```python
31  model = tf.keras.Sequential([tf.keras.layers.Dense(10)])
32  model.compile(tf.keras.optimizers.SGD(), loss="mse")
33
34  def dataset_fn(input_context):
35    global_batch_size = 64
36    batch_size = input_context.get_per_replica_batch_size(global_batch_size)
37    dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat()
38    dataset = dataset.shard(
39        input_context.num_input_pipelines, input_context.input_pipeline_id)
40    dataset = dataset.batch(batch_size)
41    dataset = dataset.prefetch(2)
42    return dataset
43
44  model.fit(DatasetCreator(dataset_fn), epochs=10, steps_per_epoch=10)
45  ```
46
47  Args:
48    dataset_fn: A callable that takes a single argument of type
49      `tf.distribute.InputContext`, which is used for batch size calculation and
50      cross-worker input pipeline sharding (if neither is needed, the
51      `InputContext` parameter can be ignored in the `dataset_fn`), and returns
52      a `tf.data.Dataset`.
53  """
54
55  def __init__(self, dataset_fn):
56    if not callable(dataset_fn):
57      raise TypeError('`dataset_fn` for `DatasetCreator` must be a `callable`.')
58    self.dataset_fn = dataset_fn
59
60  def __call__(self, *args, **kwargs):
61    # When a `DatasetCreator` is invoked, it forwards args/kwargs straight to
62    # the callable.
63    dataset = self.dataset_fn(*args, **kwargs)
64    if not isinstance(dataset, dataset_ops.DatasetV2):
65      raise TypeError('The `callable` provided to `DatasetCreator` must return '
66                      'a Dataset.')
67    return dataset
68