1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for `tf.data.Dataset.repeat()`."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import numpy as np
21
22from tensorflow.python.data.kernel_tests import test_base
23from tensorflow.python.data.ops import dataset_ops
24from tensorflow.python.framework import test_util
25from tensorflow.python.platform import test
26
27
28@test_util.run_all_in_graph_and_eager_modes
29class RepeatTest(test_base.DatasetTestBase):
30
31  def testRepeatTensorDataset(self):
32    """Test a dataset that repeats its input multiple times."""
33    components = (np.array(1), np.array([1, 2, 3]), np.array(37.0))
34    # This placeholder can be fed when dataset-definition subgraph
35    # runs (i.e. `init_op` below) to configure the number of
36    # repetitions used in a particular iterator.
37
38    def do_test(count):
39      dataset = dataset_ops.Dataset.from_tensors(components).repeat(count)
40      self.assertEqual(
41          [c.shape for c in components],
42          [shape for shape in dataset_ops.get_legacy_output_shapes(dataset)])
43      self.assertDatasetProduces(dataset, [components] * count)
44
45    # Test a finite repetition.
46    do_test(3)
47
48    # test a different finite repetition.
49    do_test(7)
50
51    # Test an empty repetition.
52    do_test(0)
53
54    # Test an infinite repetition.
55    # NOTE(mrry): There's not a good way to test that the sequence
56    # actually is infinite.
57    dataset = dataset_ops.Dataset.from_tensors(components).repeat(-1)
58    self.assertEqual(
59        [c.shape for c in components],
60        [shape for shape in dataset_ops.get_legacy_output_shapes(dataset)])
61    get_next = self.getNext(dataset)
62    for _ in range(17):
63      results = self.evaluate(get_next())
64      for component, result_component in zip(components, results):
65        self.assertAllEqual(component, result_component)
66
67  def testRepeatRepeatTensorDataset(self):
68    """Test the composition of repeat datasets."""
69    components = (np.array(1), np.array([1, 2, 3]), np.array(37.0))
70    inner_count, outer_count = 7, 14
71
72    dataset = dataset_ops.Dataset.from_tensors(components).repeat(
73        inner_count).repeat(outer_count)
74    self.assertEqual(
75        [c.shape for c in components],
76        [shape for shape in dataset_ops.get_legacy_output_shapes(dataset)])
77    self.assertDatasetProduces(dataset,
78                               [components] * (inner_count * outer_count))
79
80  def testRepeatEmptyDataset(self):
81    """Test that repeating an empty dataset does not hang."""
82    dataset = dataset_ops.Dataset.from_tensors(0).repeat(10).skip(10).repeat(-1)
83    self.assertDatasetProduces(dataset, [])
84
85
86if __name__ == "__main__":
87  test.main()
88