1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for `tf.data.Dataset.repeat()`.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import numpy as np 21 22from tensorflow.python.data.kernel_tests import test_base 23from tensorflow.python.data.ops import dataset_ops 24from tensorflow.python.framework import test_util 25from tensorflow.python.platform import test 26 27 28@test_util.run_all_in_graph_and_eager_modes 29class RepeatTest(test_base.DatasetTestBase): 30 31 def testRepeatTensorDataset(self): 32 """Test a dataset that repeats its input multiple times.""" 33 components = (np.array(1), np.array([1, 2, 3]), np.array(37.0)) 34 # This placeholder can be fed when dataset-definition subgraph 35 # runs (i.e. `init_op` below) to configure the number of 36 # repetitions used in a particular iterator. 37 38 def do_test(count): 39 dataset = dataset_ops.Dataset.from_tensors(components).repeat(count) 40 self.assertEqual( 41 [c.shape for c in components], 42 [shape for shape in dataset_ops.get_legacy_output_shapes(dataset)]) 43 self.assertDatasetProduces(dataset, [components] * count) 44 45 # Test a finite repetition. 46 do_test(3) 47 48 # test a different finite repetition. 49 do_test(7) 50 51 # Test an empty repetition. 52 do_test(0) 53 54 # Test an infinite repetition. 55 # NOTE(mrry): There's not a good way to test that the sequence 56 # actually is infinite. 57 dataset = dataset_ops.Dataset.from_tensors(components).repeat(-1) 58 self.assertEqual( 59 [c.shape for c in components], 60 [shape for shape in dataset_ops.get_legacy_output_shapes(dataset)]) 61 get_next = self.getNext(dataset) 62 for _ in range(17): 63 results = self.evaluate(get_next()) 64 for component, result_component in zip(components, results): 65 self.assertAllEqual(component, result_component) 66 67 def testRepeatRepeatTensorDataset(self): 68 """Test the composition of repeat datasets.""" 69 components = (np.array(1), np.array([1, 2, 3]), np.array(37.0)) 70 inner_count, outer_count = 7, 14 71 72 dataset = dataset_ops.Dataset.from_tensors(components).repeat( 73 inner_count).repeat(outer_count) 74 self.assertEqual( 75 [c.shape for c in components], 76 [shape for shape in dataset_ops.get_legacy_output_shapes(dataset)]) 77 self.assertDatasetProduces(dataset, 78 [components] * (inner_count * outer_count)) 79 80 def testRepeatEmptyDataset(self): 81 """Test that repeating an empty dataset does not hang.""" 82 dataset = dataset_ops.Dataset.from_tensors(0).repeat(10).skip(10).repeat(-1) 83 self.assertDatasetProduces(dataset, []) 84 85 86if __name__ == "__main__": 87 test.main() 88