1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for the input_lib library.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22 23from absl.testing import parameterized 24import numpy as np 25 26from tensorflow.python import tf2 27from tensorflow.python.compat import compat 28from tensorflow.python.data.experimental.ops import data_service_ops 29from tensorflow.python.data.experimental.ops.distribute_options import AutoShardPolicy 30from tensorflow.python.data.experimental.service import server_lib 31from tensorflow.python.data.ops import dataset_ops 32from tensorflow.python.distribute import combinations 33from tensorflow.python.distribute import device_util 34from tensorflow.python.distribute import distribute_lib 35from tensorflow.python.distribute import distribute_utils 36from tensorflow.python.distribute import input_lib 37from tensorflow.python.distribute import multi_worker_util 38from tensorflow.python.distribute import reduce_util 39from tensorflow.python.distribute import strategy_combinations 40from tensorflow.python.distribute import test_util 41from tensorflow.python.eager import context 42from tensorflow.python.eager import def_function 43from tensorflow.python.eager import test 44from tensorflow.python.framework import composite_tensor 45from tensorflow.python.framework import constant_op 46from tensorflow.python.framework import dtypes 47from tensorflow.python.framework import errors 48from tensorflow.python.framework import ops 49from tensorflow.python.framework import sparse_tensor 50from tensorflow.python.ops import array_ops 51from tensorflow.python.ops import control_flow_ops 52from tensorflow.python.ops import math_ops 53from tensorflow.python.ops import sparse_ops 54from tensorflow.python.ops import variables 55from tensorflow.python.ops.ragged import ragged_tensor as ragged_tensor_lib 56from tensorflow.python.util import nest 57 58 59class DistributedIteratorTestBase(test.TestCase): 60 61 # The passed input_context is to create a sharded dataset in between-graph 62 # case. 63 # TODO(yuefengz): rewrite the following method to make it less DRY. 64 def _wrap_iterator(self, 65 input_type, 66 dataset_or_input_fn, 67 input_workers, 68 devices, 69 num_replicas_in_sync, 70 strategy, 71 input_context=None): 72 # The `input_context` passed in is to shard dataset for 73 # MultiWorkerMirroredStrategy. It doesn't apply to in-graph case where 74 # multiple InputContexts are needed. 75 if input_type == "input_fn": 76 self.assertIsNone( 77 input_context, 78 msg=("`The input_context` arg is only used to shard dataset in " 79 "`MultiWorkerMirroredStrategy` when the input type is dataset.")) 80 81 input_contexts = [] 82 for i in range(input_workers.num_workers): 83 input_contexts.append( 84 distribute_lib.InputContext( 85 # Note: `input_workers.num_workers` is always 1 in between-graph 86 # case. 87 num_input_pipelines=input_workers.num_workers, 88 input_pipeline_id=i, 89 num_replicas_in_sync=len(devices))) 90 91 iterator = input_lib.InputFunctionIterator( 92 dataset_or_input_fn, 93 input_workers, 94 input_contexts, 95 strategy) 96 else: 97 iterator = input_lib.DatasetIterator( 98 dataset_or_input_fn, 99 input_workers, 100 strategy, 101 num_replicas_in_sync=num_replicas_in_sync, 102 input_context=input_context) 103 return iterator 104 105 def _wrap_dataset(self, 106 input_type, 107 dataset, 108 input_workers, 109 num_replicas_in_sync, 110 strategy, 111 input_context=None): 112 if input_type == "dataset": 113 if tf2.enabled(): 114 return input_lib.DistributedDataset( 115 dataset, 116 input_workers, 117 strategy, 118 num_replicas_in_sync=num_replicas_in_sync, 119 input_context=input_context) 120 else: 121 return input_lib.DistributedDatasetV1( 122 dataset, 123 input_workers, 124 strategy, 125 num_replicas_in_sync=num_replicas_in_sync, 126 input_context=input_context) 127 else: 128 return strategy.distribute_datasets_from_function(dataset) 129 130 def _assert_iterator_values(self, 131 iterator, 132 expected_values, 133 evaluate_fn, 134 devices, 135 enable_get_next_as_optional=False): 136 actual_values = [] 137 for _ in range(len(expected_values)): 138 if enable_get_next_as_optional: 139 next_element = iterator.get_next_as_optional().get_value() 140 else: 141 next_element = iterator.get_next() 142 computed_value = evaluate_fn([ 143 distribute_utils.select_replica(r, next_element) 144 for r in range(len(devices)) 145 ]) 146 actual_values.append(computed_value) 147 for expected_value, actual_value in zip(expected_values, actual_values): 148 for expected, actual in zip(expected_value, actual_value): 149 self.assertAllEqual(expected, actual) 150 151 def _assert_dataset_values_for_loop(self, dataset, expected_values, 152 evaluate_fn, devices): 153 actual_values = [] 154 for x in dataset: 155 computed_value = self.evaluate( 156 [distribute_utils.select_replica(r, x) for r in range(len(devices))]) 157 actual_values.append(computed_value) 158 for expected_value, actual_value in zip(expected_values, actual_values): 159 for expected, actual in zip(expected_value, actual_value): 160 self.assertAllEqual(expected, actual) 161 162 def _test_input_iteration(self, 163 input_type, 164 api_type, 165 iteration_type, 166 dataset_or_input_fn, 167 worker_device_pairs, 168 expected_values, 169 strategy, 170 sess=None, 171 num_replicas_in_sync=None, 172 input_context=None): 173 if iteration_type == "for_loop" and not context.executing_eagerly(): 174 self.skipTest("unsupported test combination.") 175 176 if api_type == "wrap_into_iterator" and iteration_type == "for_loop": 177 self.skipTest("unsupported test combination.") 178 179 if api_type == "wrap_into_iterator" and input_type == "input_fn": 180 self.skipTest("unsupported test combination.") 181 182 devices = nest.flatten([ds for _, ds in worker_device_pairs]) 183 input_workers = input_lib.InputWorkers(worker_device_pairs) 184 185 if api_type == "wrap_into_iterator": 186 iterator = self._wrap_iterator( 187 input_type, 188 dataset_or_input_fn, 189 input_workers, 190 devices, 191 num_replicas_in_sync, 192 strategy, 193 input_context=input_context) 194 else: 195 # wrapping into a dataset: 196 dataset = self._wrap_dataset( 197 input_type, 198 dataset_or_input_fn, 199 input_workers, 200 num_replicas_in_sync, 201 strategy, 202 input_context=input_context) 203 204 if ops.executing_eagerly_outside_functions(): 205 iterator = iter(dataset) 206 else: 207 if isinstance(dataset, input_lib.DistributedDatasetV1): 208 iterator = dataset.make_initializable_iterator() 209 else: 210 self.skipTest("unsupported test combination") 211 212 if isinstance(iterator, composite_tensor.CompositeTensor): 213 nest.assert_same_structure(iterator, iterator._type_spec, 214 expand_composites=True) 215 216 if iteration_type == "get_next": 217 evaluate = lambda x: sess.run(x) if sess else self.evaluate(x) 218 if not ops.executing_eagerly_outside_functions(): 219 evaluate(control_flow_ops.group(iterator.initializer)) 220 221 def test_get_next(iterator): 222 self._assert_iterator_values(iterator, expected_values, evaluate, 223 devices) 224 225 with self.assertRaises(errors.OutOfRangeError): 226 self._assert_iterator_values(iterator, expected_values, evaluate, 227 devices) 228 229 # After re-initializing the iterator, should be able to iterate again. 230 if not ops.executing_eagerly_outside_functions(): 231 evaluate(control_flow_ops.group(iterator.initializer)) 232 else: 233 if api_type == "wrap_into_iterator": 234 self.skipTest("unsupported test combination") 235 else: 236 iterator = iter(dataset) 237 238 self._assert_iterator_values(iterator, expected_values, evaluate, 239 devices) 240 241 def test_get_next_as_optional(iterator): 242 self._assert_iterator_values( 243 iterator, 244 expected_values, 245 evaluate, 246 devices, 247 enable_get_next_as_optional=True) 248 249 next_element = iterator.get_next_as_optional() 250 self.assertFalse(self.evaluate(next_element.has_value())) 251 with self.assertRaises(errors.InvalidArgumentError): 252 self._assert_iterator_values( 253 iterator, [0], 254 evaluate, 255 devices, 256 enable_get_next_as_optional=True) 257 258 test_get_next(iterator) 259 260 # re-initializing the iterator 261 if not tf2.enabled(): 262 # TODO(yuefengz): we should split this function. 263 return 264 else: 265 if api_type == "wrap_into_iterator": 266 return 267 else: 268 iterator = iter(dataset) 269 270 test_get_next_as_optional(iterator) 271 272 if iteration_type == "for_loop" and context.executing_eagerly(): 273 self._assert_dataset_values_for_loop(dataset, expected_values, 274 self.evaluate, devices) 275 276 def _create_dataset_or_input_fn(self, input_type, input_fn): 277 if input_type == "input_fn": 278 return input_fn 279 else: 280 return input_fn(distribute_lib.InputContext()) 281 282 283class DistributedIteratorTest(DistributedIteratorTestBase, 284 parameterized.TestCase): 285 286 @combinations.generate( 287 combinations.combine( 288 mode=["eager"], 289 input_type=["input_fn", "dataset"], 290 distribution=[ 291 strategy_combinations.one_device_strategy, 292 strategy_combinations.mirrored_strategy_with_one_cpu, 293 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 294 strategy_combinations.multi_worker_mirrored_2x1_cpu 295 ])) 296 def testDisablingOwnedIteratorsInTF2(self, distribution, input_type): 297 if not tf2.enabled(): 298 self.skipTest("unsupported test combination") 299 300 worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] 301 input_workers = input_lib.InputWorkers(worker_device_pairs) 302 dataset_fn = lambda _: dataset_ops.DatasetV2.range(10) 303 dataset_or_input_fn = self._create_dataset_or_input_fn( 304 input_type, dataset_fn) 305 306 input_workers = input_lib.InputWorkers(worker_device_pairs) 307 if input_type == "dataset": 308 dist_dataset = input_lib.get_distributed_dataset(dataset_or_input_fn, 309 input_workers, 310 distribution) 311 else: 312 dist_dataset = input_lib.get_distributed_datasets_from_function( 313 dataset_or_input_fn, input_workers, [distribute_lib.InputContext()], 314 distribution) 315 316 # Default Iterator types in TF2. 317 iterator = iter(dist_dataset) 318 self.assertIsInstance(iterator, input_lib.DistributedIterator) 319 self.assertIsInstance(iterator._iterators[0], 320 input_lib._SingleWorkerOwnedDatasetIterator) 321 322 # Disable creating owned iterators by setting a property on the strategy. 323 distribution._enable_legacy_iterators = True 324 iterator = iter(dist_dataset) 325 self.assertIsInstance(iterator, input_lib.DistributedIteratorV1) 326 self.assertIsInstance(iterator._iterators[0], 327 input_lib._SingleWorkerDatasetIterator) 328 329 @combinations.generate( 330 combinations.combine( 331 mode=["eager"], 332 distribution=[ 333 strategy_combinations.mirrored_strategy_with_gpu_and_cpu 334 ])) 335 def testMultiDeviceIterInitialize(self, distribution): 336 if tf2.enabled(): 337 self.skipTest("Only V1 is supported.") 338 worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0", 339 "/device:CPU:0"])] 340 dataset_fn = lambda _: dataset_ops.DatasetV1.range(10) 341 342 input_workers = input_lib.InputWorkers(worker_device_pairs) 343 344 dist_dataset = input_lib.get_distributed_dataset( 345 dataset_fn(distribute_lib.InputContext()), input_workers, distribution) 346 347 iterator = dataset_ops.make_one_shot_iterator(dist_dataset) 348 349 @def_function.function 350 def init_func_for_iter(): 351 self.evaluate(iterator.initializer) 352 353 init_func_for_iter() 354 355 @combinations.generate( 356 combinations.combine( 357 mode=["graph", "eager"], 358 input_type=["input_fn", "dataset"], 359 api_type=["wrap_into_iterator", "wrap_into_dataset"], 360 iteration_type=["get_next", "for_loop"], 361 distribution=[ 362 strategy_combinations.one_device_strategy, 363 strategy_combinations.mirrored_strategy_with_one_cpu, 364 ], 365 enable_get_next_as_optional=[True, False])) 366 def testOneDeviceCPU(self, input_type, api_type, iteration_type, distribution, 367 enable_get_next_as_optional): 368 worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] 369 dataset_fn = lambda _: dataset_ops.Dataset.range(10) 370 dataset_or_input_fn = self._create_dataset_or_input_fn( 371 input_type, dataset_fn) 372 373 expected_values = [[i] for i in range(10)] 374 375 distribution.extended.experimental_enable_get_next_as_optional = ( 376 enable_get_next_as_optional) 377 self._test_input_iteration(input_type, api_type, iteration_type, 378 dataset_or_input_fn, worker_device_pairs, 379 expected_values, distribution) 380 381 @combinations.generate( 382 combinations.combine( 383 mode=["eager"], 384 input_type=["input_fn", "dataset"], 385 api_type=["wrap_into_dataset"], 386 iteration_type=["get_next", "for_loop"], 387 distribution=[strategy_combinations.multi_worker_mirrored_2x1_cpu], 388 enable_get_next_as_optional=[True, False])) 389 def testOneDeviceCPUMultiWorker(self, input_type, api_type, iteration_type, 390 distribution, enable_get_next_as_optional): 391 worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] 392 dataset_fn = lambda _: dataset_ops.DatasetV1.range(10) 393 dataset_or_input_fn = self._create_dataset_or_input_fn( 394 input_type, dataset_fn) 395 396 expected_values = [[i] for i in range(10)] 397 398 distribution.extended.experimental_enable_get_next_as_optional = ( 399 enable_get_next_as_optional) 400 self._test_input_iteration( 401 input_type, 402 api_type, 403 iteration_type, 404 dataset_or_input_fn, 405 worker_device_pairs, 406 expected_values, 407 distribution) 408 409 @combinations.generate( 410 combinations.combine( 411 mode=["graph", "eager"], 412 input_type=["input_fn", "dataset"], 413 api_type=["wrap_into_iterator", "wrap_into_dataset"], 414 iteration_type=["get_next", "for_loop"], 415 distribution=[ 416 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 417 strategy_combinations.central_storage_strategy_with_gpu_and_cpu 418 ], 419 enable_get_next_as_optional=[True, False])) 420 def testTwoDevicesOneGPUOneCPU(self, input_type, api_type, iteration_type, 421 distribution, enable_get_next_as_optional): 422 worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0", 423 "/device:CPU:0"])] 424 dataset_fn = lambda _: dataset_ops.Dataset.range(10) 425 dataset_or_input_fn = self._create_dataset_or_input_fn( 426 input_type, dataset_fn) 427 428 expected_values = [[i, i+1] for i in range(0, 10, 2)] 429 430 distribution.extended.experimental_enable_get_next_as_optional = ( 431 enable_get_next_as_optional) 432 self._test_input_iteration( 433 input_type, 434 api_type, 435 iteration_type, 436 dataset_or_input_fn, 437 worker_device_pairs, 438 expected_values, 439 distribution) 440 441 @combinations.generate( 442 combinations.combine( 443 mode=["graph", "eager"], 444 input_type=["input_fn", "dataset"], 445 api_type=["wrap_into_iterator", "wrap_into_dataset"], 446 iteration_type=["get_next", "for_loop"], 447 distribution=[strategy_combinations.tpu_strategy], 448 enable_get_next_as_optional=[True, False])) 449 def testTPU(self, input_type, api_type, iteration_type, distribution, 450 enable_get_next_as_optional): 451 worker_device_pairs = collections.OrderedDict() 452 for tpu_device in distribution.extended.worker_devices: 453 host_device = device_util.get_host_for_device(tpu_device) 454 worker_device_pairs.setdefault(host_device, []) 455 worker_device_pairs[host_device].append(tpu_device) 456 worker_device_pairs = worker_device_pairs.items() 457 dataset_fn = lambda _: dataset_ops.Dataset.range(10) 458 dataset_or_input_fn = self._create_dataset_or_input_fn( 459 input_type, dataset_fn) 460 461 expected_values = [[i, i + 1] for i in range(0, 10, 2)] 462 463 distribution.extended.experimental_enable_get_next_as_optional = ( 464 enable_get_next_as_optional) 465 self._test_input_iteration( 466 input_type, 467 api_type, 468 iteration_type, 469 dataset_or_input_fn, 470 worker_device_pairs, 471 expected_values, 472 distribution) 473 474 @combinations.generate( 475 combinations.combine( 476 mode=["graph", "eager"], 477 input_type=["input_fn", "dataset"], 478 api_type=["wrap_into_iterator", "wrap_into_dataset"], 479 iteration_type=["get_next", "for_loop"], 480 distribution=[ 481 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 482 strategy_combinations.central_storage_strategy_with_gpu_and_cpu, 483 ], 484 enable_get_next_as_optional=[True, False])) 485 def testTupleDataset(self, input_type, api_type, iteration_type, distribution, 486 enable_get_next_as_optional): 487 worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0", 488 "/device:CPU:0"])] 489 490 def dataset_fn(ctx): 491 del ctx 492 dataset1 = dataset_ops.Dataset.range(10) 493 dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2) 494 return dataset_ops.Dataset.zip((dataset1, dataset2)) 495 496 dataset_or_input_fn = self._create_dataset_or_input_fn( 497 input_type, dataset_fn) 498 499 expected_values = [[(i, i**2), (i+1, (i+1)**2)] for i in range(0, 10, 2)] 500 501 distribution.extended.experimental_enable_get_next_as_optional = ( 502 enable_get_next_as_optional) 503 self._test_input_iteration( 504 input_type, 505 api_type, 506 iteration_type, 507 dataset_or_input_fn, 508 worker_device_pairs, 509 expected_values, 510 distribution) 511 512 @combinations.generate( 513 combinations.combine( 514 mode=["eager"], 515 input_type=["input_fn", "dataset"], 516 api_type=["wrap_into_dataset"], 517 iteration_type=["get_next", "for_loop"], 518 distribution=[strategy_combinations.multi_worker_mirrored_2x2_gpu], 519 enable_get_next_as_optional=[True, False])) 520 def testTupleDatasetMultiworker(self, input_type, api_type, iteration_type, 521 distribution, enable_get_next_as_optional): 522 worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0", 523 "/device:GPU:1"])] 524 525 def dataset_fn(ctx): 526 del ctx 527 dataset1 = dataset_ops.Dataset.range(10) 528 dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2) 529 return dataset_ops.Dataset.zip((dataset1, dataset2)) 530 531 dataset_or_input_fn = self._create_dataset_or_input_fn( 532 input_type, dataset_fn) 533 534 expected_values = [ 535 [(i, i**2), (i + 1, (i + 1)**2)] for i in range(0, 10, 2) 536 ] 537 538 distribution.extended.experimental_enable_get_next_as_optional = ( 539 enable_get_next_as_optional) 540 541 # Input_context is not passed in and thus no sharding. 542 self._test_input_iteration(input_type, api_type, iteration_type, 543 dataset_or_input_fn, worker_device_pairs, 544 expected_values, distribution) 545 546 @combinations.generate( 547 combinations.combine( 548 mode=["eager"], 549 distribution=[ 550 strategy_combinations.one_device_strategy, 551 strategy_combinations.mirrored_strategy_with_one_cpu, 552 strategy_combinations.multi_worker_mirrored_2x1_cpu, 553 ])) 554 def testIterableIterator(self, distribution): 555 worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] 556 input_workers = input_lib.InputWorkers(worker_device_pairs) 557 558 dataset = dataset_ops.Dataset.range(10) 559 dist_dataset = input_lib.get_distributed_dataset(dataset, input_workers, 560 distribution) 561 562 iterator = iter(dist_dataset) 563 for i, element in enumerate(iterator): 564 self.assertAllEqual(distribution.experimental_local_results(element), [i]) 565 566 @combinations.generate( 567 combinations.combine( 568 mode=["graph", "eager"], 569 input_type=["input_fn", "dataset"], 570 api_type=["wrap_into_iterator", "wrap_into_dataset"], 571 iteration_type=["get_next", "for_loop"], 572 drop_remainder=[True, False], 573 distribution=[ 574 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 575 strategy_combinations.central_storage_strategy_with_gpu_and_cpu 576 ])) 577 def testUnevenDatasetBatches(self, input_type, api_type, iteration_type, 578 drop_remainder, distribution): 579 worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0", 580 "/device:CPU:0"])] 581 dataset_fn = lambda _: dataset_ops.Dataset.range(9).batch( # pylint: disable=g-long-lambda 582 2, drop_remainder=drop_remainder) 583 dataset_or_input_fn = self._create_dataset_or_input_fn( 584 input_type, dataset_fn) 585 586 # The last global batch only contains data for one replica. 587 if drop_remainder: 588 expected_values = [[[0, 1], [2, 3]], [[4, 5], [6, 7]]] 589 else: 590 expected_values = [[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[8], []]] 591 distribution.extended.experimental_enable_get_next_as_optional = True 592 self._test_input_iteration( 593 input_type, 594 api_type, 595 iteration_type, 596 dataset_or_input_fn, 597 worker_device_pairs, 598 expected_values, 599 distribution) 600 601 @combinations.generate( 602 combinations.combine( 603 mode=["eager"], 604 input_type=["input_fn", "dataset"], 605 api_type=["wrap_into_dataset"], 606 iteration_type=["get_next", "for_loop"], 607 drop_remainder=[True, False], 608 distribution=[ 609 strategy_combinations.multi_worker_mirrored_2x1_cpu, 610 strategy_combinations.multi_worker_mirrored_2x1_gpu, 611 ])) 612 def testUnevenDatasetBatchesMultiWorker(self, input_type, api_type, 613 iteration_type, drop_remainder, 614 distribution): 615 # Actual devices don't matter in this test as long as the number of global 616 # repices is 2. 617 worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] 618 cr = distribution.cluster_resolver 619 self.assertIsNotNone(cr) 620 worker_count = multi_worker_util.worker_count(cr.cluster_spec(), 621 cr.task_type) 622 id_in_cluster = multi_worker_util.id_in_cluster(cr.cluster_spec(), 623 cr.task_type, cr.task_id) 624 625 def dataset_fn(_): 626 dataset = dataset_ops.Dataset.range(9) 627 628 if input_type == "input_fn": 629 # When input_fn is used, there is no automatic rebatching and sharding, 630 # so we add them here. 631 return dataset.shard(worker_count, id_in_cluster).batch(1) 632 else: 633 return dataset.batch(2, drop_remainder=drop_remainder) 634 635 dataset_or_input_fn = self._create_dataset_or_input_fn( 636 input_type, dataset_fn) 637 638 if drop_remainder and input_type == "dataset": 639 if id_in_cluster == 0: 640 expected_values = [[[0]], [[2]], [[4]], [[6]]] 641 else: 642 expected_values = [[[1]], [[3]], [[5]], [[7]]] 643 else: 644 # The last global batch only contains data for one replica. 645 if id_in_cluster == 0: 646 expected_values = [[[0]], [[2]], [[4]], [[6]], [[8]]] 647 else: 648 expected_values = [[[1]], [[3]], [[5]], [[7]], [[]]] 649 distribution.extended.experimental_enable_get_next_as_optional = True 650 self._test_input_iteration( 651 input_type, 652 api_type, 653 iteration_type, 654 dataset_or_input_fn, 655 worker_device_pairs, 656 expected_values, 657 distribution, 658 num_replicas_in_sync=distribution.num_replicas_in_sync, 659 input_context=distribution.extended._make_input_context()) 660 661 @combinations.generate( 662 combinations.combine( 663 mode=["eager"], 664 input_type=["input_fn", "dataset"], 665 api_type=["wrap_into_dataset"], 666 iteration_type=["get_next", "for_loop"], 667 drop_remainder=[True, False], 668 distribution=[ 669 strategy_combinations.multi_worker_mirrored_2x2_gpu, 670 ])) 671 def testUnevenDatasetBatchesMultiWorkerFourReplicas(self, input_type, 672 api_type, iteration_type, 673 drop_remainder, 674 distribution): 675 # Actual devices don't matter in this test as long as the number of global 676 # repices is 2. 677 worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0", 678 "/device:GPU:1"])] 679 cr = distribution.cluster_resolver 680 self.assertIsNotNone(cr) 681 worker_count = multi_worker_util.worker_count(cr.cluster_spec(), 682 cr.task_type) 683 id_in_cluster = multi_worker_util.id_in_cluster(cr.cluster_spec(), 684 cr.task_type, cr.task_id) 685 686 def dataset_fn(_): 687 dataset = dataset_ops.Dataset.range(15) 688 689 if input_type == "input_fn": 690 # When input_fn is used, there is no automatic rebatching and sharding, 691 # so we add them here. 692 return dataset.shard(worker_count, id_in_cluster).batch(1) 693 else: 694 return dataset.batch(4, drop_remainder=drop_remainder) 695 696 dataset_or_input_fn = self._create_dataset_or_input_fn( 697 input_type, dataset_fn) 698 699 # The last global batch only contains data for one replica. 700 if drop_remainder and input_type == "dataset": 701 if id_in_cluster == 0: 702 expected_values = [[[0], [2]], [[4], [6]], [[8], [10]]] 703 else: 704 expected_values = [[[1], [3]], [[5], [7]], [[9], [11]]] 705 else: 706 if id_in_cluster == 0: 707 expected_values = [[[0], [2]], [[4], [6]], [[8], [10]], [[12], [14]]] 708 else: 709 expected_values = [[[1], [3]], [[5], [7]], [[9], [11]], [[13], []]] 710 distribution.extended.experimental_enable_get_next_as_optional = True 711 self._test_input_iteration( 712 input_type, 713 api_type, 714 iteration_type, 715 dataset_or_input_fn, 716 worker_device_pairs, 717 expected_values, 718 distribution, 719 num_replicas_in_sync=distribution.num_replicas_in_sync, 720 input_context=distribution.extended._make_input_context()) 721 722 @combinations.generate( 723 combinations.combine( 724 mode=["graph", "eager"], 725 input_type=["dataset"], 726 api_type=["wrap_into_iterator", "wrap_into_dataset"], 727 iteration_type=["get_next", "for_loop"], 728 num_replicas_in_sync=[None, 2], 729 distribution=[ 730 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 731 strategy_combinations.central_storage_strategy_with_gpu_and_cpu 732 ], 733 enable_get_next_as_optional=[True, False])) 734 def testBatchSplitting(self, input_type, api_type, iteration_type, 735 num_replicas_in_sync, distribution, 736 enable_get_next_as_optional): 737 worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0", 738 "/device:CPU:0"])] 739 batch_size = 10 740 dataset_fn = lambda _: dataset_ops.Dataset.range(100).batch(batch_size) 741 dataset_or_input_fn = self._create_dataset_or_input_fn( 742 input_type, dataset_fn) 743 744 updated_batch_size = ( 745 batch_size // 746 num_replicas_in_sync if num_replicas_in_sync else batch_size) 747 expected_values = [[range(i, i+updated_batch_size), 748 range(i+updated_batch_size, i+2*updated_batch_size)] 749 for i in range(0, 100, updated_batch_size*2)] 750 751 distribution.extended.experimental_enable_get_next_as_optional = ( 752 enable_get_next_as_optional) 753 self._test_input_iteration( 754 input_type, 755 api_type, 756 iteration_type, 757 dataset_or_input_fn, 758 worker_device_pairs, 759 expected_values, 760 distribution, 761 sess=None, 762 num_replicas_in_sync=num_replicas_in_sync) 763 764 @combinations.generate( 765 combinations.combine( 766 mode=["eager"], 767 input_type=["dataset"], 768 api_type=["wrap_into_dataset"], 769 iteration_type=["get_next", "for_loop"], 770 num_replicas_in_sync=[None, 2], 771 distribution=[ 772 strategy_combinations.multi_worker_mirrored_2x2_gpu, 773 ], 774 enable_get_next_as_optional=[True, False])) 775 def testBatchSplittingMultiWorker(self, input_type, api_type, iteration_type, 776 num_replicas_in_sync, distribution, 777 enable_get_next_as_optional): 778 worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0", 779 "/device:GPU:1"])] 780 batch_size = 10 781 cr = distribution.cluster_resolver 782 self.assertIsNotNone(cr) 783 784 def dataset_fn(_): 785 dataset = dataset_ops.Dataset.range(100).batch(batch_size) 786 return dataset 787 788 dataset_or_input_fn = self._create_dataset_or_input_fn( 789 input_type, dataset_fn) 790 791 updated_batch_size = ( 792 batch_size // 793 num_replicas_in_sync if num_replicas_in_sync else batch_size) 794 expected_values = [ 795 [ # pylint: disable=g-complex-comprehension 796 range(i, i + updated_batch_size), 797 range(i + updated_batch_size, i + 2 * updated_batch_size) 798 ] for i in range(0, 100, updated_batch_size * 2) 799 ] 800 801 distribution.extended.experimental_enable_get_next_as_optional = ( 802 enable_get_next_as_optional) 803 self._test_input_iteration( 804 input_type, 805 api_type, 806 iteration_type, 807 dataset_or_input_fn, 808 worker_device_pairs, 809 expected_values, 810 distribution, 811 sess=None, 812 num_replicas_in_sync=num_replicas_in_sync) 813 814 @combinations.generate( 815 combinations.combine( 816 mode=["eager"], 817 distribution=[ 818 strategy_combinations.one_device_strategy, 819 strategy_combinations.mirrored_strategy_with_one_cpu, 820 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 821 strategy_combinations.tpu_strategy, 822 strategy_combinations.central_storage_strategy_with_two_gpus, 823 strategy_combinations.multi_worker_mirrored_2x2_gpu, 824 strategy_combinations.multi_worker_mirrored_2x1_cpu, 825 ], 826 )) 827 def testCacheAcrossIteration(self, distribution): 828 if not tf2.enabled(): 829 self.skipTest("Only V2 is supported.") 830 831 dataset = dataset_ops.Dataset.range(16).shuffle(16).cache().batch(4) 832 dist_dataset = distribution.experimental_distribute_dataset(dataset) 833 834 first_epoch = list( 835 distribution.experimental_local_results(x) for x in dist_dataset) 836 second_epoch = list( 837 distribution.experimental_local_results(x) for x in dist_dataset) 838 839 self.assertAllEqual(first_epoch, second_epoch) 840 841 @combinations.generate( 842 combinations.combine( 843 mode=["eager"], 844 distribution=[ 845 strategy_combinations.one_device_strategy, 846 strategy_combinations.mirrored_strategy_with_one_cpu, 847 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 848 strategy_combinations.tpu_strategy, 849 strategy_combinations.central_storage_strategy_with_two_gpus, 850 strategy_combinations.multi_worker_mirrored_2x2_gpu, 851 strategy_combinations.multi_worker_mirrored_2x1_cpu, 852 ], 853 reshuffle=[True, False])) 854 def testShuffleAcrossIterations(self, distribution, reshuffle): 855 if not tf2.enabled(): 856 self.skipTest("Only V2 is supported.") 857 858 if not reshuffle and not compat.forward_compatible(2020, 5, 22): 859 self.skipTest("Functionality currently not supported.") 860 861 dataset = dataset_ops.Dataset.range(12).shuffle( 862 12, reshuffle_each_iteration=reshuffle).batch(4) 863 dist_dataset = distribution.experimental_distribute_dataset(dataset) 864 865 first_epoch = list( 866 distribution.experimental_local_results(x) for x in dist_dataset) 867 second_epoch = list( 868 distribution.experimental_local_results(x) for x in dist_dataset) 869 870 if reshuffle: 871 self.assertNotAllEqual(first_epoch, second_epoch) 872 else: 873 self.assertAllEqual(first_epoch, second_epoch) 874 875 @combinations.generate( 876 combinations.combine( 877 mode=["eager"], 878 distribution=[ 879 strategy_combinations.one_device_strategy, 880 strategy_combinations.mirrored_strategy_with_one_cpu, 881 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 882 strategy_combinations.tpu_strategy, 883 strategy_combinations.central_storage_strategy_with_two_gpus, 884 strategy_combinations.multi_worker_mirrored_2x2_gpu, 885 strategy_combinations.multi_worker_mirrored_2x1_cpu, 886 ])) 887 def testGetNextOptionalShape(self, distribution): 888 batch_size = 8 889 dataset = dataset_ops.DatasetV2.from_tensor_slices({ 890 "feature": array_ops.ones([batch_size, 10]), 891 "label": array_ops.ones([batch_size]), 892 }) 893 dataset = dataset.batch(batch_size, drop_remainder=True) 894 dist_dataset = distribution.experimental_distribute_dataset(dataset) 895 per_replica_batch_size = batch_size // distribution.num_replicas_in_sync 896 897 @def_function.function 898 def train_fn(): 899 for data in dist_dataset: 900 data = nest.map_structure(distribution.experimental_local_results, data) 901 feature = data["feature"] 902 label = data["label"] 903 904 # Assert the shapes are still static from all replicas. 905 for replica_id in range(len(distribution.extended.worker_devices)): 906 self.assertEqual([per_replica_batch_size, 10], 907 feature[replica_id].shape) 908 self.assertEqual([per_replica_batch_size], label[replica_id].shape) 909 910 train_fn() 911 912 @combinations.generate( 913 combinations.combine( 914 mode=["eager"], 915 distribution=[ 916 strategy_combinations.multi_worker_mirrored_2x1_cpu, 917 ], 918 input_type=["dataset"], 919 api_type=["wrap_into_iterator", "wrap_into_dataset"], 920 iteration_type=["get_next", "for_loop"], 921 auto_shard_policy=[AutoShardPolicy.AUTO, AutoShardPolicy.OFF])) 922 def testAutoshardingOption(self, distribution, input_type, api_type, 923 iteration_type, auto_shard_policy): 924 cr = distribution.cluster_resolver 925 self.assertIsNotNone(cr) 926 id_in_cluster = multi_worker_util.id_in_cluster(cr.cluster_spec(), 927 cr.task_type, cr.task_id) 928 ds_option = dataset_ops.Options() 929 ds_option.experimental_distribute.auto_shard_policy = auto_shard_policy 930 dataset_fn = ( 931 lambda _: dataset_ops.Dataset.range(4).with_options(ds_option)) 932 dataset_or_input_fn = self._create_dataset_or_input_fn( 933 input_type, dataset_fn) 934 935 worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] 936 if auto_shard_policy == AutoShardPolicy.AUTO: 937 if id_in_cluster == 0: 938 expected_values = [[0], [2]] 939 else: 940 expected_values = [[1], [3]] 941 else: 942 expected_values = [[0], [1], [2], [3]] 943 self._test_input_iteration( 944 input_type, 945 api_type, 946 iteration_type, 947 dataset_or_input_fn, 948 worker_device_pairs, 949 expected_values, 950 distribution, 951 input_context=distribution.extended._make_input_context()) 952 953 @combinations.generate( 954 combinations.combine( 955 mode=["eager"], 956 distribution=[ 957 strategy_combinations.multi_worker_mirrored_2x1_cpu, 958 ], 959 input_type=["input_fn"], 960 api_type=["wrap_into_dataset"], 961 iteration_type=["get_next", "for_loop"])) 962 def testDifferentDatasetsMultiWorker(self, distribution, input_type, api_type, 963 iteration_type): 964 cr = distribution.cluster_resolver 965 self.assertIsNotNone(cr) 966 id_in_cluster = multi_worker_util.id_in_cluster(cr.cluster_spec(), 967 cr.task_type, cr.task_id) 968 969 def dataset_fn(ctx): 970 if ctx.input_pipeline_id == 0: 971 return dataset_ops.Dataset.range(8).batch(2) 972 else: 973 return dataset_ops.Dataset.range(9).batch(2) 974 975 dataset_or_input_fn = self._create_dataset_or_input_fn( 976 input_type, dataset_fn) 977 978 worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] 979 980 if id_in_cluster == 0: 981 expected_values = [[[0, 1]], [[2, 3]], [[4, 5]], [[6, 7]], [[]]] 982 else: 983 expected_values = [[[0, 1]], [[2, 3]], [[4, 5]], [[6, 7]], [[8]]] 984 distribution.extended.experimental_enable_get_next_as_optional = True 985 self._test_input_iteration(input_type, api_type, iteration_type, 986 dataset_or_input_fn, worker_device_pairs, 987 expected_values, distribution) 988 989 @combinations.generate( 990 combinations.combine( 991 strategy=[ 992 strategy_combinations.multi_worker_mirrored_2x1_cpu, 993 strategy_combinations.multi_worker_mirrored_2x1_gpu, 994 ], 995 mode=["eager"])) 996 def testLoopOverDatasetInTFFunction(self, strategy): 997 dataset = dataset_ops.Dataset.range(10).map(lambda x: { # pylint: disable=g-long-lambda 998 "y": math_ops.cast(x, dtypes.float32) ** 2, 999 }).batch(4) 1000 dist_dataset = strategy.experimental_distribute_dataset(dataset) 1001 1002 with strategy.scope(): 1003 v = variables.Variable(0.0, aggregation=variables.VariableAggregation.SUM) 1004 1005 @def_function.function 1006 def iterator_fn(dist_dataset): 1007 1008 def assign_add_fn(data): 1009 v.assign_add(math_ops.reduce_sum(data["y"])) 1010 1011 for data in dist_dataset: 1012 strategy.run(assign_add_fn, args=(data,)) 1013 1014 iterator_fn(dist_dataset) 1015 self.assertEqual(v.numpy(), 285.0) 1016 1017 1018class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase, 1019 parameterized.TestCase): 1020 """Tests for DistributedDataset with non-dense tensors.""" 1021 1022 @combinations.generate( 1023 combinations.combine( 1024 mode=["eager"], 1025 distribution=[ 1026 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 1027 strategy_combinations.central_storage_strategy_with_gpu_and_cpu, 1028 ], 1029 input_type=["dataset", "input_fn"], 1030 drop_remainder=[False, True], 1031 defun_type=["lambda", "tf_function"], 1032 )) 1033 def testRaggedSparse(self, distribution, input_type, drop_remainder, 1034 defun_type): 1035 """Test with `RaggedTensor`s and `SparseTensor`s.""" 1036 if not tf2.enabled(): 1037 self.skipTest("Only V2 is supported.") 1038 1039 defun = {"lambda": lambda f: f, 1040 "tf_function": def_function.function}[defun_type] 1041 distribution.extended.experimental_enable_get_next_as_optional = True 1042 global_batch_size = 8 1043 1044 def dataset_fn(ctx=None): 1045 ctx = ctx or distribute_lib.InputContext() 1046 batch_size = ctx.get_per_replica_batch_size(global_batch_size) 1047 # Use 20 which isn't divisible by 8 to test partial batch behavior. 1048 row_lengths = np.mod(np.arange(20), 4).astype(np.int64) 1049 ragged_tensor = ragged_tensor_lib.RaggedTensor.from_row_lengths( 1050 np.repeat(np.arange(20, dtype=np.float32), row_lengths), row_lengths) 1051 dataset = dataset_ops.DatasetV2.from_tensor_slices({ 1052 "dense": ragged_tensor.to_tensor(), 1053 "ragged": ragged_tensor, 1054 "sparse": ragged_tensor.to_sparse(), 1055 }) 1056 dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id) 1057 return dataset.batch(batch_size, drop_remainder=drop_remainder) 1058 1059 dataset_or_input_fn = self._create_dataset_or_input_fn( 1060 input_type, dataset_fn) 1061 dataset = self._wrap_dataset(input_type, dataset_or_input_fn, 1062 distribution.extended._input_workers, 1063 len(distribution.extended.worker_devices), 1064 distribution) 1065 # Assert that the tensors are rebatched and sparsity is preserved. 1066 per_replica_batch = defun(lambda x: next(iter(x)))(dataset) 1067 self.assertAllEqual( 1068 distribute_utils.select_replica(0, per_replica_batch["dense"]), 1069 [[0., 0., 0.], [1., 0., 0.], [2., 2., 0.], [3., 3., 3.]]) 1070 self.assertAllEqual( 1071 distribute_utils.select_replica(1, per_replica_batch["dense"]), 1072 [[0., 0., 0.], [5., 0., 0.], [6., 6., 0.], [7., 7., 7.]]) 1073 # Transitively check the ragged and sparse tensors by densification. 1074 for i in range(2): 1075 self.assertLen( 1076 distribute_utils.select_replica(i, 1077 per_replica_batch["ragged"]).values, 1078 6) 1079 self.assertAllEqual( 1080 distribute_utils.select_replica( 1081 i, per_replica_batch["ragged"]).to_tensor(), 1082 distribute_utils.select_replica(i, per_replica_batch["dense"])) 1083 self.assertLen( 1084 distribute_utils.select_replica(i, 1085 per_replica_batch["sparse"]).indices, 1086 6) 1087 self.assertAllEqual( 1088 sparse_ops.sparse_tensor_to_dense( 1089 distribute_utils.select_replica(i, per_replica_batch["sparse"])), 1090 distribute_utils.select_replica(i, per_replica_batch["dense"])) 1091 # Iterate through all the batches and sum them up. 1092 def sum_batch(per_replica_features): 1093 """Sums the `PerReplica` values in the `per_replica_features` map.""" 1094 1095 def map_fn(per_replica_values): 1096 per_replica_sums = distribution.run( 1097 (lambda x: math_ops.reduce_sum(x.values)) if all( 1098 map(sparse_tensor.is_sparse, per_replica_values.values)) else 1099 math_ops.reduce_sum, (per_replica_values,)) 1100 return distribution.reduce( 1101 reduce_util.ReduceOp.SUM, per_replica_sums, axis=None) 1102 1103 return nest.map_structure(map_fn, per_replica_features) 1104 1105 def _reduce(state, batch): 1106 sums = sum_batch(batch) 1107 return {name: value + sums[name] for name, value in state.items()} 1108 1109 def sum_for_loop(dataset): 1110 sums = {"dense": 0., "ragged": 0., "sparse": 0.} 1111 for batch in dataset: 1112 sums = _reduce(sums, batch) 1113 return sums 1114 1115 def sum_while_loop(iterator, reduce_fn): 1116 sums = {"dense": 0., "ragged": 0., "sparse": 0.} 1117 while True: 1118 try: 1119 sums = reduce_fn(sums, iterator) 1120 except (StopIteration, errors.OutOfRangeError): 1121 return sums 1122 1123 while_sums = sum_while_loop( 1124 iter(dataset), 1125 defun(lambda state, iterator: _reduce(state, next(iterator)))) 1126 self.assertAllEqual( 1127 nest.flatten(while_sums), 1128 # When there's no partial batch, the sum is smaller. 1129 [200. if drop_remainder else 310.] * 3) 1130 for_sums = defun(sum_for_loop)(dataset) 1131 # For loops always call get next as optional inside tf functions, so we 1132 # expect 310 here when using an input function (as there are 5 batches of 1133 # size 4 round robined over 2 replicas. 1134 expected_for_sum = 200. 1135 if (not drop_remainder or ( 1136 defun_type == "tf_function" and input_type == "input_fn")): 1137 expected_for_sum = 310. 1138 self.assertAllEqual(nest.flatten(for_sums), [expected_for_sum] * 3) 1139 1140 @combinations.generate( 1141 combinations.combine( 1142 mode=["eager"], 1143 distribution=[ 1144 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 1145 strategy_combinations.central_storage_strategy_with_gpu_and_cpu, 1146 strategy_combinations.one_device_strategy, 1147 strategy_combinations.mirrored_strategy_with_one_cpu 1148 ], 1149 input_type=["dataset", "input_fn"], 1150 drop_remainder=[False, True], 1151 tensor_type=["sparse", "ragged"], 1152 enable_get_next_as_optional=[True, False] 1153 )) 1154 def testRaggedSparseGetNextAsOptional( 1155 self, distribution, input_type, drop_remainder, tensor_type, 1156 enable_get_next_as_optional): 1157 """Test with `RaggedTensor`s and `SparseTensor`s.""" 1158 if not tf2.enabled(): 1159 self.skipTest("Only V2 is supported.") 1160 1161 distribution.extended.experimental_enable_get_next_as_optional = ( 1162 enable_get_next_as_optional) 1163 global_batch_size = 8 1164 1165 def dataset_fn(ctx=None): 1166 ctx = ctx or distribute_lib.InputContext() 1167 batch_size = ctx.get_per_replica_batch_size(global_batch_size) 1168 # Use 20 which isn't divisible by 8 to test partial batch behavior. 1169 row_lengths = np.mod(np.arange(20), 4).astype(np.int64) 1170 ragged_tensor = ragged_tensor_lib.RaggedTensor.from_row_lengths( 1171 np.repeat(np.arange(20, dtype=np.float32), row_lengths), row_lengths) 1172 dataset = dataset_ops.DatasetV2.from_tensor_slices({ 1173 tensor_type: (ragged_tensor if tensor_type == "ragged" else 1174 ragged_tensor.to_sparse()), 1175 }) 1176 dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id) 1177 return dataset.batch(batch_size, drop_remainder=drop_remainder) 1178 1179 if input_type == "dataset": 1180 ds = distribution.experimental_distribute_dataset( 1181 dataset_fn(distribute_lib.InputContext())) 1182 else: 1183 ds = distribution.distribute_datasets_from_function(dataset_fn) 1184 iterator = iter(ds) 1185 1186 self.assertEqual(iterator._enable_get_next_as_optional, 1187 (not drop_remainder) and enable_get_next_as_optional) 1188 1189 @combinations.generate( 1190 combinations.combine( 1191 tf_api_version=2, 1192 mode=["eager"], 1193 distribution=[ 1194 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 1195 strategy_combinations.central_storage_strategy_with_gpu_and_cpu, 1196 strategy_combinations.one_device_strategy, 1197 strategy_combinations.mirrored_strategy_with_one_cpu, 1198 # TODO(mdan): Add these? 1199 # strategy_combinations.multi_worker_mirrored_2x1_cpu, 1200 # strategy_combinations.multi_worker_mirrored_2x1_gpu, 1201 # strategy_combinations.multi_worker_mirrored_2x2_gpu, 1202 ], 1203 input_type=["dataset", "input_fn"], 1204 drop_remainder=[False, True], 1205 )) 1206 def testRaggedSparseGetNextAsOptionalInLoop( 1207 self, distribution, input_type, drop_remainder): 1208 """Test with `RaggedTensor`s and `SparseTensor`s.""" 1209 self.skipTest("b/323359921") 1210 1211 global_batch_size = 8 1212 1213 def dataset_fn(ctx=None): 1214 ctx = ctx or distribute_lib.InputContext() 1215 batch_size = ctx.get_per_replica_batch_size(global_batch_size) 1216 # Use 20 which isn't divisible by 8 to test partial batch behavior. 1217 row_lengths = np.mod(np.arange(20), 4).astype(np.int64) 1218 ragged_tensor = ragged_tensor_lib.RaggedTensor.from_row_lengths( 1219 np.repeat(np.arange(20, dtype=np.float32), row_lengths), row_lengths) 1220 dataset = dataset_ops.DatasetV2.from_tensor_slices({ 1221 "dense": ragged_tensor.to_tensor(), 1222 "ragged": ragged_tensor, 1223 "sparse": ragged_tensor.to_sparse(), 1224 }) 1225 dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id) 1226 return dataset.batch(batch_size, drop_remainder=drop_remainder) 1227 1228 if input_type == "dataset": 1229 ds = distribution.experimental_distribute_dataset( 1230 dataset_fn(distribute_lib.InputContext())) 1231 else: 1232 ds = distribution.distribute_datasets_from_function(dataset_fn) 1233 1234 # Iterate through all the batches and sum them up. 1235 def sum_batch(per_replica_features): 1236 """Sums the `PerReplica` values in the `per_replica_features` map.""" 1237 1238 def map_fn(per_replica_values): 1239 per_replica_sums = distribution.run( 1240 (lambda x: math_ops.reduce_sum(x.values)) if all( 1241 map(sparse_tensor.is_sparse, per_replica_values.values)) else 1242 math_ops.reduce_sum, (per_replica_values,)) 1243 return distribution.reduce( 1244 reduce_util.ReduceOp.SUM, per_replica_sums, axis=None) 1245 1246 return nest.map_structure(map_fn, per_replica_features) 1247 1248 def _reduce(state, batch): 1249 sums = sum_batch(batch) 1250 return {name: value + sums[name] for name, value in state.items()} 1251 1252 def sum_while_loop(ds): 1253 iterator = iter(ds) 1254 sums = {"dense": 0., "ragged": 0., "sparse": 0.} 1255 try_next = constant_op.constant(True) 1256 1257 while try_next: 1258 opt_iterate = iterator.get_next_as_optional() 1259 if opt_iterate.has_value(): 1260 sums = _reduce(sums, opt_iterate.get_value()) 1261 else: 1262 try_next = False 1263 return sums 1264 1265 sums = def_function.function(sum_while_loop)(ds) 1266 # For loops always call get next as optional inside tf functions, so we 1267 # expect 310 here when using an input function (as there are 5 batches of 1268 # size 4 round robined over 2 replicas. 1269 expected_for_sum = 200. 1270 if not drop_remainder or input_type == "input_fn": 1271 expected_for_sum = 310. 1272 self.assertAllEqual(nest.flatten(sums), [expected_for_sum] * 3) 1273 1274 @combinations.generate( 1275 combinations.combine( 1276 mode=["eager"], 1277 input_type=["dataset"], 1278 api_type=["wrap_into_iterator", "wrap_into_dataset"], 1279 iteration_type=["get_next", "for_loop"], 1280 distribution=[ 1281 strategy_combinations.multi_worker_mirrored_2x1_cpu, 1282 strategy_combinations.multi_worker_mirrored_2x1_gpu, 1283 ])) 1284 def testMWMSPartialBatch(self, input_type, api_type, iteration_type, 1285 distribution): 1286 # Test case: 2 workers, 1 replica each. 1287 # This test simulates the sharded behavior when we have two files each with 1288 # 12 elements and a global batch size of 8. When we consider the dataset in 1289 # aggregate (non-distributed), there are 24 elements divided into 3 batches 1290 # of size 8. Hence, the correct distributed behavior is for each replica to 1291 # see sub-batches of size 4, over three steps. 1292 def dataset_fn(ctx): 1293 del ctx 1294 dataset = dataset_ops.Dataset.range(12).batch(8) 1295 1296 # Set the sharding behavior to OFF for simplicity of test setup; namely, 1297 # `dataset` defines the per-worker dataset and will not be further 1298 # sharded. Each worker will see a dataset that is 1299 # tf.data.Dataset.range(12).batch(8).rebatch(...). 1300 options = dataset_ops.Options() 1301 options.experimental_distribute.auto_shard_policy = AutoShardPolicy.OFF 1302 dataset = dataset.with_options(options) 1303 return dataset 1304 1305 dataset = self._create_dataset_or_input_fn(input_type, dataset_fn) 1306 1307 # Actual devices don't matter in this test as long as there is 1 local 1308 # replica. 1309 worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] 1310 1311 # Each test runs individually on each worker, so we compare the 1312 # values on each worker. Each worker should rebatch its dataset into 1313 # smaller batches of size 4. 1314 expected_values = [[[0, 1, 2, 3]], [[4, 5, 6, 7]], [[8, 9, 10, 11]]] 1315 self._test_input_iteration( 1316 input_type, 1317 api_type, 1318 iteration_type, 1319 dataset, 1320 worker_device_pairs, 1321 expected_values, 1322 distribution, 1323 num_replicas_in_sync=distribution.num_replicas_in_sync, 1324 input_context=distribution.extended._make_input_context()) 1325 1326 @combinations.generate( 1327 combinations.combine( 1328 mode=["eager"], 1329 input_type=["dataset"], 1330 api_type=["wrap_into_iterator", "wrap_into_dataset"], 1331 iteration_type=["get_next", "for_loop"], 1332 distribution=[ 1333 strategy_combinations.multi_worker_mirrored_2x1_cpu, 1334 strategy_combinations.multi_worker_mirrored_2x1_gpu, 1335 ])) 1336 def testMWMSPartialBatchWithLegacyRebatch(self, input_type, api_type, 1337 iteration_type, distribution): 1338 # Test case: 2 workers, 1 replica each. 1339 # This test simulates the sharded behavior when we have two files each with 1340 # 12 elements and a global batch size of 8. When we consider the dataset in 1341 # aggregate (non-distributed), there are 24 elements divided into 3 batches 1342 # of size 8. Hence, the correct distributed behavior is for each replica to 1343 # see sub-batches of size 4, over three steps. However, when we create a 1344 # DistributedDataset and cannot statically infer the intended global batch 1345 # size (e.g. if the user does not use a batching dataset), each worker will 1346 # rebatch based on the dynamic batch size of the data encountered, even when 1347 # it encounters partial batches. The last per-worker partial batch (size 4) 1348 # ends up being split into two replicas, resulting in 4 steps in total, of 1349 # (global) batch sizes 8, 8, 4, 4. 1350 def dataset_fn(ctx): 1351 del ctx 1352 # The following dataset is equivalent to 1353 # tf.data.Dataset.range(12).batch(8), but does not use a batching dataset. 1354 # This causes DistributedDataset to use LegacyRebatch instead. 1355 batch_sizes = dataset_ops.Dataset.from_tensor_slices([8, 4]) 1356 offsets = dataset_ops.Dataset.from_tensor_slices([0, 8]) 1357 dataset = dataset_ops.Dataset.zip((offsets, batch_sizes)) 1358 1359 def map_fn(offset, batch_size): 1360 return math_ops.range(offset, offset + batch_size) 1361 1362 dataset = dataset.map(map_fn) 1363 1364 # Set the sharding behavior to OFF for simplicity of test setup; namely, 1365 # `dataset` defines the per-worker dataset and will not be further 1366 # sharded. Each worker will see a dataset that is equivalent to 1367 # tf.data.Dataset.range(12).batch(8).rebatch(...). 1368 options = dataset_ops.Options() 1369 options.experimental_distribute.auto_shard_policy = AutoShardPolicy.OFF 1370 dataset = dataset.with_options(options) 1371 return dataset 1372 1373 dataset = self._create_dataset_or_input_fn(input_type, dataset_fn) 1374 1375 # Actual devices don't matter in this test as long as the number of global 1376 # replicas is 2. 1377 worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] 1378 1379 # Each test runs individually on each worker, so we compare the 1380 # values on each worker. Each worker should rebatch its dataset into 1381 # smaller batches of size 4. 1382 expected_values = [[[0, 1, 2, 3]], [[4, 5, 6, 7]], [[8, 9]], [[10, 11]]] 1383 self._test_input_iteration( 1384 input_type, 1385 api_type, 1386 iteration_type, 1387 dataset, 1388 worker_device_pairs, 1389 expected_values, 1390 distribution, 1391 num_replicas_in_sync=distribution.num_replicas_in_sync, 1392 input_context=distribution.extended._make_input_context()) 1393 1394 @combinations.generate( 1395 combinations.combine( 1396 mode=["eager"], 1397 input_type=["dataset"], 1398 api_type=["wrap_into_iterator", "wrap_into_dataset"], 1399 iteration_type=["get_next", "for_loop"], 1400 distribution=[ 1401 strategy_combinations.multi_worker_mirrored_2x1_cpu, 1402 strategy_combinations.multi_worker_mirrored_2x1_gpu, 1403 ], 1404 auto_shard_policy=[AutoShardPolicy.AUTO, AutoShardPolicy.DATA])) 1405 def testMWMSWithDataSharding(self, input_type, api_type, iteration_type, 1406 distribution, auto_shard_policy): 1407 # Test case: 2 workers, 1 replica each. 1408 # This test simulates the sharded behavior the dataset is sharded by data 1409 # and the batch size is indivisible by the number of replicas. This checks 1410 # that the elements are as expected and the batch size across all workers 1411 # adds up to 3. This test will only pass if the autoshard rewrite rewrites 1412 # RebatchDatasetV2 to legacy RebatchDataset when sharding by data. 1413 def dataset_fn(ctx): 1414 del ctx 1415 dataset = dataset_ops.Dataset.range(8).batch(3) 1416 1417 # Set the sharding behavior to OFF for simplicity of test setup; namely, 1418 # `dataset` defines the per-worker dataset and will not be further 1419 # sharded. Each worker will see a dataset that is 1420 # tf.data.Dataset.range(12).batch(8).rebatch(...). 1421 options = dataset_ops.Options() 1422 options.experimental_distribute.auto_shard_policy = auto_shard_policy 1423 dataset = dataset.with_options(options) 1424 return dataset 1425 1426 dataset = self._create_dataset_or_input_fn(input_type, dataset_fn) 1427 1428 # Actual devices don't matter in this test as long as there is 1 local 1429 # replica. 1430 worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] 1431 1432 # Each test runs individually on each worker, so we compare the 1433 # values on each worker. We expect each worker to see different shards of 1434 # data. 1435 cr = distribution.cluster_resolver 1436 worker_id = multi_worker_util.id_in_cluster(cr.cluster_spec(), cr.task_type, 1437 cr.task_id) 1438 1439 if worker_id == 0: 1440 expected_values = [[[0, 1]], [[3, 4]], [[6]]] 1441 elif worker_id == 1: 1442 expected_values = [[[2]], [[5]], [[7]]] 1443 1444 self._test_input_iteration( 1445 input_type, 1446 api_type, 1447 iteration_type, 1448 dataset, 1449 worker_device_pairs, 1450 expected_values, 1451 distribution, 1452 num_replicas_in_sync=distribution.num_replicas_in_sync, 1453 input_context=distribution.extended._make_input_context()) 1454 1455 1456class DistributedIteratorPerDeviceTest(DistributedIteratorTestBase, 1457 parameterized.TestCase): 1458 """Tests for PER_WORKER and PER_REPLICA's InputOptions variants.""" 1459 1460 def setUp(self): 1461 context._reset_context() 1462 strategy_combinations.set_virtual_cpus_to_at_least(3) 1463 super(DistributedIteratorPerDeviceTest, self).setUp() 1464 1465 @combinations.generate( 1466 combinations.combine( 1467 input_options=[ 1468 distribute_lib.InputOptions( 1469 experimental_place_dataset_on_device=False, 1470 experimental_prefetch_to_device=True, 1471 experimental_replication_mode=distribute_lib 1472 .InputReplicationMode.PER_WORKER), 1473 distribute_lib.InputOptions( 1474 experimental_place_dataset_on_device=False, 1475 experimental_prefetch_to_device=True, 1476 experimental_replication_mode=distribute_lib 1477 .InputReplicationMode.PER_REPLICA), 1478 ], 1479 mode=["eager"], 1480 distribution=[ 1481 strategy_combinations.mirrored_strategy_with_two_gpus, 1482 strategy_combinations.mirrored_strategy_with_cpu_1_and_2, 1483 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 1484 ])) 1485 def testDevicePlacementForPerWorkerValuesWithPrefetch(self, distribution, 1486 input_options): 1487 1488 def dataset_fn(input_context): # pylint: disable=[unused-argument] 1489 return dataset_ops.Dataset.from_tensor_slices([1, 2, 3, 4]) 1490 1491 ds = distribution.experimental_distribute_datasets_from_function( 1492 dataset_fn, input_options) 1493 1494 for x in ds: 1495 assert x.values[0].device == distribution.extended.worker_devices[0] 1496 assert x.values[0].backing_device == distribution.extended.worker_devices[ 1497 0] 1498 assert x.values[1].device == distribution.extended.worker_devices[1] 1499 assert x.values[1].backing_device == distribution.extended.worker_devices[ 1500 1] 1501 1502 @combinations.generate( 1503 combinations.combine( 1504 distribution=[ 1505 strategy_combinations.mirrored_strategy_with_two_gpus, 1506 strategy_combinations.mirrored_strategy_with_cpu_1_and_2, 1507 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 1508 ], 1509 input_options=[ 1510 distribute_lib.InputOptions( 1511 experimental_place_dataset_on_device=False, 1512 experimental_prefetch_to_device=False, 1513 experimental_replication_mode=distribute_lib 1514 .InputReplicationMode.PER_WORKER) 1515 ], 1516 mode=["eager"], 1517 )) 1518 def testDevicePlacementForPerWorkerValuesWithoutPrefetch( 1519 self, distribution, input_options): 1520 1521 def dataset_fn(input_context): 1522 return dataset_ops.Dataset.from_tensor_slices( 1523 np.full(4, input_context.input_pipeline_id)) 1524 1525 ds = distribution.experimental_distribute_datasets_from_function( 1526 dataset_fn, input_options) 1527 1528 for x in ds: 1529 x = distribution.run(lambda inputs: inputs, args=(x,)) 1530 assert x.values[ 1531 0].device == "/job:localhost/replica:0/task:0/device:CPU:0" 1532 assert x.values[ 1533 0].backing_device == "/job:localhost/replica:0/task:0/device:CPU:0" 1534 assert x.values[ 1535 1].device == "/job:localhost/replica:0/task:0/device:CPU:0" 1536 assert x.values[ 1537 1].backing_device == "/job:localhost/replica:0/task:0/device:CPU:0" 1538 1539 @combinations.generate( 1540 combinations.combine( 1541 input_options=[ 1542 distribute_lib.InputOptions( 1543 experimental_place_dataset_on_device=True, 1544 experimental_prefetch_to_device=False, 1545 experimental_replication_mode=distribute_lib 1546 .InputReplicationMode.PER_WORKER), 1547 distribute_lib.InputOptions( 1548 experimental_place_dataset_on_device=True, 1549 experimental_prefetch_to_device=True, 1550 experimental_replication_mode=distribute_lib 1551 .InputReplicationMode.PER_REPLICA) 1552 ], 1553 mode=["eager"], 1554 distribution=[ 1555 strategy_combinations.mirrored_strategy_with_two_gpus, 1556 strategy_combinations.mirrored_strategy_with_cpu_1_and_2, 1557 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 1558 ])) 1559 def testDevicePlacementForInvalidCombinations(self, distribution, 1560 input_options): 1561 1562 def dataset_fn(input_context): 1563 return dataset_ops.Dataset.from_tensor_slices( 1564 np.full(4, input_context.input_pipeline_id)) 1565 1566 with self.assertRaises(ValueError): 1567 distribution.experimental_distribute_datasets_from_function( 1568 dataset_fn, input_options) 1569 1570 @combinations.generate( 1571 combinations.combine( 1572 input_options=[ 1573 distribute_lib.InputOptions( 1574 experimental_place_dataset_on_device=False, 1575 experimental_prefetch_to_device=False, 1576 experimental_replication_mode=distribute_lib 1577 .InputReplicationMode.PER_WORKER), 1578 distribute_lib.InputOptions( 1579 experimental_place_dataset_on_device=False, 1580 experimental_prefetch_to_device=True, 1581 experimental_replication_mode=distribute_lib 1582 .InputReplicationMode.PER_WORKER), 1583 ], 1584 mode=["eager"], 1585 distribution=[ 1586 strategy_combinations.mirrored_strategy_with_two_gpus, 1587 strategy_combinations.mirrored_strategy_with_cpu_1_and_2, 1588 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 1589 ])) 1590 def testOutputValuesForPerWorkerInputOptions(self, distribution, 1591 input_options): 1592 1593 def dataset_fn(input_context): 1594 return dataset_ops.Dataset.from_tensor_slices( 1595 np.arange(1, 11).reshape( 1596 (2, 5)) * (input_context.input_pipeline_id + 1)) 1597 1598 ds = distribution.experimental_distribute_datasets_from_function( 1599 dataset_fn, input_options) 1600 1601 # validating the values 1602 x = next(iter(ds)) 1603 assert np.array_equal(x.values[0].numpy(), np.array([1, 2, 3, 4, 5])) 1604 assert np.array_equal(x.values[1].numpy(), np.array([6, 7, 8, 9, 10])) 1605 1606 @combinations.generate( 1607 combinations.combine( 1608 input_options=[ 1609 distribute_lib.InputOptions( 1610 experimental_place_dataset_on_device=True, 1611 experimental_prefetch_to_device=False, 1612 experimental_replication_mode=distribute_lib 1613 .InputReplicationMode.PER_REPLICA), 1614 distribute_lib.InputOptions( 1615 experimental_place_dataset_on_device=False, 1616 experimental_prefetch_to_device=False, 1617 experimental_replication_mode=distribute_lib 1618 .InputReplicationMode.PER_REPLICA), 1619 distribute_lib.InputOptions( 1620 experimental_place_dataset_on_device=False, 1621 experimental_prefetch_to_device=True, 1622 experimental_replication_mode=distribute_lib 1623 .InputReplicationMode.PER_REPLICA), 1624 ], 1625 mode=["eager"], 1626 distribution=[ 1627 strategy_combinations.mirrored_strategy_with_two_gpus, 1628 strategy_combinations.mirrored_strategy_with_cpu_1_and_2, 1629 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 1630 ])) 1631 def testOutputValuesForPerReplicaInputOptions(self, distribution, 1632 input_options): 1633 1634 def dataset_fn(input_context): 1635 return dataset_ops.Dataset.from_tensor_slices( 1636 np.arange(1, 10) * (input_context.input_pipeline_id + 1)) 1637 1638 ds = distribution.experimental_distribute_datasets_from_function( 1639 dataset_fn, input_options) 1640 expected = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9]) 1641 for i, x in enumerate(ds): 1642 # validating the values 1643 assert x.values[0].numpy() == expected[i] 1644 assert x.values[1].numpy() == expected[i] * 2 1645 loop_num = i 1646 assert loop_num == len(expected) - 1 1647 1648 1649class DistributedIteratorTfDataServiceTest(DistributedIteratorTestBase, 1650 parameterized.TestCase): 1651 """Tests for distributed iterators which read from tf.data service.""" 1652 1653 def setUp(self): 1654 super(DistributedIteratorTfDataServiceTest, self).setUp() 1655 self.num_workers = 3 1656 if combinations.in_main_process(): 1657 self.dispatcher = server_lib.DispatchServer() 1658 self.workers = [] 1659 for _ in range(self.num_workers): 1660 self.workers.append( 1661 server_lib.WorkerServer( 1662 server_lib.WorkerConfig( 1663 dispatcher_address=self.dispatcher.target.split("://")[1], 1664 heartbeat_interval_ms=100, 1665 dispatcher_timeout_ms=1000))) 1666 combinations.env().tf_data_service_dispatcher = self.dispatcher.target 1667 1668 @combinations.generate( 1669 combinations.combine( 1670 mode=["eager"], 1671 distribution=[ 1672 strategy_combinations.one_device_strategy, 1673 strategy_combinations.mirrored_strategy_with_one_cpu, 1674 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 1675 strategy_combinations.tpu_strategy, 1676 strategy_combinations.central_storage_strategy_with_two_gpus, 1677 strategy_combinations.multi_worker_mirrored_2x2_gpu, 1678 strategy_combinations.multi_worker_mirrored_2x1_cpu, 1679 ])) 1680 def testTfDataService(self, distribution): 1681 worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] 1682 input_workers = input_lib.InputWorkers(worker_device_pairs) 1683 1684 dataset = dataset_ops.Dataset.range(1, 50) 1685 dataset = dataset.apply( 1686 data_service_ops._distribute( 1687 processing_mode="parallel_epochs", 1688 service=combinations.env().tf_data_service_dispatcher, 1689 job_name="foo")) 1690 1691 dist_dataset = input_lib.get_distributed_dataset(dataset, input_workers, 1692 distribution) 1693 1694 iterator = iter(dist_dataset) 1695 results = [] 1696 for element in iterator: 1697 local_results = distribution.experimental_local_results(element) 1698 for result in local_results: 1699 # input_lib.distributed_dataset may add extra '0' elements to pad 1700 # per-replica results. 1701 if result.numpy() != 0: 1702 results.append(result.numpy()) 1703 self.assertNotEmpty(results) 1704 gathered = distribution.gather(constant_op.constant(results), axis=0) 1705 self.assertCountEqual(self.num_workers * list(range(1, 50)), gathered) 1706 1707 1708if __name__ == "__main__": 1709 test_util.main() 1710