1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for `tf.data.Dataset.shuffle()`.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import collections 21import functools 22 23from absl.testing import parameterized 24import numpy as np 25 26from tensorflow.python.data.kernel_tests import test_base 27from tensorflow.python.data.ops import dataset_ops 28from tensorflow.python.eager import function 29from tensorflow.python.framework import combinations 30from tensorflow.python.framework import dtypes 31from tensorflow.python.framework import errors 32from tensorflow.python.framework import ops 33from tensorflow.python.framework import random_seed 34from tensorflow.python.ops import array_ops 35from tensorflow.python.ops import check_ops 36from tensorflow.python.ops import variables 37from tensorflow.python.platform import test 38from tensorflow.python.training import checkpoint_management 39from tensorflow.python.training.tracking import util as trackable_utils 40 41 42class ShuffleTest(test_base.DatasetTestBase, parameterized.TestCase): 43 44 @combinations.generate(test_base.default_test_combinations()) 45 def testBasic(self): 46 components = ( 47 np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]), 48 np.array([9.0, 10.0, 11.0, 12.0]) 49 ) 50 51 def dataset_fn(count=5, buffer_size=None, seed=0): 52 repeat_dataset = ( 53 dataset_ops.Dataset.from_tensor_slices(components).repeat(count)) 54 if buffer_size: 55 shuffle_dataset = repeat_dataset.shuffle(buffer_size, seed) 56 57 self.assertEqual( 58 tuple([c.shape[1:] for c in components]), 59 dataset_ops.get_legacy_output_shapes(shuffle_dataset)) 60 return shuffle_dataset 61 else: 62 return repeat_dataset 63 64 # First run without shuffling to collect the "ground truth". 65 get_next = self.getNext(dataset_fn()) 66 unshuffled_elements = [] 67 for _ in range(20): 68 unshuffled_elements.append(self.evaluate(get_next())) 69 with self.assertRaises(errors.OutOfRangeError): 70 self.evaluate(get_next()) 71 72 # Assert that the shuffled dataset has the same elements as the 73 # "ground truth". 74 get_next = self.getNext(dataset_fn(buffer_size=100, seed=37)) 75 shuffled_elements = [] 76 for _ in range(20): 77 shuffled_elements.append(self.evaluate(get_next())) 78 with self.assertRaises(errors.OutOfRangeError): 79 self.evaluate(get_next()) 80 with self.assertRaises(errors.OutOfRangeError): 81 self.evaluate(get_next()) 82 self.assertAllEqual(sorted(unshuffled_elements), sorted(shuffled_elements)) 83 84 # Assert that shuffling twice with the same seeds gives the same sequence. 85 get_next = self.getNext(dataset_fn(buffer_size=100, seed=37)) 86 reshuffled_elements_same_seed = [] 87 for _ in range(20): 88 reshuffled_elements_same_seed.append(self.evaluate(get_next())) 89 with self.assertRaises(errors.OutOfRangeError): 90 self.evaluate(get_next()) 91 self.assertEqual(shuffled_elements, reshuffled_elements_same_seed) 92 93 # Assert that shuffling twice with a different seed gives a different 94 # permutation of the same elements. 95 get_next = self.getNext(dataset_fn(buffer_size=100, seed=137)) 96 reshuffled_elements_different_seed = [] 97 for _ in range(20): 98 reshuffled_elements_different_seed.append(self.evaluate(get_next())) 99 with self.assertRaises(errors.OutOfRangeError): 100 self.evaluate(get_next()) 101 self.assertNotEqual(shuffled_elements, reshuffled_elements_different_seed) 102 self.assertAllEqual( 103 sorted(shuffled_elements), sorted(reshuffled_elements_different_seed)) 104 105 # Assert that the shuffled dataset has the same elements as the 106 # "ground truth" when the buffer size is smaller than the input 107 # dataset. 108 get_next = self.getNext(dataset_fn(buffer_size=2, seed=37)) 109 reshuffled_elements_small_buffer = [] 110 for _ in range(20): 111 reshuffled_elements_small_buffer.append(self.evaluate(get_next())) 112 with self.assertRaises(errors.OutOfRangeError): 113 self.evaluate(get_next()) 114 self.assertAllEqual( 115 sorted(unshuffled_elements), sorted(reshuffled_elements_small_buffer)) 116 117 # Test the case of shuffling an empty dataset. 118 get_next = self.getNext(dataset_fn(count=0, buffer_size=100, seed=37)) 119 120 with self.assertRaises(errors.OutOfRangeError): 121 self.evaluate(get_next()) 122 123 @combinations.generate(combinations.combine(tf_api_version=1, mode="graph")) 124 def testSeedZero(self): 125 """Test for same behavior when the seed is a Python or Tensor zero.""" 126 iterator = dataset_ops.make_one_shot_iterator( 127 dataset_ops.Dataset.range(10).shuffle(10, seed=0)) 128 get_next = iterator.get_next() 129 130 elems = [] 131 with self.cached_session() as sess: 132 for _ in range(10): 133 elems.append(sess.run(get_next)) 134 with self.assertRaises(errors.OutOfRangeError): 135 sess.run(get_next) 136 137 seed_placeholder = array_ops.placeholder(dtypes.int64, shape=[]) 138 iterator = dataset_ops.make_initializable_iterator( 139 dataset_ops.Dataset.range(10).shuffle(10, seed=seed_placeholder)) 140 get_next = iterator.get_next() 141 142 with self.cached_session() as sess: 143 sess.run(iterator.initializer, feed_dict={seed_placeholder: 0}) 144 for elem in elems: 145 self.assertEqual(elem, sess.run(get_next)) 146 with self.assertRaises(errors.OutOfRangeError): 147 sess.run(get_next) 148 149 @combinations.generate(test_base.default_test_combinations()) 150 def testDefaultArguments(self): 151 components = [0, 1, 2, 3, 4] 152 dataset = dataset_ops.Dataset.from_tensor_slices(components).shuffle( 153 5).repeat() 154 get_next = self.getNext(dataset) 155 counts = collections.defaultdict(lambda: 0) 156 for _ in range(10): 157 for _ in range(5): 158 counts[self.evaluate(get_next())] += 1 159 160 for i in range(5): 161 self.assertEqual(10, counts[i]) 162 163 @combinations.generate( 164 combinations.times( 165 test_base.graph_only_combinations(), 166 combinations.combine(reshuffle=[True, False]), 167 combinations.combine(graph_seed=38, op_seed=None) + 168 combinations.combine(graph_seed=None, op_seed=42) + 169 combinations.combine(graph_seed=38, op_seed=42))) 170 def testShuffleSeed(self, reshuffle, graph_seed, op_seed): 171 results = [] 172 for _ in range(2): 173 with ops.Graph().as_default() as g: 174 random_seed.set_random_seed(graph_seed) 175 dataset = dataset_ops.Dataset.range(10).shuffle( 176 10, seed=op_seed, reshuffle_each_iteration=reshuffle).repeat(3) 177 iterator = dataset_ops.make_one_shot_iterator(dataset) 178 next_element = iterator.get_next() 179 180 run_results = [] 181 with self.session(graph=g) as sess: 182 for _ in range(30): 183 run_results.append(sess.run(next_element)) 184 with self.assertRaises(errors.OutOfRangeError): 185 sess.run(next_element) 186 results.append(run_results) 187 188 self.assertAllEqual(results[0], results[1]) 189 190 # TODO(b/117581999): enable this test for eager-mode. 191 @combinations.generate( 192 combinations.times( 193 test_base.graph_only_combinations(), 194 combinations.combine( 195 reshuffle=[True, False], initializable=[True, False]))) 196 def testMultipleIterators(self, reshuffle, initializable): 197 with ops.Graph().as_default() as g: 198 dataset = dataset_ops.Dataset.range(100).shuffle( 199 10, reshuffle_each_iteration=reshuffle).repeat(3) 200 201 if initializable: 202 iterators = [dataset_ops.make_initializable_iterator(dataset) 203 for _ in range(2)] 204 else: 205 iterators = [dataset_ops.make_one_shot_iterator(dataset) 206 for _ in range(2)] 207 208 results = [] 209 with self.session(graph=g) as sess: 210 for iterator in iterators: 211 if initializable: 212 sess.run(iterator.initializer) 213 next_element = iterator.get_next() 214 run_results = [] 215 for _ in range(300): 216 run_results.append(sess.run(next_element)) 217 with self.assertRaises(errors.OutOfRangeError): 218 sess.run(next_element) 219 220 results.append(run_results) 221 222 self.assertNotEqual(results[0], results[1]) 223 224 @combinations.generate( 225 combinations.times( 226 test_base.default_test_combinations(), 227 combinations.combine(reshuffle=[True, False], seed=[None, 42]))) 228 def testReshuffleRepeatEpochs(self, reshuffle, seed): 229 dataset = dataset_ops.Dataset.range(10).shuffle( 230 10, seed=seed, reshuffle_each_iteration=reshuffle).repeat(2) 231 next_element = self.getNext(dataset) 232 233 first_epoch = [] 234 for _ in range(10): 235 first_epoch.append(self.evaluate(next_element())) 236 237 second_epoch = [] 238 for _ in range(10): 239 second_epoch.append(self.evaluate(next_element())) 240 241 self.assertEqual(first_epoch == second_epoch, not reshuffle) 242 243 @combinations.generate( 244 combinations.times( 245 combinations.combine(tf_api_version=2, mode="eager"), 246 combinations.combine(reshuffle=[True, False], seed=[None, 42]))) 247 def testReshuffleIterationEpochs(self, reshuffle, seed): 248 # TensorFlow unit tests set the global graph seed. We unset it here so that 249 # we can control determinism via the `seed` parameter. 250 random_seed.set_random_seed(None) 251 dataset = dataset_ops.Dataset.range(10).shuffle( 252 10, seed=seed, reshuffle_each_iteration=reshuffle) 253 254 first_epoch = self.getDatasetOutput(dataset) 255 second_epoch = self.getDatasetOutput(dataset) 256 257 self.assertEqual(first_epoch == second_epoch, not reshuffle) 258 259 @combinations.generate(combinations.combine(tf_api_version=2, mode="eager")) 260 def testShuffleV2ResourceCapture(self): 261 262 def make_dataset(): 263 ids = dataset_ops.Dataset.range(10) 264 ids = ids.shuffle(1) 265 266 def interleave_fn(dataset, _): 267 return dataset 268 269 dataset = dataset_ops.Dataset.range(1) 270 dataset = dataset.interleave(functools.partial(interleave_fn, ids)) 271 return dataset 272 273 results = [] 274 for elem in make_dataset(): 275 results.append(elem.numpy()) 276 277 self.assertAllEqual(results, range(10)) 278 279 @combinations.generate( 280 combinations.times( 281 test_base.eager_only_combinations(), 282 combinations.combine(reshuffle=[True, False], seed=[None, 42]))) 283 def testReshuffleSeparateTransformations(self, reshuffle, seed): 284 dataset = dataset_ops.Dataset.range(10) 285 286 first_epoch = [] 287 for elem in dataset.shuffle( 288 10, seed=seed, reshuffle_each_iteration=reshuffle): 289 first_epoch.append(elem.numpy()) 290 291 second_epoch = [] 292 for elem in dataset.shuffle( 293 10, seed=seed, reshuffle_each_iteration=reshuffle): 294 second_epoch.append(elem.numpy()) 295 296 self.assertEqual(first_epoch != second_epoch, seed is None) 297 298 @combinations.generate(combinations.combine(tf_api_version=2, mode="eager")) 299 def testShuffleV2InFunction(self): 300 counter_var = variables.Variable(0) 301 302 @function.defun 303 def consume(): 304 ds = dataset_ops.Dataset.range(10) 305 ds = ds.shuffle(1) 306 for _ in ds: 307 counter_var.assign(counter_var + 1) 308 309 consume() 310 self.assertAllEqual(self.evaluate(counter_var), 10) 311 312 @combinations.generate(test_base.default_test_combinations()) 313 def testEmptyDataset(self): 314 dataset = dataset_ops.Dataset.from_tensors(1) 315 316 def map_fn(x): 317 with ops.control_dependencies([check_ops.assert_equal(x, 0)]): 318 return x 319 320 dataset = dataset.map(map_fn) 321 dataset = dataset.cache() 322 dataset = dataset.shuffle(buffer_size=10).repeat() 323 324 get_next = self.getNext(dataset) 325 326 # First time around, we get an error for the failed assertion. 327 with self.assertRaises(errors.InvalidArgumentError): 328 self.evaluate(get_next()) 329 330 # Second time around, we get an EOF because the cached dataset is empty. 331 with self.assertRaises(errors.OutOfRangeError): 332 self.evaluate(get_next()) 333 334 @combinations.generate( 335 combinations.times( 336 test_base.default_test_combinations(), 337 combinations.combine(reshuffle=[True, False]))) 338 def testRerandomizeOnReplicate(self, reshuffle): 339 random_seed.set_random_seed(None) 340 # When no seeds are fixed, each instantiation of the shuffle dataset should 341 # produce elements in a different order. 342 num_elements = 100 343 dataset = dataset_ops.Dataset.range(num_elements) 344 dataset = dataset.shuffle(num_elements, reshuffle_each_iteration=reshuffle) 345 346 shuffle_1 = self.getDatasetOutput(dataset) 347 dataset = self.graphRoundTrip(dataset, allow_stateful=True) 348 shuffle_2 = self.getDatasetOutput(dataset) 349 350 self.assertCountEqual(shuffle_1, shuffle_2) 351 self.assertNotEqual(shuffle_1, shuffle_2) 352 353 @combinations.generate(test_base.eager_only_combinations()) 354 def testCheckpointLargeShuffleBuffer(self): 355 # Tensor of size 100M 356 dataset = dataset_ops.Dataset.from_tensors( 357 array_ops.ones((25, 1000, 1000), dtype=dtypes.float32)) 358 dataset = dataset.repeat() 359 # Shuffle 25 tensors to exceed the 2GB protocol buffer limit 360 dataset = dataset.shuffle(25) 361 362 iterator = iter(dataset) 363 next(iterator) # request an element to fill the shuffle buffer 364 ckpt = trackable_utils.Checkpoint(iterator=iterator) 365 manager = checkpoint_management.CheckpointManager( 366 ckpt, self.get_temp_dir(), max_to_keep=1) 367 manager.save() 368 ckpt.restore(manager.latest_checkpoint) 369 370 371if __name__ == "__main__": 372 test.main() 373