1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for `tf.data.Dataset.shuffle()`."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import collections
21
22from absl.testing import parameterized
23import numpy as np
24
25from tensorflow.python.data.kernel_tests import test_base
26from tensorflow.python.data.ops import dataset_ops
27
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import errors
30from tensorflow.python.framework import ops
31from tensorflow.python.framework import random_seed
32from tensorflow.python.framework import test_util
33from tensorflow.python.ops import array_ops
34from tensorflow.python.platform import test
35
36
37@test_util.run_all_in_graph_and_eager_modes
38class ShuffleTest(test_base.DatasetTestBase, parameterized.TestCase):
39
40  def testShuffleDataset(self):
41    components = (
42        np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]),
43        np.array([9.0, 10.0, 11.0, 12.0])
44    )
45
46    def dataset_fn(count=5, buffer_size=None, seed=0):
47      repeat_dataset = (
48          dataset_ops.Dataset.from_tensor_slices(components).repeat(count))
49      if buffer_size:
50        shuffle_dataset = repeat_dataset.shuffle(buffer_size, seed)
51
52        self.assertEqual(
53            tuple([c.shape[1:] for c in components]),
54            dataset_ops.get_legacy_output_shapes(shuffle_dataset))
55        return shuffle_dataset
56      else:
57        return repeat_dataset
58
59    # First run without shuffling to collect the "ground truth".
60    get_next = self.getNext(dataset_fn())
61    unshuffled_elements = []
62    for _ in range(20):
63      unshuffled_elements.append(self.evaluate(get_next()))
64    with self.assertRaises(errors.OutOfRangeError):
65      self.evaluate(get_next())
66
67    # Assert that the shuffled dataset has the same elements as the
68    # "ground truth".
69    get_next = self.getNext(dataset_fn(buffer_size=100, seed=37))
70    shuffled_elements = []
71    for _ in range(20):
72      shuffled_elements.append(self.evaluate(get_next()))
73    with self.assertRaises(errors.OutOfRangeError):
74      self.evaluate(get_next())
75    with self.assertRaises(errors.OutOfRangeError):
76      self.evaluate(get_next())
77    self.assertAllEqual(sorted(unshuffled_elements), sorted(shuffled_elements))
78
79    # Assert that shuffling twice with the same seeds gives the same sequence.
80    get_next = self.getNext(dataset_fn(buffer_size=100, seed=37))
81    reshuffled_elements_same_seed = []
82    for _ in range(20):
83      reshuffled_elements_same_seed.append(self.evaluate(get_next()))
84    with self.assertRaises(errors.OutOfRangeError):
85      self.evaluate(get_next())
86    self.assertEqual(shuffled_elements, reshuffled_elements_same_seed)
87
88    # Assert that shuffling twice with a different seed gives a different
89    # permutation of the same elements.
90    get_next = self.getNext(dataset_fn(buffer_size=100, seed=137))
91    reshuffled_elements_different_seed = []
92    for _ in range(20):
93      reshuffled_elements_different_seed.append(self.evaluate(get_next()))
94    with self.assertRaises(errors.OutOfRangeError):
95      self.evaluate(get_next())
96    self.assertNotEqual(shuffled_elements, reshuffled_elements_different_seed)
97    self.assertAllEqual(
98        sorted(shuffled_elements), sorted(reshuffled_elements_different_seed))
99
100    # Assert that the shuffled dataset has the same elements as the
101    # "ground truth" when the buffer size is smaller than the input
102    # dataset.
103    get_next = self.getNext(dataset_fn(buffer_size=2, seed=37))
104    reshuffled_elements_small_buffer = []
105    for _ in range(20):
106      reshuffled_elements_small_buffer.append(self.evaluate(get_next()))
107    with self.assertRaises(errors.OutOfRangeError):
108      self.evaluate(get_next())
109    self.assertAllEqual(
110        sorted(unshuffled_elements), sorted(reshuffled_elements_small_buffer))
111
112    # Test the case of shuffling an empty dataset.
113    get_next = self.getNext(dataset_fn(count=0, buffer_size=100, seed=37))
114
115    with self.assertRaises(errors.OutOfRangeError):
116      self.evaluate(get_next())
117
118  @test_util.run_deprecated_v1
119  def testSkipEagerSeedZero(self):
120    """Test for same behavior when the seed is a Python or Tensor zero."""
121    iterator = dataset_ops.make_one_shot_iterator(
122        dataset_ops.Dataset.range(10).shuffle(10, seed=0))
123    get_next = iterator.get_next()
124
125    elems = []
126    with self.cached_session() as sess:
127      for _ in range(10):
128        elems.append(sess.run(get_next))
129      with self.assertRaises(errors.OutOfRangeError):
130        sess.run(get_next)
131
132    seed_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
133    iterator = dataset_ops.make_initializable_iterator(
134        dataset_ops.Dataset.range(10).shuffle(10, seed=seed_placeholder))
135    get_next = iterator.get_next()
136
137    with self.cached_session() as sess:
138      sess.run(iterator.initializer, feed_dict={seed_placeholder: 0})
139      for elem in elems:
140        self.assertEqual(elem, sess.run(get_next))
141      with self.assertRaises(errors.OutOfRangeError):
142        sess.run(get_next)
143
144  def testDefaultArguments(self):
145    components = [0, 1, 2, 3, 4]
146    dataset = dataset_ops.Dataset.from_tensor_slices(components).shuffle(
147        5).repeat()
148    get_next = self.getNext(dataset)
149    counts = collections.defaultdict(lambda: 0)
150    for _ in range(10):
151      for _ in range(5):
152        counts[self.evaluate(get_next())] += 1
153
154    for i in range(5):
155      self.assertEqual(10, counts[i])
156
157  def testShuffleNoReshuffleEachIteration(self):
158    dataset = dataset_ops.Dataset.range(10).shuffle(
159        10, reshuffle_each_iteration=False).batch(10).repeat(3)
160    next_element = self.getNext(dataset)
161
162    initial_permutation = self.evaluate(next_element())
163    self.assertAllEqual(initial_permutation, self.evaluate(next_element()))
164    self.assertAllEqual(initial_permutation, self.evaluate(next_element()))
165    with self.assertRaises(errors.OutOfRangeError):
166      self.evaluate(next_element())
167
168  def testShuffleReshuffleEachIteration(self):
169    dataset = dataset_ops.Dataset.range(10).shuffle(
170        10, seed=3, reshuffle_each_iteration=True).batch(10).repeat(3)
171    next_element = self.getNext(dataset)
172
173    initial_permutation = list(self.evaluate(next_element()))
174    for _ in range(2):
175      next_permutation = list(self.evaluate(next_element()))
176      self.assertNotEqual(initial_permutation, next_permutation)
177      self.assertAllEqual(sorted(initial_permutation), sorted(next_permutation))
178    with self.assertRaises(errors.OutOfRangeError):
179      self.evaluate(next_element())
180
181  @parameterized.named_parameters(
182      ("ReshuffleGraphLevelSeed", True, 38, None),
183      ("ReshuffleOpLevelSeed", True, None, 42),
184      ("ReshuffleGraphAndOpLevelSeed", True, 38, 42),
185      ("NoReshuffleGraphLevelSeed", False, 38, None),
186      ("NoReshuffleOpLevelSeed", False, None, 42),
187      ("NoReshuffleGraphAndOpLevelSeed", False, 38, 42),
188  )
189  def testSkipEagerShuffleSeed(self, reshuffle, graph_level_seed,
190                               op_level_seed):
191    results = []
192    for _ in range(2):
193      with ops.Graph().as_default() as g:
194        random_seed.set_random_seed(graph_level_seed)
195        dataset = dataset_ops.Dataset.range(10).shuffle(
196            10, seed=op_level_seed, reshuffle_each_iteration=reshuffle).repeat(
197                3)
198        iterator = dataset_ops.make_one_shot_iterator(dataset)
199        next_element = iterator.get_next()
200
201        run_results = []
202        with self.session(graph=g) as sess:
203          for _ in range(30):
204            run_results.append(sess.run(next_element))
205          with self.assertRaises(errors.OutOfRangeError):
206            sess.run(next_element)
207        results.append(run_results)
208
209    self.assertAllEqual(results[0], results[1])
210
211  # TODO(b/117581999): fails for eager mode with result[0] equal to result[1],
212  # debug.
213  @parameterized.named_parameters(
214      ("ReshuffleOneShot", True, False),
215      ("ReshuffleInitializable", True, True),
216      ("NoReshuffleOneShot", False, False),
217      ("NoReshuffleInitializable", False, True),
218  )
219  def testSkipEagerMultipleIterators(self, reshuffle, initializable):
220    with ops.Graph().as_default() as g:
221      dataset = dataset_ops.Dataset.range(100).shuffle(
222          10, reshuffle_each_iteration=reshuffle).repeat(3)
223
224      if initializable:
225        iterators = [dataset_ops.make_initializable_iterator(dataset)
226                     for _ in range(2)]
227      else:
228        iterators = [dataset_ops.make_one_shot_iterator(dataset)
229                     for _ in range(2)]
230
231      results = []
232      with self.session(graph=g) as sess:
233        for iterator in iterators:
234          if initializable:
235            sess.run(iterator.initializer)
236          next_element = iterator.get_next()
237          run_results = []
238          for _ in range(300):
239            run_results.append(sess.run(next_element))
240          with self.assertRaises(errors.OutOfRangeError):
241            sess.run(next_element)
242
243          results.append(run_results)
244
245        self.assertNotEqual(results[0], results[1])
246
247
248if __name__ == "__main__":
249  test.main()
250