1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for `tf.data.Dataset.shuffle()`."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import collections
21import functools
22
23from absl.testing import parameterized
24import numpy as np
25
26from tensorflow.python.data.kernel_tests import test_base
27from tensorflow.python.data.ops import dataset_ops
28from tensorflow.python.eager import function
29from tensorflow.python.framework import combinations
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import errors
32from tensorflow.python.framework import ops
33from tensorflow.python.framework import random_seed
34from tensorflow.python.ops import array_ops
35from tensorflow.python.ops import check_ops
36from tensorflow.python.ops import variables
37from tensorflow.python.platform import test
38from tensorflow.python.training import checkpoint_management
39from tensorflow.python.training.tracking import util as trackable_utils
40
41
42class ShuffleTest(test_base.DatasetTestBase, parameterized.TestCase):
43
44  @combinations.generate(test_base.default_test_combinations())
45  def testBasic(self):
46    components = (
47        np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]),
48        np.array([9.0, 10.0, 11.0, 12.0])
49    )
50
51    def dataset_fn(count=5, buffer_size=None, seed=0):
52      repeat_dataset = (
53          dataset_ops.Dataset.from_tensor_slices(components).repeat(count))
54      if buffer_size:
55        shuffle_dataset = repeat_dataset.shuffle(buffer_size, seed)
56
57        self.assertEqual(
58            tuple([c.shape[1:] for c in components]),
59            dataset_ops.get_legacy_output_shapes(shuffle_dataset))
60        return shuffle_dataset
61      else:
62        return repeat_dataset
63
64    # First run without shuffling to collect the "ground truth".
65    get_next = self.getNext(dataset_fn())
66    unshuffled_elements = []
67    for _ in range(20):
68      unshuffled_elements.append(self.evaluate(get_next()))
69    with self.assertRaises(errors.OutOfRangeError):
70      self.evaluate(get_next())
71
72    # Assert that the shuffled dataset has the same elements as the
73    # "ground truth".
74    get_next = self.getNext(dataset_fn(buffer_size=100, seed=37))
75    shuffled_elements = []
76    for _ in range(20):
77      shuffled_elements.append(self.evaluate(get_next()))
78    with self.assertRaises(errors.OutOfRangeError):
79      self.evaluate(get_next())
80    with self.assertRaises(errors.OutOfRangeError):
81      self.evaluate(get_next())
82    self.assertAllEqual(sorted(unshuffled_elements), sorted(shuffled_elements))
83
84    # Assert that shuffling twice with the same seeds gives the same sequence.
85    get_next = self.getNext(dataset_fn(buffer_size=100, seed=37))
86    reshuffled_elements_same_seed = []
87    for _ in range(20):
88      reshuffled_elements_same_seed.append(self.evaluate(get_next()))
89    with self.assertRaises(errors.OutOfRangeError):
90      self.evaluate(get_next())
91    self.assertEqual(shuffled_elements, reshuffled_elements_same_seed)
92
93    # Assert that shuffling twice with a different seed gives a different
94    # permutation of the same elements.
95    get_next = self.getNext(dataset_fn(buffer_size=100, seed=137))
96    reshuffled_elements_different_seed = []
97    for _ in range(20):
98      reshuffled_elements_different_seed.append(self.evaluate(get_next()))
99    with self.assertRaises(errors.OutOfRangeError):
100      self.evaluate(get_next())
101    self.assertNotEqual(shuffled_elements, reshuffled_elements_different_seed)
102    self.assertAllEqual(
103        sorted(shuffled_elements), sorted(reshuffled_elements_different_seed))
104
105    # Assert that the shuffled dataset has the same elements as the
106    # "ground truth" when the buffer size is smaller than the input
107    # dataset.
108    get_next = self.getNext(dataset_fn(buffer_size=2, seed=37))
109    reshuffled_elements_small_buffer = []
110    for _ in range(20):
111      reshuffled_elements_small_buffer.append(self.evaluate(get_next()))
112    with self.assertRaises(errors.OutOfRangeError):
113      self.evaluate(get_next())
114    self.assertAllEqual(
115        sorted(unshuffled_elements), sorted(reshuffled_elements_small_buffer))
116
117    # Test the case of shuffling an empty dataset.
118    get_next = self.getNext(dataset_fn(count=0, buffer_size=100, seed=37))
119
120    with self.assertRaises(errors.OutOfRangeError):
121      self.evaluate(get_next())
122
123  @combinations.generate(combinations.combine(tf_api_version=1, mode="graph"))
124  def testSeedZero(self):
125    """Test for same behavior when the seed is a Python or Tensor zero."""
126    iterator = dataset_ops.make_one_shot_iterator(
127        dataset_ops.Dataset.range(10).shuffle(10, seed=0))
128    get_next = iterator.get_next()
129
130    elems = []
131    with self.cached_session() as sess:
132      for _ in range(10):
133        elems.append(sess.run(get_next))
134      with self.assertRaises(errors.OutOfRangeError):
135        sess.run(get_next)
136
137    seed_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
138    iterator = dataset_ops.make_initializable_iterator(
139        dataset_ops.Dataset.range(10).shuffle(10, seed=seed_placeholder))
140    get_next = iterator.get_next()
141
142    with self.cached_session() as sess:
143      sess.run(iterator.initializer, feed_dict={seed_placeholder: 0})
144      for elem in elems:
145        self.assertEqual(elem, sess.run(get_next))
146      with self.assertRaises(errors.OutOfRangeError):
147        sess.run(get_next)
148
149  @combinations.generate(test_base.default_test_combinations())
150  def testDefaultArguments(self):
151    components = [0, 1, 2, 3, 4]
152    dataset = dataset_ops.Dataset.from_tensor_slices(components).shuffle(
153        5).repeat()
154    get_next = self.getNext(dataset)
155    counts = collections.defaultdict(lambda: 0)
156    for _ in range(10):
157      for _ in range(5):
158        counts[self.evaluate(get_next())] += 1
159
160    for i in range(5):
161      self.assertEqual(10, counts[i])
162
163  @combinations.generate(
164      combinations.times(
165          test_base.graph_only_combinations(),
166          combinations.combine(reshuffle=[True, False]),
167          combinations.combine(graph_seed=38, op_seed=None) +
168          combinations.combine(graph_seed=None, op_seed=42) +
169          combinations.combine(graph_seed=38, op_seed=42)))
170  def testShuffleSeed(self, reshuffle, graph_seed, op_seed):
171    results = []
172    for _ in range(2):
173      with ops.Graph().as_default() as g:
174        random_seed.set_random_seed(graph_seed)
175        dataset = dataset_ops.Dataset.range(10).shuffle(
176            10, seed=op_seed, reshuffle_each_iteration=reshuffle).repeat(3)
177        iterator = dataset_ops.make_one_shot_iterator(dataset)
178        next_element = iterator.get_next()
179
180        run_results = []
181        with self.session(graph=g) as sess:
182          for _ in range(30):
183            run_results.append(sess.run(next_element))
184          with self.assertRaises(errors.OutOfRangeError):
185            sess.run(next_element)
186        results.append(run_results)
187
188    self.assertAllEqual(results[0], results[1])
189
190  # TODO(b/117581999): enable this test for eager-mode.
191  @combinations.generate(
192      combinations.times(
193          test_base.graph_only_combinations(),
194          combinations.combine(
195              reshuffle=[True, False], initializable=[True, False])))
196  def testMultipleIterators(self, reshuffle, initializable):
197    with ops.Graph().as_default() as g:
198      dataset = dataset_ops.Dataset.range(100).shuffle(
199          10, reshuffle_each_iteration=reshuffle).repeat(3)
200
201      if initializable:
202        iterators = [dataset_ops.make_initializable_iterator(dataset)
203                     for _ in range(2)]
204      else:
205        iterators = [dataset_ops.make_one_shot_iterator(dataset)
206                     for _ in range(2)]
207
208      results = []
209      with self.session(graph=g) as sess:
210        for iterator in iterators:
211          if initializable:
212            sess.run(iterator.initializer)
213          next_element = iterator.get_next()
214          run_results = []
215          for _ in range(300):
216            run_results.append(sess.run(next_element))
217          with self.assertRaises(errors.OutOfRangeError):
218            sess.run(next_element)
219
220          results.append(run_results)
221
222        self.assertNotEqual(results[0], results[1])
223
224  @combinations.generate(
225      combinations.times(
226          test_base.default_test_combinations(),
227          combinations.combine(reshuffle=[True, False], seed=[None, 42])))
228  def testReshuffleRepeatEpochs(self, reshuffle, seed):
229    dataset = dataset_ops.Dataset.range(10).shuffle(
230        10, seed=seed, reshuffle_each_iteration=reshuffle).repeat(2)
231    next_element = self.getNext(dataset)
232
233    first_epoch = []
234    for _ in range(10):
235      first_epoch.append(self.evaluate(next_element()))
236
237    second_epoch = []
238    for _ in range(10):
239      second_epoch.append(self.evaluate(next_element()))
240
241    self.assertEqual(first_epoch == second_epoch, not reshuffle)
242
243  @combinations.generate(
244      combinations.times(
245          combinations.combine(tf_api_version=2, mode="eager"),
246          combinations.combine(reshuffle=[True, False], seed=[None, 42])))
247  def testReshuffleIterationEpochs(self, reshuffle, seed):
248    # TensorFlow unit tests set the global graph seed. We unset it here so that
249    # we can control determinism via the `seed` parameter.
250    random_seed.set_random_seed(None)
251    dataset = dataset_ops.Dataset.range(10).shuffle(
252        10, seed=seed, reshuffle_each_iteration=reshuffle)
253
254    first_epoch = self.getDatasetOutput(dataset)
255    second_epoch = self.getDatasetOutput(dataset)
256
257    self.assertEqual(first_epoch == second_epoch, not reshuffle)
258
259  @combinations.generate(combinations.combine(tf_api_version=2, mode="eager"))
260  def testShuffleV2ResourceCapture(self):
261
262    def make_dataset():
263      ids = dataset_ops.Dataset.range(10)
264      ids = ids.shuffle(1)
265
266      def interleave_fn(dataset, _):
267        return dataset
268
269      dataset = dataset_ops.Dataset.range(1)
270      dataset = dataset.interleave(functools.partial(interleave_fn, ids))
271      return dataset
272
273    results = []
274    for elem in make_dataset():
275      results.append(elem.numpy())
276
277    self.assertAllEqual(results, range(10))
278
279  @combinations.generate(
280      combinations.times(
281          test_base.eager_only_combinations(),
282          combinations.combine(reshuffle=[True, False], seed=[None, 42])))
283  def testReshuffleSeparateTransformations(self, reshuffle, seed):
284    dataset = dataset_ops.Dataset.range(10)
285
286    first_epoch = []
287    for elem in dataset.shuffle(
288        10, seed=seed, reshuffle_each_iteration=reshuffle):
289      first_epoch.append(elem.numpy())
290
291    second_epoch = []
292    for elem in dataset.shuffle(
293        10, seed=seed, reshuffle_each_iteration=reshuffle):
294      second_epoch.append(elem.numpy())
295
296    self.assertEqual(first_epoch != second_epoch, seed is None)
297
298  @combinations.generate(combinations.combine(tf_api_version=2, mode="eager"))
299  def testShuffleV2InFunction(self):
300    counter_var = variables.Variable(0)
301
302    @function.defun
303    def consume():
304      ds = dataset_ops.Dataset.range(10)
305      ds = ds.shuffle(1)
306      for _ in ds:
307        counter_var.assign(counter_var + 1)
308
309    consume()
310    self.assertAllEqual(self.evaluate(counter_var), 10)
311
312  @combinations.generate(test_base.default_test_combinations())
313  def testEmptyDataset(self):
314    dataset = dataset_ops.Dataset.from_tensors(1)
315
316    def map_fn(x):
317      with ops.control_dependencies([check_ops.assert_equal(x, 0)]):
318        return x
319
320    dataset = dataset.map(map_fn)
321    dataset = dataset.cache()
322    dataset = dataset.shuffle(buffer_size=10).repeat()
323
324    get_next = self.getNext(dataset)
325
326    # First time around, we get an error for the failed assertion.
327    with self.assertRaises(errors.InvalidArgumentError):
328      self.evaluate(get_next())
329
330    # Second time around, we get an EOF because the cached dataset is empty.
331    with self.assertRaises(errors.OutOfRangeError):
332      self.evaluate(get_next())
333
334  @combinations.generate(
335      combinations.times(
336          test_base.default_test_combinations(),
337          combinations.combine(reshuffle=[True, False])))
338  def testRerandomizeOnReplicate(self, reshuffle):
339    random_seed.set_random_seed(None)
340    # When no seeds are fixed, each instantiation of the shuffle dataset should
341    # produce elements in a different order.
342    num_elements = 100
343    dataset = dataset_ops.Dataset.range(num_elements)
344    dataset = dataset.shuffle(num_elements, reshuffle_each_iteration=reshuffle)
345
346    shuffle_1 = self.getDatasetOutput(dataset)
347    dataset = self.graphRoundTrip(dataset, allow_stateful=True)
348    shuffle_2 = self.getDatasetOutput(dataset)
349
350    self.assertCountEqual(shuffle_1, shuffle_2)
351    self.assertNotEqual(shuffle_1, shuffle_2)
352
353  @combinations.generate(test_base.eager_only_combinations())
354  def testCheckpointLargeShuffleBuffer(self):
355    # Tensor of size 100M
356    dataset = dataset_ops.Dataset.from_tensors(
357        array_ops.ones((25, 1000, 1000), dtype=dtypes.float32))
358    dataset = dataset.repeat()
359    # Shuffle 25 tensors to exceed the 2GB protocol buffer limit
360    dataset = dataset.shuffle(25)
361
362    iterator = iter(dataset)
363    next(iterator)  # request an element to fill the shuffle buffer
364    ckpt = trackable_utils.Checkpoint(iterator=iterator)
365    manager = checkpoint_management.CheckpointManager(
366        ckpt, self.get_temp_dir(), max_to_keep=1)
367    manager.save()
368    ckpt.restore(manager.latest_checkpoint)
369
370
371if __name__ == "__main__":
372  test.main()
373