1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for the input_lib library."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22
23from absl.testing import parameterized
24import numpy as np
25
26from tensorflow.python import tf2
27from tensorflow.python.compat import compat
28from tensorflow.python.data.experimental.ops import data_service_ops
29from tensorflow.python.data.experimental.ops.distribute_options import AutoShardPolicy
30from tensorflow.python.data.experimental.service import server_lib
31from tensorflow.python.data.ops import dataset_ops
32from tensorflow.python.distribute import combinations
33from tensorflow.python.distribute import device_util
34from tensorflow.python.distribute import distribute_lib
35from tensorflow.python.distribute import distribute_utils
36from tensorflow.python.distribute import input_lib
37from tensorflow.python.distribute import multi_worker_util
38from tensorflow.python.distribute import reduce_util
39from tensorflow.python.distribute import strategy_combinations
40from tensorflow.python.distribute import test_util
41from tensorflow.python.eager import context
42from tensorflow.python.eager import def_function
43from tensorflow.python.eager import test
44from tensorflow.python.framework import composite_tensor
45from tensorflow.python.framework import constant_op
46from tensorflow.python.framework import dtypes
47from tensorflow.python.framework import errors
48from tensorflow.python.framework import ops
49from tensorflow.python.framework import sparse_tensor
50from tensorflow.python.ops import array_ops
51from tensorflow.python.ops import control_flow_ops
52from tensorflow.python.ops import math_ops
53from tensorflow.python.ops import sparse_ops
54from tensorflow.python.ops import variables
55from tensorflow.python.ops.ragged import ragged_tensor as ragged_tensor_lib
56from tensorflow.python.util import nest
57
58
59class DistributedIteratorTestBase(test.TestCase):
60
61  # The passed input_context is to create a sharded dataset in between-graph
62  # case.
63  # TODO(yuefengz): rewrite the following method to make it less DRY.
64  def _wrap_iterator(self,
65                     input_type,
66                     dataset_or_input_fn,
67                     input_workers,
68                     devices,
69                     num_replicas_in_sync,
70                     strategy,
71                     input_context=None):
72    # The `input_context` passed in is to shard dataset for
73    # MultiWorkerMirroredStrategy. It doesn't apply to in-graph case where
74    # multiple InputContexts are needed.
75    if input_type == "input_fn":
76      self.assertIsNone(
77          input_context,
78          msg=("`The input_context` arg is only used to shard dataset in "
79               "`MultiWorkerMirroredStrategy` when the input type is dataset."))
80
81      input_contexts = []
82      for i in range(input_workers.num_workers):
83        input_contexts.append(
84            distribute_lib.InputContext(
85                # Note: `input_workers.num_workers` is always 1 in between-graph
86                # case.
87                num_input_pipelines=input_workers.num_workers,
88                input_pipeline_id=i,
89                num_replicas_in_sync=len(devices)))
90
91      iterator = input_lib.InputFunctionIterator(
92          dataset_or_input_fn,
93          input_workers,
94          input_contexts,
95          strategy)
96    else:
97      iterator = input_lib.DatasetIterator(
98          dataset_or_input_fn,
99          input_workers,
100          strategy,
101          num_replicas_in_sync=num_replicas_in_sync,
102          input_context=input_context)
103    return iterator
104
105  def _wrap_dataset(self,
106                    input_type,
107                    dataset,
108                    input_workers,
109                    num_replicas_in_sync,
110                    strategy,
111                    input_context=None):
112    if input_type == "dataset":
113      if tf2.enabled():
114        return input_lib.DistributedDataset(
115            dataset,
116            input_workers,
117            strategy,
118            num_replicas_in_sync=num_replicas_in_sync,
119            input_context=input_context)
120      else:
121        return input_lib.DistributedDatasetV1(
122            dataset,
123            input_workers,
124            strategy,
125            num_replicas_in_sync=num_replicas_in_sync,
126            input_context=input_context)
127    else:
128      return strategy.distribute_datasets_from_function(dataset)
129
130  def _assert_iterator_values(self,
131                              iterator,
132                              expected_values,
133                              evaluate_fn,
134                              devices,
135                              enable_get_next_as_optional=False):
136    actual_values = []
137    for _ in range(len(expected_values)):
138      if enable_get_next_as_optional:
139        next_element = iterator.get_next_as_optional().get_value()
140      else:
141        next_element = iterator.get_next()
142      computed_value = evaluate_fn([
143          distribute_utils.select_replica(r, next_element)
144          for r in range(len(devices))
145      ])
146      actual_values.append(computed_value)
147    for expected_value, actual_value in zip(expected_values, actual_values):
148      for expected, actual in zip(expected_value, actual_value):
149        self.assertAllEqual(expected, actual)
150
151  def _assert_dataset_values_for_loop(self, dataset, expected_values,
152                                      evaluate_fn, devices):
153    actual_values = []
154    for x in dataset:
155      computed_value = self.evaluate(
156          [distribute_utils.select_replica(r, x) for r in range(len(devices))])
157      actual_values.append(computed_value)
158    for expected_value, actual_value in zip(expected_values, actual_values):
159      for expected, actual in zip(expected_value, actual_value):
160        self.assertAllEqual(expected, actual)
161
162  def _test_input_iteration(self,
163                            input_type,
164                            api_type,
165                            iteration_type,
166                            dataset_or_input_fn,
167                            worker_device_pairs,
168                            expected_values,
169                            strategy,
170                            sess=None,
171                            num_replicas_in_sync=None,
172                            input_context=None):
173    if iteration_type == "for_loop" and not context.executing_eagerly():
174      self.skipTest("unsupported test combination.")
175
176    if api_type == "wrap_into_iterator" and iteration_type == "for_loop":
177      self.skipTest("unsupported test combination.")
178
179    if api_type == "wrap_into_iterator" and input_type == "input_fn":
180      self.skipTest("unsupported test combination.")
181
182    devices = nest.flatten([ds for _, ds in worker_device_pairs])
183    input_workers = input_lib.InputWorkers(worker_device_pairs)
184
185    if api_type == "wrap_into_iterator":
186      iterator = self._wrap_iterator(
187          input_type,
188          dataset_or_input_fn,
189          input_workers,
190          devices,
191          num_replicas_in_sync,
192          strategy,
193          input_context=input_context)
194    else:
195      # wrapping into a dataset:
196      dataset = self._wrap_dataset(
197          input_type,
198          dataset_or_input_fn,
199          input_workers,
200          num_replicas_in_sync,
201          strategy,
202          input_context=input_context)
203
204      if ops.executing_eagerly_outside_functions():
205        iterator = iter(dataset)
206      else:
207        if isinstance(dataset, input_lib.DistributedDatasetV1):
208          iterator = dataset.make_initializable_iterator()
209        else:
210          self.skipTest("unsupported test combination")
211
212    if isinstance(iterator, composite_tensor.CompositeTensor):
213      nest.assert_same_structure(iterator, iterator._type_spec,
214                                 expand_composites=True)
215
216    if iteration_type == "get_next":
217      evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)
218      if not ops.executing_eagerly_outside_functions():
219        evaluate(control_flow_ops.group(iterator.initializer))
220
221      def test_get_next(iterator):
222        self._assert_iterator_values(iterator, expected_values, evaluate,
223                                     devices)
224
225        with self.assertRaises(errors.OutOfRangeError):
226          self._assert_iterator_values(iterator, expected_values, evaluate,
227                                       devices)
228
229        # After re-initializing the iterator, should be able to iterate again.
230        if not ops.executing_eagerly_outside_functions():
231          evaluate(control_flow_ops.group(iterator.initializer))
232        else:
233          if api_type == "wrap_into_iterator":
234            self.skipTest("unsupported test combination")
235          else:
236            iterator = iter(dataset)
237
238        self._assert_iterator_values(iterator, expected_values, evaluate,
239                                     devices)
240
241      def test_get_next_as_optional(iterator):
242        self._assert_iterator_values(
243            iterator,
244            expected_values,
245            evaluate,
246            devices,
247            enable_get_next_as_optional=True)
248
249        next_element = iterator.get_next_as_optional()
250        self.assertFalse(self.evaluate(next_element.has_value()))
251        with self.assertRaises(errors.InvalidArgumentError):
252          self._assert_iterator_values(
253              iterator, [0],
254              evaluate,
255              devices,
256              enable_get_next_as_optional=True)
257
258      test_get_next(iterator)
259
260      # re-initializing the iterator
261      if not tf2.enabled():
262        # TODO(yuefengz): we should split this function.
263        return
264      else:
265        if api_type == "wrap_into_iterator":
266          return
267        else:
268          iterator = iter(dataset)
269
270      test_get_next_as_optional(iterator)
271
272    if iteration_type == "for_loop" and context.executing_eagerly():
273      self._assert_dataset_values_for_loop(dataset, expected_values,
274                                           self.evaluate, devices)
275
276  def _create_dataset_or_input_fn(self, input_type, input_fn):
277    if input_type == "input_fn":
278      return input_fn
279    else:
280      return input_fn(distribute_lib.InputContext())
281
282
283class DistributedIteratorTest(DistributedIteratorTestBase,
284                              parameterized.TestCase):
285
286  @combinations.generate(
287      combinations.combine(
288          mode=["eager"],
289          input_type=["input_fn", "dataset"],
290          distribution=[
291              strategy_combinations.one_device_strategy,
292              strategy_combinations.mirrored_strategy_with_one_cpu,
293              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
294              strategy_combinations.multi_worker_mirrored_2x1_cpu
295          ]))
296  def testDisablingOwnedIteratorsInTF2(self, distribution, input_type):
297    if not tf2.enabled():
298      self.skipTest("unsupported test combination")
299
300    worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
301    input_workers = input_lib.InputWorkers(worker_device_pairs)
302    dataset_fn = lambda _: dataset_ops.DatasetV2.range(10)
303    dataset_or_input_fn = self._create_dataset_or_input_fn(
304        input_type, dataset_fn)
305
306    input_workers = input_lib.InputWorkers(worker_device_pairs)
307    if input_type == "dataset":
308      dist_dataset = input_lib.get_distributed_dataset(dataset_or_input_fn,
309                                                       input_workers,
310                                                       distribution)
311    else:
312      dist_dataset = input_lib.get_distributed_datasets_from_function(
313          dataset_or_input_fn, input_workers, [distribute_lib.InputContext()],
314          distribution)
315
316    # Default Iterator types in TF2.
317    iterator = iter(dist_dataset)
318    self.assertIsInstance(iterator, input_lib.DistributedIterator)
319    self.assertIsInstance(iterator._iterators[0],
320                          input_lib._SingleWorkerOwnedDatasetIterator)
321
322    # Disable creating owned iterators by setting a property on the strategy.
323    distribution._enable_legacy_iterators = True
324    iterator = iter(dist_dataset)
325    self.assertIsInstance(iterator, input_lib.DistributedIteratorV1)
326    self.assertIsInstance(iterator._iterators[0],
327                          input_lib._SingleWorkerDatasetIterator)
328
329  @combinations.generate(
330      combinations.combine(
331          mode=["eager"],
332          distribution=[
333              strategy_combinations.mirrored_strategy_with_gpu_and_cpu
334          ]))
335  def testMultiDeviceIterInitialize(self, distribution):
336    if tf2.enabled():
337      self.skipTest("Only V1 is supported.")
338    worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
339                                              "/device:CPU:0"])]
340    dataset_fn = lambda _: dataset_ops.DatasetV1.range(10)
341
342    input_workers = input_lib.InputWorkers(worker_device_pairs)
343
344    dist_dataset = input_lib.get_distributed_dataset(
345        dataset_fn(distribute_lib.InputContext()), input_workers, distribution)
346
347    iterator = dataset_ops.make_one_shot_iterator(dist_dataset)
348
349    @def_function.function
350    def init_func_for_iter():
351      self.evaluate(iterator.initializer)
352
353    init_func_for_iter()
354
355  @combinations.generate(
356      combinations.combine(
357          mode=["graph", "eager"],
358          input_type=["input_fn", "dataset"],
359          api_type=["wrap_into_iterator", "wrap_into_dataset"],
360          iteration_type=["get_next", "for_loop"],
361          distribution=[
362              strategy_combinations.one_device_strategy,
363              strategy_combinations.mirrored_strategy_with_one_cpu,
364          ],
365          enable_get_next_as_optional=[True, False]))
366  def testOneDeviceCPU(self, input_type, api_type, iteration_type, distribution,
367                       enable_get_next_as_optional):
368    worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
369    dataset_fn = lambda _: dataset_ops.Dataset.range(10)
370    dataset_or_input_fn = self._create_dataset_or_input_fn(
371        input_type, dataset_fn)
372
373    expected_values = [[i] for i in range(10)]
374
375    distribution.extended.experimental_enable_get_next_as_optional = (
376        enable_get_next_as_optional)
377    self._test_input_iteration(input_type, api_type, iteration_type,
378                               dataset_or_input_fn, worker_device_pairs,
379                               expected_values, distribution)
380
381  @combinations.generate(
382      combinations.combine(
383          mode=["eager"],
384          input_type=["input_fn", "dataset"],
385          api_type=["wrap_into_dataset"],
386          iteration_type=["get_next", "for_loop"],
387          distribution=[strategy_combinations.multi_worker_mirrored_2x1_cpu],
388          enable_get_next_as_optional=[True, False]))
389  def testOneDeviceCPUMultiWorker(self, input_type, api_type, iteration_type,
390                                  distribution, enable_get_next_as_optional):
391    worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
392    dataset_fn = lambda _: dataset_ops.DatasetV1.range(10)
393    dataset_or_input_fn = self._create_dataset_or_input_fn(
394        input_type, dataset_fn)
395
396    expected_values = [[i] for i in range(10)]
397
398    distribution.extended.experimental_enable_get_next_as_optional = (
399        enable_get_next_as_optional)
400    self._test_input_iteration(
401        input_type,
402        api_type,
403        iteration_type,
404        dataset_or_input_fn,
405        worker_device_pairs,
406        expected_values,
407        distribution)
408
409  @combinations.generate(
410      combinations.combine(
411          mode=["graph", "eager"],
412          input_type=["input_fn", "dataset"],
413          api_type=["wrap_into_iterator", "wrap_into_dataset"],
414          iteration_type=["get_next", "for_loop"],
415          distribution=[
416              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
417              strategy_combinations.central_storage_strategy_with_gpu_and_cpu
418          ],
419          enable_get_next_as_optional=[True, False]))
420  def testTwoDevicesOneGPUOneCPU(self, input_type, api_type, iteration_type,
421                                 distribution, enable_get_next_as_optional):
422    worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
423                                              "/device:CPU:0"])]
424    dataset_fn = lambda _: dataset_ops.Dataset.range(10)
425    dataset_or_input_fn = self._create_dataset_or_input_fn(
426        input_type, dataset_fn)
427
428    expected_values = [[i, i+1] for i in range(0, 10, 2)]
429
430    distribution.extended.experimental_enable_get_next_as_optional = (
431        enable_get_next_as_optional)
432    self._test_input_iteration(
433        input_type,
434        api_type,
435        iteration_type,
436        dataset_or_input_fn,
437        worker_device_pairs,
438        expected_values,
439        distribution)
440
441  @combinations.generate(
442      combinations.combine(
443          mode=["graph", "eager"],
444          input_type=["input_fn", "dataset"],
445          api_type=["wrap_into_iterator", "wrap_into_dataset"],
446          iteration_type=["get_next", "for_loop"],
447          distribution=[strategy_combinations.tpu_strategy],
448          enable_get_next_as_optional=[True, False]))
449  def testTPU(self, input_type, api_type, iteration_type, distribution,
450              enable_get_next_as_optional):
451    worker_device_pairs = collections.OrderedDict()
452    for tpu_device in distribution.extended.worker_devices:
453      host_device = device_util.get_host_for_device(tpu_device)
454      worker_device_pairs.setdefault(host_device, [])
455      worker_device_pairs[host_device].append(tpu_device)
456    worker_device_pairs = worker_device_pairs.items()
457    dataset_fn = lambda _: dataset_ops.Dataset.range(10)
458    dataset_or_input_fn = self._create_dataset_or_input_fn(
459        input_type, dataset_fn)
460
461    expected_values = [[i, i + 1] for i in range(0, 10, 2)]
462
463    distribution.extended.experimental_enable_get_next_as_optional = (
464        enable_get_next_as_optional)
465    self._test_input_iteration(
466        input_type,
467        api_type,
468        iteration_type,
469        dataset_or_input_fn,
470        worker_device_pairs,
471        expected_values,
472        distribution)
473
474  @combinations.generate(
475      combinations.combine(
476          mode=["graph", "eager"],
477          input_type=["input_fn", "dataset"],
478          api_type=["wrap_into_iterator", "wrap_into_dataset"],
479          iteration_type=["get_next", "for_loop"],
480          distribution=[
481              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
482              strategy_combinations.central_storage_strategy_with_gpu_and_cpu,
483          ],
484          enable_get_next_as_optional=[True, False]))
485  def testTupleDataset(self, input_type, api_type, iteration_type, distribution,
486                       enable_get_next_as_optional):
487    worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
488                                              "/device:CPU:0"])]
489
490    def dataset_fn(ctx):
491      del ctx
492      dataset1 = dataset_ops.Dataset.range(10)
493      dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2)
494      return dataset_ops.Dataset.zip((dataset1, dataset2))
495
496    dataset_or_input_fn = self._create_dataset_or_input_fn(
497        input_type, dataset_fn)
498
499    expected_values = [[(i, i**2), (i+1, (i+1)**2)] for i in range(0, 10, 2)]
500
501    distribution.extended.experimental_enable_get_next_as_optional = (
502        enable_get_next_as_optional)
503    self._test_input_iteration(
504        input_type,
505        api_type,
506        iteration_type,
507        dataset_or_input_fn,
508        worker_device_pairs,
509        expected_values,
510        distribution)
511
512  @combinations.generate(
513      combinations.combine(
514          mode=["eager"],
515          input_type=["input_fn", "dataset"],
516          api_type=["wrap_into_dataset"],
517          iteration_type=["get_next", "for_loop"],
518          distribution=[strategy_combinations.multi_worker_mirrored_2x2_gpu],
519          enable_get_next_as_optional=[True, False]))
520  def testTupleDatasetMultiworker(self, input_type, api_type, iteration_type,
521                                  distribution, enable_get_next_as_optional):
522    worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
523                                              "/device:GPU:1"])]
524
525    def dataset_fn(ctx):
526      del ctx
527      dataset1 = dataset_ops.Dataset.range(10)
528      dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2)
529      return dataset_ops.Dataset.zip((dataset1, dataset2))
530
531    dataset_or_input_fn = self._create_dataset_or_input_fn(
532        input_type, dataset_fn)
533
534    expected_values = [
535        [(i, i**2), (i + 1, (i + 1)**2)] for i in range(0, 10, 2)
536    ]
537
538    distribution.extended.experimental_enable_get_next_as_optional = (
539        enable_get_next_as_optional)
540
541    # Input_context is not passed in and thus no sharding.
542    self._test_input_iteration(input_type, api_type, iteration_type,
543                               dataset_or_input_fn, worker_device_pairs,
544                               expected_values, distribution)
545
546  @combinations.generate(
547      combinations.combine(
548          mode=["eager"],
549          distribution=[
550              strategy_combinations.one_device_strategy,
551              strategy_combinations.mirrored_strategy_with_one_cpu,
552              strategy_combinations.multi_worker_mirrored_2x1_cpu,
553          ]))
554  def testIterableIterator(self, distribution):
555    worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
556    input_workers = input_lib.InputWorkers(worker_device_pairs)
557
558    dataset = dataset_ops.Dataset.range(10)
559    dist_dataset = input_lib.get_distributed_dataset(dataset, input_workers,
560                                                     distribution)
561
562    iterator = iter(dist_dataset)
563    for i, element in enumerate(iterator):
564      self.assertAllEqual(distribution.experimental_local_results(element), [i])
565
566  @combinations.generate(
567      combinations.combine(
568          mode=["graph", "eager"],
569          input_type=["input_fn", "dataset"],
570          api_type=["wrap_into_iterator", "wrap_into_dataset"],
571          iteration_type=["get_next", "for_loop"],
572          drop_remainder=[True, False],
573          distribution=[
574              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
575              strategy_combinations.central_storage_strategy_with_gpu_and_cpu
576          ]))
577  def testUnevenDatasetBatches(self, input_type, api_type, iteration_type,
578                               drop_remainder, distribution):
579    worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
580                                              "/device:CPU:0"])]
581    dataset_fn = lambda _: dataset_ops.Dataset.range(9).batch(  # pylint: disable=g-long-lambda
582        2, drop_remainder=drop_remainder)
583    dataset_or_input_fn = self._create_dataset_or_input_fn(
584        input_type, dataset_fn)
585
586    # The last global batch only contains data for one replica.
587    if drop_remainder:
588      expected_values = [[[0, 1], [2, 3]], [[4, 5], [6, 7]]]
589    else:
590      expected_values = [[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[8], []]]
591    distribution.extended.experimental_enable_get_next_as_optional = True
592    self._test_input_iteration(
593        input_type,
594        api_type,
595        iteration_type,
596        dataset_or_input_fn,
597        worker_device_pairs,
598        expected_values,
599        distribution)
600
601  @combinations.generate(
602      combinations.combine(
603          mode=["eager"],
604          input_type=["input_fn", "dataset"],
605          api_type=["wrap_into_dataset"],
606          iteration_type=["get_next", "for_loop"],
607          drop_remainder=[True, False],
608          distribution=[
609              strategy_combinations.multi_worker_mirrored_2x1_cpu,
610              strategy_combinations.multi_worker_mirrored_2x1_gpu,
611          ]))
612  def testUnevenDatasetBatchesMultiWorker(self, input_type, api_type,
613                                          iteration_type, drop_remainder,
614                                          distribution):
615    # Actual devices don't matter in this test as long as the number of global
616    # repices is 2.
617    worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
618    cr = distribution.cluster_resolver
619    self.assertIsNotNone(cr)
620    worker_count = multi_worker_util.worker_count(cr.cluster_spec(),
621                                                  cr.task_type)
622    id_in_cluster = multi_worker_util.id_in_cluster(cr.cluster_spec(),
623                                                    cr.task_type, cr.task_id)
624
625    def dataset_fn(_):
626      dataset = dataset_ops.Dataset.range(9)
627
628      if input_type == "input_fn":
629        # When input_fn is used, there is no automatic rebatching and sharding,
630        # so we add them here.
631        return dataset.shard(worker_count, id_in_cluster).batch(1)
632      else:
633        return dataset.batch(2, drop_remainder=drop_remainder)
634
635    dataset_or_input_fn = self._create_dataset_or_input_fn(
636        input_type, dataset_fn)
637
638    if drop_remainder and input_type == "dataset":
639      if id_in_cluster == 0:
640        expected_values = [[[0]], [[2]], [[4]], [[6]]]
641      else:
642        expected_values = [[[1]], [[3]], [[5]], [[7]]]
643    else:
644      # The last global batch only contains data for one replica.
645      if id_in_cluster == 0:
646        expected_values = [[[0]], [[2]], [[4]], [[6]], [[8]]]
647      else:
648        expected_values = [[[1]], [[3]], [[5]], [[7]], [[]]]
649    distribution.extended.experimental_enable_get_next_as_optional = True
650    self._test_input_iteration(
651        input_type,
652        api_type,
653        iteration_type,
654        dataset_or_input_fn,
655        worker_device_pairs,
656        expected_values,
657        distribution,
658        num_replicas_in_sync=distribution.num_replicas_in_sync,
659        input_context=distribution.extended._make_input_context())
660
661  @combinations.generate(
662      combinations.combine(
663          mode=["eager"],
664          input_type=["input_fn", "dataset"],
665          api_type=["wrap_into_dataset"],
666          iteration_type=["get_next", "for_loop"],
667          drop_remainder=[True, False],
668          distribution=[
669              strategy_combinations.multi_worker_mirrored_2x2_gpu,
670          ]))
671  def testUnevenDatasetBatchesMultiWorkerFourReplicas(self, input_type,
672                                                      api_type, iteration_type,
673                                                      drop_remainder,
674                                                      distribution):
675    # Actual devices don't matter in this test as long as the number of global
676    # repices is 2.
677    worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
678                                              "/device:GPU:1"])]
679    cr = distribution.cluster_resolver
680    self.assertIsNotNone(cr)
681    worker_count = multi_worker_util.worker_count(cr.cluster_spec(),
682                                                  cr.task_type)
683    id_in_cluster = multi_worker_util.id_in_cluster(cr.cluster_spec(),
684                                                    cr.task_type, cr.task_id)
685
686    def dataset_fn(_):
687      dataset = dataset_ops.Dataset.range(15)
688
689      if input_type == "input_fn":
690        # When input_fn is used, there is no automatic rebatching and sharding,
691        # so we add them here.
692        return dataset.shard(worker_count, id_in_cluster).batch(1)
693      else:
694        return dataset.batch(4, drop_remainder=drop_remainder)
695
696    dataset_or_input_fn = self._create_dataset_or_input_fn(
697        input_type, dataset_fn)
698
699    # The last global batch only contains data for one replica.
700    if drop_remainder and input_type == "dataset":
701      if id_in_cluster == 0:
702        expected_values = [[[0], [2]], [[4], [6]], [[8], [10]]]
703      else:
704        expected_values = [[[1], [3]], [[5], [7]], [[9], [11]]]
705    else:
706      if id_in_cluster == 0:
707        expected_values = [[[0], [2]], [[4], [6]], [[8], [10]], [[12], [14]]]
708      else:
709        expected_values = [[[1], [3]], [[5], [7]], [[9], [11]], [[13], []]]
710    distribution.extended.experimental_enable_get_next_as_optional = True
711    self._test_input_iteration(
712        input_type,
713        api_type,
714        iteration_type,
715        dataset_or_input_fn,
716        worker_device_pairs,
717        expected_values,
718        distribution,
719        num_replicas_in_sync=distribution.num_replicas_in_sync,
720        input_context=distribution.extended._make_input_context())
721
722  @combinations.generate(
723      combinations.combine(
724          mode=["graph", "eager"],
725          input_type=["dataset"],
726          api_type=["wrap_into_iterator", "wrap_into_dataset"],
727          iteration_type=["get_next", "for_loop"],
728          num_replicas_in_sync=[None, 2],
729          distribution=[
730              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
731              strategy_combinations.central_storage_strategy_with_gpu_and_cpu
732          ],
733          enable_get_next_as_optional=[True, False]))
734  def testBatchSplitting(self, input_type, api_type, iteration_type,
735                         num_replicas_in_sync, distribution,
736                         enable_get_next_as_optional):
737    worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
738                                              "/device:CPU:0"])]
739    batch_size = 10
740    dataset_fn = lambda _: dataset_ops.Dataset.range(100).batch(batch_size)
741    dataset_or_input_fn = self._create_dataset_or_input_fn(
742        input_type, dataset_fn)
743
744    updated_batch_size = (
745        batch_size //
746        num_replicas_in_sync if num_replicas_in_sync else batch_size)
747    expected_values = [[range(i, i+updated_batch_size),
748                        range(i+updated_batch_size, i+2*updated_batch_size)]
749                       for i in range(0, 100, updated_batch_size*2)]
750
751    distribution.extended.experimental_enable_get_next_as_optional = (
752        enable_get_next_as_optional)
753    self._test_input_iteration(
754        input_type,
755        api_type,
756        iteration_type,
757        dataset_or_input_fn,
758        worker_device_pairs,
759        expected_values,
760        distribution,
761        sess=None,
762        num_replicas_in_sync=num_replicas_in_sync)
763
764  @combinations.generate(
765      combinations.combine(
766          mode=["eager"],
767          input_type=["dataset"],
768          api_type=["wrap_into_dataset"],
769          iteration_type=["get_next", "for_loop"],
770          num_replicas_in_sync=[None, 2],
771          distribution=[
772              strategy_combinations.multi_worker_mirrored_2x2_gpu,
773          ],
774          enable_get_next_as_optional=[True, False]))
775  def testBatchSplittingMultiWorker(self, input_type, api_type, iteration_type,
776                                    num_replicas_in_sync, distribution,
777                                    enable_get_next_as_optional):
778    worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
779                                              "/device:GPU:1"])]
780    batch_size = 10
781    cr = distribution.cluster_resolver
782    self.assertIsNotNone(cr)
783
784    def dataset_fn(_):
785      dataset = dataset_ops.Dataset.range(100).batch(batch_size)
786      return dataset
787
788    dataset_or_input_fn = self._create_dataset_or_input_fn(
789        input_type, dataset_fn)
790
791    updated_batch_size = (
792        batch_size //
793        num_replicas_in_sync if num_replicas_in_sync else batch_size)
794    expected_values = [
795        [  # pylint: disable=g-complex-comprehension
796            range(i, i + updated_batch_size),
797            range(i + updated_batch_size, i + 2 * updated_batch_size)
798        ] for i in range(0, 100, updated_batch_size * 2)
799    ]
800
801    distribution.extended.experimental_enable_get_next_as_optional = (
802        enable_get_next_as_optional)
803    self._test_input_iteration(
804        input_type,
805        api_type,
806        iteration_type,
807        dataset_or_input_fn,
808        worker_device_pairs,
809        expected_values,
810        distribution,
811        sess=None,
812        num_replicas_in_sync=num_replicas_in_sync)
813
814  @combinations.generate(
815      combinations.combine(
816          mode=["eager"],
817          distribution=[
818              strategy_combinations.one_device_strategy,
819              strategy_combinations.mirrored_strategy_with_one_cpu,
820              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
821              strategy_combinations.tpu_strategy,
822              strategy_combinations.central_storage_strategy_with_two_gpus,
823              strategy_combinations.multi_worker_mirrored_2x2_gpu,
824              strategy_combinations.multi_worker_mirrored_2x1_cpu,
825          ],
826      ))
827  def testCacheAcrossIteration(self, distribution):
828    if not tf2.enabled():
829      self.skipTest("Only V2 is supported.")
830
831    dataset = dataset_ops.Dataset.range(16).shuffle(16).cache().batch(4)
832    dist_dataset = distribution.experimental_distribute_dataset(dataset)
833
834    first_epoch = list(
835        distribution.experimental_local_results(x) for x in dist_dataset)
836    second_epoch = list(
837        distribution.experimental_local_results(x) for x in dist_dataset)
838
839    self.assertAllEqual(first_epoch, second_epoch)
840
841  @combinations.generate(
842      combinations.combine(
843          mode=["eager"],
844          distribution=[
845              strategy_combinations.one_device_strategy,
846              strategy_combinations.mirrored_strategy_with_one_cpu,
847              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
848              strategy_combinations.tpu_strategy,
849              strategy_combinations.central_storage_strategy_with_two_gpus,
850              strategy_combinations.multi_worker_mirrored_2x2_gpu,
851              strategy_combinations.multi_worker_mirrored_2x1_cpu,
852          ],
853          reshuffle=[True, False]))
854  def testShuffleAcrossIterations(self, distribution, reshuffle):
855    if not tf2.enabled():
856      self.skipTest("Only V2 is supported.")
857
858    if not reshuffle and not compat.forward_compatible(2020, 5, 22):
859      self.skipTest("Functionality currently not supported.")
860
861    dataset = dataset_ops.Dataset.range(12).shuffle(
862        12, reshuffle_each_iteration=reshuffle).batch(4)
863    dist_dataset = distribution.experimental_distribute_dataset(dataset)
864
865    first_epoch = list(
866        distribution.experimental_local_results(x) for x in dist_dataset)
867    second_epoch = list(
868        distribution.experimental_local_results(x) for x in dist_dataset)
869
870    if reshuffle:
871      self.assertNotAllEqual(first_epoch, second_epoch)
872    else:
873      self.assertAllEqual(first_epoch, second_epoch)
874
875  @combinations.generate(
876      combinations.combine(
877          mode=["eager"],
878          distribution=[
879              strategy_combinations.one_device_strategy,
880              strategy_combinations.mirrored_strategy_with_one_cpu,
881              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
882              strategy_combinations.tpu_strategy,
883              strategy_combinations.central_storage_strategy_with_two_gpus,
884              strategy_combinations.multi_worker_mirrored_2x2_gpu,
885              strategy_combinations.multi_worker_mirrored_2x1_cpu,
886          ]))
887  def testGetNextOptionalShape(self, distribution):
888    batch_size = 8
889    dataset = dataset_ops.DatasetV2.from_tensor_slices({
890        "feature": array_ops.ones([batch_size, 10]),
891        "label": array_ops.ones([batch_size]),
892    })
893    dataset = dataset.batch(batch_size, drop_remainder=True)
894    dist_dataset = distribution.experimental_distribute_dataset(dataset)
895    per_replica_batch_size = batch_size // distribution.num_replicas_in_sync
896
897    @def_function.function
898    def train_fn():
899      for data in dist_dataset:
900        data = nest.map_structure(distribution.experimental_local_results, data)
901        feature = data["feature"]
902        label = data["label"]
903
904        # Assert the shapes are still static from all replicas.
905        for replica_id in range(len(distribution.extended.worker_devices)):
906          self.assertEqual([per_replica_batch_size, 10],
907                           feature[replica_id].shape)
908          self.assertEqual([per_replica_batch_size], label[replica_id].shape)
909
910    train_fn()
911
912  @combinations.generate(
913      combinations.combine(
914          mode=["eager"],
915          distribution=[
916              strategy_combinations.multi_worker_mirrored_2x1_cpu,
917          ],
918          input_type=["dataset"],
919          api_type=["wrap_into_iterator", "wrap_into_dataset"],
920          iteration_type=["get_next", "for_loop"],
921          auto_shard_policy=[AutoShardPolicy.AUTO, AutoShardPolicy.OFF]))
922  def testAutoshardingOption(self, distribution, input_type, api_type,
923                             iteration_type, auto_shard_policy):
924    cr = distribution.cluster_resolver
925    self.assertIsNotNone(cr)
926    id_in_cluster = multi_worker_util.id_in_cluster(cr.cluster_spec(),
927                                                    cr.task_type, cr.task_id)
928    ds_option = dataset_ops.Options()
929    ds_option.experimental_distribute.auto_shard_policy = auto_shard_policy
930    dataset_fn = (
931        lambda _: dataset_ops.Dataset.range(4).with_options(ds_option))
932    dataset_or_input_fn = self._create_dataset_or_input_fn(
933        input_type, dataset_fn)
934
935    worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
936    if auto_shard_policy == AutoShardPolicy.AUTO:
937      if id_in_cluster == 0:
938        expected_values = [[0], [2]]
939      else:
940        expected_values = [[1], [3]]
941    else:
942      expected_values = [[0], [1], [2], [3]]
943    self._test_input_iteration(
944        input_type,
945        api_type,
946        iteration_type,
947        dataset_or_input_fn,
948        worker_device_pairs,
949        expected_values,
950        distribution,
951        input_context=distribution.extended._make_input_context())
952
953  @combinations.generate(
954      combinations.combine(
955          mode=["eager"],
956          distribution=[
957              strategy_combinations.multi_worker_mirrored_2x1_cpu,
958          ],
959          input_type=["input_fn"],
960          api_type=["wrap_into_dataset"],
961          iteration_type=["get_next", "for_loop"]))
962  def testDifferentDatasetsMultiWorker(self, distribution, input_type, api_type,
963                                       iteration_type):
964    cr = distribution.cluster_resolver
965    self.assertIsNotNone(cr)
966    id_in_cluster = multi_worker_util.id_in_cluster(cr.cluster_spec(),
967                                                    cr.task_type, cr.task_id)
968
969    def dataset_fn(ctx):
970      if ctx.input_pipeline_id == 0:
971        return dataset_ops.Dataset.range(8).batch(2)
972      else:
973        return dataset_ops.Dataset.range(9).batch(2)
974
975    dataset_or_input_fn = self._create_dataset_or_input_fn(
976        input_type, dataset_fn)
977
978    worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
979
980    if id_in_cluster == 0:
981      expected_values = [[[0, 1]], [[2, 3]], [[4, 5]], [[6, 7]], [[]]]
982    else:
983      expected_values = [[[0, 1]], [[2, 3]], [[4, 5]], [[6, 7]], [[8]]]
984    distribution.extended.experimental_enable_get_next_as_optional = True
985    self._test_input_iteration(input_type, api_type, iteration_type,
986                               dataset_or_input_fn, worker_device_pairs,
987                               expected_values, distribution)
988
989  @combinations.generate(
990      combinations.combine(
991          strategy=[
992              strategy_combinations.multi_worker_mirrored_2x1_cpu,
993              strategy_combinations.multi_worker_mirrored_2x1_gpu,
994          ],
995          mode=["eager"]))
996  def testLoopOverDatasetInTFFunction(self, strategy):
997    dataset = dataset_ops.Dataset.range(10).map(lambda x: {  # pylint: disable=g-long-lambda
998        "y": math_ops.cast(x, dtypes.float32) ** 2,
999    }).batch(4)
1000    dist_dataset = strategy.experimental_distribute_dataset(dataset)
1001
1002    with strategy.scope():
1003      v = variables.Variable(0.0, aggregation=variables.VariableAggregation.SUM)
1004
1005    @def_function.function
1006    def iterator_fn(dist_dataset):
1007
1008      def assign_add_fn(data):
1009        v.assign_add(math_ops.reduce_sum(data["y"]))
1010
1011      for data in dist_dataset:
1012        strategy.run(assign_add_fn, args=(data,))
1013
1014    iterator_fn(dist_dataset)
1015    self.assertEqual(v.numpy(), 285.0)
1016
1017
1018class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase,
1019                                        parameterized.TestCase):
1020  """Tests for DistributedDataset with non-dense tensors."""
1021
1022  @combinations.generate(
1023      combinations.combine(
1024          mode=["eager"],
1025          distribution=[
1026              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
1027              strategy_combinations.central_storage_strategy_with_gpu_and_cpu,
1028          ],
1029          input_type=["dataset", "input_fn"],
1030          drop_remainder=[False, True],
1031          defun_type=["lambda", "tf_function"],
1032      ))
1033  def testRaggedSparse(self, distribution, input_type, drop_remainder,
1034                       defun_type):
1035    """Test with `RaggedTensor`s and `SparseTensor`s."""
1036    if not tf2.enabled():
1037      self.skipTest("Only V2 is supported.")
1038
1039    defun = {"lambda": lambda f: f,
1040             "tf_function": def_function.function}[defun_type]
1041    distribution.extended.experimental_enable_get_next_as_optional = True
1042    global_batch_size = 8
1043
1044    def dataset_fn(ctx=None):
1045      ctx = ctx or distribute_lib.InputContext()
1046      batch_size = ctx.get_per_replica_batch_size(global_batch_size)
1047      # Use 20 which isn't divisible by 8 to test partial batch behavior.
1048      row_lengths = np.mod(np.arange(20), 4).astype(np.int64)
1049      ragged_tensor = ragged_tensor_lib.RaggedTensor.from_row_lengths(
1050          np.repeat(np.arange(20, dtype=np.float32), row_lengths), row_lengths)
1051      dataset = dataset_ops.DatasetV2.from_tensor_slices({
1052          "dense": ragged_tensor.to_tensor(),
1053          "ragged": ragged_tensor,
1054          "sparse": ragged_tensor.to_sparse(),
1055      })
1056      dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id)
1057      return dataset.batch(batch_size, drop_remainder=drop_remainder)
1058
1059    dataset_or_input_fn = self._create_dataset_or_input_fn(
1060        input_type, dataset_fn)
1061    dataset = self._wrap_dataset(input_type, dataset_or_input_fn,
1062                                 distribution.extended._input_workers,
1063                                 len(distribution.extended.worker_devices),
1064                                 distribution)
1065    # Assert that the tensors are rebatched and sparsity is preserved.
1066    per_replica_batch = defun(lambda x: next(iter(x)))(dataset)
1067    self.assertAllEqual(
1068        distribute_utils.select_replica(0, per_replica_batch["dense"]),
1069        [[0., 0., 0.], [1., 0., 0.], [2., 2., 0.], [3., 3., 3.]])
1070    self.assertAllEqual(
1071        distribute_utils.select_replica(1, per_replica_batch["dense"]),
1072        [[0., 0., 0.], [5., 0., 0.], [6., 6., 0.], [7., 7., 7.]])
1073    # Transitively check the ragged and sparse tensors by densification.
1074    for i in range(2):
1075      self.assertLen(
1076          distribute_utils.select_replica(i,
1077                                          per_replica_batch["ragged"]).values,
1078          6)
1079      self.assertAllEqual(
1080          distribute_utils.select_replica(
1081              i, per_replica_batch["ragged"]).to_tensor(),
1082          distribute_utils.select_replica(i, per_replica_batch["dense"]))
1083      self.assertLen(
1084          distribute_utils.select_replica(i,
1085                                          per_replica_batch["sparse"]).indices,
1086          6)
1087      self.assertAllEqual(
1088          sparse_ops.sparse_tensor_to_dense(
1089              distribute_utils.select_replica(i, per_replica_batch["sparse"])),
1090          distribute_utils.select_replica(i, per_replica_batch["dense"]))
1091    # Iterate through all the batches and sum them up.
1092    def sum_batch(per_replica_features):
1093      """Sums the `PerReplica` values in the `per_replica_features` map."""
1094
1095      def map_fn(per_replica_values):
1096        per_replica_sums = distribution.run(
1097            (lambda x: math_ops.reduce_sum(x.values)) if all(
1098                map(sparse_tensor.is_sparse, per_replica_values.values)) else
1099            math_ops.reduce_sum, (per_replica_values,))
1100        return distribution.reduce(
1101            reduce_util.ReduceOp.SUM, per_replica_sums, axis=None)
1102
1103      return nest.map_structure(map_fn, per_replica_features)
1104
1105    def _reduce(state, batch):
1106      sums = sum_batch(batch)
1107      return {name: value + sums[name] for name, value in state.items()}
1108
1109    def sum_for_loop(dataset):
1110      sums = {"dense": 0., "ragged": 0., "sparse": 0.}
1111      for batch in dataset:
1112        sums = _reduce(sums, batch)
1113      return sums
1114
1115    def sum_while_loop(iterator, reduce_fn):
1116      sums = {"dense": 0., "ragged": 0., "sparse": 0.}
1117      while True:
1118        try:
1119          sums = reduce_fn(sums, iterator)
1120        except (StopIteration, errors.OutOfRangeError):
1121          return sums
1122
1123    while_sums = sum_while_loop(
1124        iter(dataset),
1125        defun(lambda state, iterator: _reduce(state, next(iterator))))
1126    self.assertAllEqual(
1127        nest.flatten(while_sums),
1128        # When there's no partial batch, the sum is smaller.
1129        [200. if drop_remainder else 310.] * 3)
1130    for_sums = defun(sum_for_loop)(dataset)
1131    # For loops always call get next as optional inside tf functions, so we
1132    # expect 310 here when using an input function (as there are 5 batches of
1133    # size 4 round robined over 2 replicas.
1134    expected_for_sum = 200.
1135    if (not drop_remainder or (
1136        defun_type == "tf_function" and input_type == "input_fn")):
1137      expected_for_sum = 310.
1138    self.assertAllEqual(nest.flatten(for_sums), [expected_for_sum] * 3)
1139
1140  @combinations.generate(
1141      combinations.combine(
1142          mode=["eager"],
1143          distribution=[
1144              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
1145              strategy_combinations.central_storage_strategy_with_gpu_and_cpu,
1146              strategy_combinations.one_device_strategy,
1147              strategy_combinations.mirrored_strategy_with_one_cpu
1148          ],
1149          input_type=["dataset", "input_fn"],
1150          drop_remainder=[False, True],
1151          tensor_type=["sparse", "ragged"],
1152          enable_get_next_as_optional=[True, False]
1153      ))
1154  def testRaggedSparseGetNextAsOptional(
1155      self, distribution, input_type, drop_remainder, tensor_type,
1156      enable_get_next_as_optional):
1157    """Test with `RaggedTensor`s and `SparseTensor`s."""
1158    if not tf2.enabled():
1159      self.skipTest("Only V2 is supported.")
1160
1161    distribution.extended.experimental_enable_get_next_as_optional = (
1162        enable_get_next_as_optional)
1163    global_batch_size = 8
1164
1165    def dataset_fn(ctx=None):
1166      ctx = ctx or distribute_lib.InputContext()
1167      batch_size = ctx.get_per_replica_batch_size(global_batch_size)
1168      # Use 20 which isn't divisible by 8 to test partial batch behavior.
1169      row_lengths = np.mod(np.arange(20), 4).astype(np.int64)
1170      ragged_tensor = ragged_tensor_lib.RaggedTensor.from_row_lengths(
1171          np.repeat(np.arange(20, dtype=np.float32), row_lengths), row_lengths)
1172      dataset = dataset_ops.DatasetV2.from_tensor_slices({
1173          tensor_type: (ragged_tensor if tensor_type == "ragged" else
1174                        ragged_tensor.to_sparse()),
1175      })
1176      dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id)
1177      return dataset.batch(batch_size, drop_remainder=drop_remainder)
1178
1179    if input_type == "dataset":
1180      ds = distribution.experimental_distribute_dataset(
1181          dataset_fn(distribute_lib.InputContext()))
1182    else:
1183      ds = distribution.distribute_datasets_from_function(dataset_fn)
1184    iterator = iter(ds)
1185
1186    self.assertEqual(iterator._enable_get_next_as_optional,
1187                     (not drop_remainder) and enable_get_next_as_optional)
1188
1189  @combinations.generate(
1190      combinations.combine(
1191          tf_api_version=2,
1192          mode=["eager"],
1193          distribution=[
1194              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
1195              strategy_combinations.central_storage_strategy_with_gpu_and_cpu,
1196              strategy_combinations.one_device_strategy,
1197              strategy_combinations.mirrored_strategy_with_one_cpu,
1198              # TODO(mdan): Add these?
1199              # strategy_combinations.multi_worker_mirrored_2x1_cpu,
1200              # strategy_combinations.multi_worker_mirrored_2x1_gpu,
1201              # strategy_combinations.multi_worker_mirrored_2x2_gpu,
1202          ],
1203          input_type=["dataset", "input_fn"],
1204          drop_remainder=[False, True],
1205      ))
1206  def testRaggedSparseGetNextAsOptionalInLoop(
1207      self, distribution, input_type, drop_remainder):
1208    """Test with `RaggedTensor`s and `SparseTensor`s."""
1209    self.skipTest("b/323359921")
1210
1211    global_batch_size = 8
1212
1213    def dataset_fn(ctx=None):
1214      ctx = ctx or distribute_lib.InputContext()
1215      batch_size = ctx.get_per_replica_batch_size(global_batch_size)
1216      # Use 20 which isn't divisible by 8 to test partial batch behavior.
1217      row_lengths = np.mod(np.arange(20), 4).astype(np.int64)
1218      ragged_tensor = ragged_tensor_lib.RaggedTensor.from_row_lengths(
1219          np.repeat(np.arange(20, dtype=np.float32), row_lengths), row_lengths)
1220      dataset = dataset_ops.DatasetV2.from_tensor_slices({
1221          "dense": ragged_tensor.to_tensor(),
1222          "ragged": ragged_tensor,
1223          "sparse": ragged_tensor.to_sparse(),
1224      })
1225      dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id)
1226      return dataset.batch(batch_size, drop_remainder=drop_remainder)
1227
1228    if input_type == "dataset":
1229      ds = distribution.experimental_distribute_dataset(
1230          dataset_fn(distribute_lib.InputContext()))
1231    else:
1232      ds = distribution.distribute_datasets_from_function(dataset_fn)
1233
1234    # Iterate through all the batches and sum them up.
1235    def sum_batch(per_replica_features):
1236      """Sums the `PerReplica` values in the `per_replica_features` map."""
1237
1238      def map_fn(per_replica_values):
1239        per_replica_sums = distribution.run(
1240            (lambda x: math_ops.reduce_sum(x.values)) if all(
1241                map(sparse_tensor.is_sparse, per_replica_values.values)) else
1242            math_ops.reduce_sum, (per_replica_values,))
1243        return distribution.reduce(
1244            reduce_util.ReduceOp.SUM, per_replica_sums, axis=None)
1245
1246      return nest.map_structure(map_fn, per_replica_features)
1247
1248    def _reduce(state, batch):
1249      sums = sum_batch(batch)
1250      return {name: value + sums[name] for name, value in state.items()}
1251
1252    def sum_while_loop(ds):
1253      iterator = iter(ds)
1254      sums = {"dense": 0., "ragged": 0., "sparse": 0.}
1255      try_next = constant_op.constant(True)
1256
1257      while try_next:
1258        opt_iterate = iterator.get_next_as_optional()
1259        if opt_iterate.has_value():
1260          sums = _reduce(sums, opt_iterate.get_value())
1261        else:
1262          try_next = False
1263      return sums
1264
1265    sums = def_function.function(sum_while_loop)(ds)
1266    # For loops always call get next as optional inside tf functions, so we
1267    # expect 310 here when using an input function (as there are 5 batches of
1268    # size 4 round robined over 2 replicas.
1269    expected_for_sum = 200.
1270    if not drop_remainder or input_type == "input_fn":
1271      expected_for_sum = 310.
1272    self.assertAllEqual(nest.flatten(sums), [expected_for_sum] * 3)
1273
1274  @combinations.generate(
1275      combinations.combine(
1276          mode=["eager"],
1277          input_type=["dataset"],
1278          api_type=["wrap_into_iterator", "wrap_into_dataset"],
1279          iteration_type=["get_next", "for_loop"],
1280          distribution=[
1281              strategy_combinations.multi_worker_mirrored_2x1_cpu,
1282              strategy_combinations.multi_worker_mirrored_2x1_gpu,
1283          ]))
1284  def testMWMSPartialBatch(self, input_type, api_type, iteration_type,
1285                           distribution):
1286    # Test case: 2 workers, 1 replica each.
1287    # This test simulates the sharded behavior when we have two files each with
1288    # 12 elements and a global batch size of 8. When we consider the dataset in
1289    # aggregate (non-distributed), there are 24 elements divided into 3 batches
1290    # of size 8. Hence, the correct distributed behavior is for each replica to
1291    # see sub-batches of size 4, over three steps.
1292    def dataset_fn(ctx):
1293      del ctx
1294      dataset = dataset_ops.Dataset.range(12).batch(8)
1295
1296      # Set the sharding behavior to OFF for simplicity of test setup; namely,
1297      # `dataset` defines the per-worker dataset and will not be further
1298      # sharded. Each worker will see a dataset that is
1299      # tf.data.Dataset.range(12).batch(8).rebatch(...).
1300      options = dataset_ops.Options()
1301      options.experimental_distribute.auto_shard_policy = AutoShardPolicy.OFF
1302      dataset = dataset.with_options(options)
1303      return dataset
1304
1305    dataset = self._create_dataset_or_input_fn(input_type, dataset_fn)
1306
1307    # Actual devices don't matter in this test as long as there is 1 local
1308    # replica.
1309    worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
1310
1311    # Each test runs individually on each worker, so we compare the
1312    # values on each worker. Each worker should rebatch its dataset into
1313    # smaller batches of size 4.
1314    expected_values = [[[0, 1, 2, 3]], [[4, 5, 6, 7]], [[8, 9, 10, 11]]]
1315    self._test_input_iteration(
1316        input_type,
1317        api_type,
1318        iteration_type,
1319        dataset,
1320        worker_device_pairs,
1321        expected_values,
1322        distribution,
1323        num_replicas_in_sync=distribution.num_replicas_in_sync,
1324        input_context=distribution.extended._make_input_context())
1325
1326  @combinations.generate(
1327      combinations.combine(
1328          mode=["eager"],
1329          input_type=["dataset"],
1330          api_type=["wrap_into_iterator", "wrap_into_dataset"],
1331          iteration_type=["get_next", "for_loop"],
1332          distribution=[
1333              strategy_combinations.multi_worker_mirrored_2x1_cpu,
1334              strategy_combinations.multi_worker_mirrored_2x1_gpu,
1335          ]))
1336  def testMWMSPartialBatchWithLegacyRebatch(self, input_type, api_type,
1337                                            iteration_type, distribution):
1338    # Test case: 2 workers, 1 replica each.
1339    # This test simulates the sharded behavior when we have two files each with
1340    # 12 elements and a global batch size of 8. When we consider the dataset in
1341    # aggregate (non-distributed), there are 24 elements divided into 3 batches
1342    # of size 8. Hence, the correct distributed behavior is for each replica to
1343    # see sub-batches of size 4, over three steps. However, when we create a
1344    # DistributedDataset and cannot statically infer the intended global batch
1345    # size (e.g. if the user does not use a batching dataset), each worker will
1346    # rebatch based on the dynamic batch size of the data encountered, even when
1347    # it encounters partial batches. The last per-worker partial batch (size 4)
1348    # ends up being split into two replicas, resulting in 4 steps in total, of
1349    # (global) batch sizes 8, 8, 4, 4.
1350    def dataset_fn(ctx):
1351      del ctx
1352      # The following dataset is equivalent to
1353      # tf.data.Dataset.range(12).batch(8), but does not use a batching dataset.
1354      # This causes DistributedDataset to use LegacyRebatch instead.
1355      batch_sizes = dataset_ops.Dataset.from_tensor_slices([8, 4])
1356      offsets = dataset_ops.Dataset.from_tensor_slices([0, 8])
1357      dataset = dataset_ops.Dataset.zip((offsets, batch_sizes))
1358
1359      def map_fn(offset, batch_size):
1360        return math_ops.range(offset, offset + batch_size)
1361
1362      dataset = dataset.map(map_fn)
1363
1364      # Set the sharding behavior to OFF for simplicity of test setup; namely,
1365      # `dataset` defines the per-worker dataset and will not be further
1366      # sharded. Each worker will see a dataset that is equivalent to
1367      # tf.data.Dataset.range(12).batch(8).rebatch(...).
1368      options = dataset_ops.Options()
1369      options.experimental_distribute.auto_shard_policy = AutoShardPolicy.OFF
1370      dataset = dataset.with_options(options)
1371      return dataset
1372
1373    dataset = self._create_dataset_or_input_fn(input_type, dataset_fn)
1374
1375    # Actual devices don't matter in this test as long as the number of global
1376    # replicas is 2.
1377    worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
1378
1379    # Each test runs individually on each worker, so we compare the
1380    # values on each worker. Each worker should rebatch its dataset into
1381    # smaller batches of size 4.
1382    expected_values = [[[0, 1, 2, 3]], [[4, 5, 6, 7]], [[8, 9]], [[10, 11]]]
1383    self._test_input_iteration(
1384        input_type,
1385        api_type,
1386        iteration_type,
1387        dataset,
1388        worker_device_pairs,
1389        expected_values,
1390        distribution,
1391        num_replicas_in_sync=distribution.num_replicas_in_sync,
1392        input_context=distribution.extended._make_input_context())
1393
1394  @combinations.generate(
1395      combinations.combine(
1396          mode=["eager"],
1397          input_type=["dataset"],
1398          api_type=["wrap_into_iterator", "wrap_into_dataset"],
1399          iteration_type=["get_next", "for_loop"],
1400          distribution=[
1401              strategy_combinations.multi_worker_mirrored_2x1_cpu,
1402              strategy_combinations.multi_worker_mirrored_2x1_gpu,
1403          ],
1404          auto_shard_policy=[AutoShardPolicy.AUTO, AutoShardPolicy.DATA]))
1405  def testMWMSWithDataSharding(self, input_type, api_type, iteration_type,
1406                               distribution, auto_shard_policy):
1407    # Test case: 2 workers, 1 replica each.
1408    # This test simulates the sharded behavior the dataset is sharded by data
1409    # and the batch size is indivisible by the number of replicas. This checks
1410    # that the elements are as expected and the batch size across all workers
1411    # adds up to 3. This test will only pass if the autoshard rewrite rewrites
1412    # RebatchDatasetV2 to legacy RebatchDataset when sharding by data.
1413    def dataset_fn(ctx):
1414      del ctx
1415      dataset = dataset_ops.Dataset.range(8).batch(3)
1416
1417      # Set the sharding behavior to OFF for simplicity of test setup; namely,
1418      # `dataset` defines the per-worker dataset and will not be further
1419      # sharded. Each worker will see a dataset that is
1420      # tf.data.Dataset.range(12).batch(8).rebatch(...).
1421      options = dataset_ops.Options()
1422      options.experimental_distribute.auto_shard_policy = auto_shard_policy
1423      dataset = dataset.with_options(options)
1424      return dataset
1425
1426    dataset = self._create_dataset_or_input_fn(input_type, dataset_fn)
1427
1428    # Actual devices don't matter in this test as long as there is 1 local
1429    # replica.
1430    worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
1431
1432    # Each test runs individually on each worker, so we compare the
1433    # values on each worker. We expect each worker to see different shards of
1434    # data.
1435    cr = distribution.cluster_resolver
1436    worker_id = multi_worker_util.id_in_cluster(cr.cluster_spec(), cr.task_type,
1437                                                cr.task_id)
1438
1439    if worker_id == 0:
1440      expected_values = [[[0, 1]], [[3, 4]], [[6]]]
1441    elif worker_id == 1:
1442      expected_values = [[[2]], [[5]], [[7]]]
1443
1444    self._test_input_iteration(
1445        input_type,
1446        api_type,
1447        iteration_type,
1448        dataset,
1449        worker_device_pairs,
1450        expected_values,
1451        distribution,
1452        num_replicas_in_sync=distribution.num_replicas_in_sync,
1453        input_context=distribution.extended._make_input_context())
1454
1455
1456class DistributedIteratorPerDeviceTest(DistributedIteratorTestBase,
1457                                       parameterized.TestCase):
1458  """Tests for PER_WORKER and PER_REPLICA's InputOptions variants."""
1459
1460  def setUp(self):
1461    context._reset_context()
1462    strategy_combinations.set_virtual_cpus_to_at_least(3)
1463    super(DistributedIteratorPerDeviceTest, self).setUp()
1464
1465  @combinations.generate(
1466      combinations.combine(
1467          input_options=[
1468              distribute_lib.InputOptions(
1469                  experimental_place_dataset_on_device=False,
1470                  experimental_prefetch_to_device=True,
1471                  experimental_replication_mode=distribute_lib
1472                  .InputReplicationMode.PER_WORKER),
1473              distribute_lib.InputOptions(
1474                  experimental_place_dataset_on_device=False,
1475                  experimental_prefetch_to_device=True,
1476                  experimental_replication_mode=distribute_lib
1477                  .InputReplicationMode.PER_REPLICA),
1478          ],
1479          mode=["eager"],
1480          distribution=[
1481              strategy_combinations.mirrored_strategy_with_two_gpus,
1482              strategy_combinations.mirrored_strategy_with_cpu_1_and_2,
1483              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
1484          ]))
1485  def testDevicePlacementForPerWorkerValuesWithPrefetch(self, distribution,
1486                                                        input_options):
1487
1488    def dataset_fn(input_context):  # pylint: disable=[unused-argument]
1489      return dataset_ops.Dataset.from_tensor_slices([1, 2, 3, 4])
1490
1491    ds = distribution.experimental_distribute_datasets_from_function(
1492        dataset_fn, input_options)
1493
1494    for x in ds:
1495      assert x.values[0].device == distribution.extended.worker_devices[0]
1496      assert x.values[0].backing_device == distribution.extended.worker_devices[
1497          0]
1498      assert x.values[1].device == distribution.extended.worker_devices[1]
1499      assert x.values[1].backing_device == distribution.extended.worker_devices[
1500          1]
1501
1502  @combinations.generate(
1503      combinations.combine(
1504          distribution=[
1505              strategy_combinations.mirrored_strategy_with_two_gpus,
1506              strategy_combinations.mirrored_strategy_with_cpu_1_and_2,
1507              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
1508          ],
1509          input_options=[
1510              distribute_lib.InputOptions(
1511                  experimental_place_dataset_on_device=False,
1512                  experimental_prefetch_to_device=False,
1513                  experimental_replication_mode=distribute_lib
1514                  .InputReplicationMode.PER_WORKER)
1515          ],
1516          mode=["eager"],
1517      ))
1518  def testDevicePlacementForPerWorkerValuesWithoutPrefetch(
1519      self, distribution, input_options):
1520
1521    def dataset_fn(input_context):
1522      return dataset_ops.Dataset.from_tensor_slices(
1523          np.full(4, input_context.input_pipeline_id))
1524
1525    ds = distribution.experimental_distribute_datasets_from_function(
1526        dataset_fn, input_options)
1527
1528    for x in ds:
1529      x = distribution.run(lambda inputs: inputs, args=(x,))
1530      assert x.values[
1531          0].device == "/job:localhost/replica:0/task:0/device:CPU:0"
1532      assert x.values[
1533          0].backing_device == "/job:localhost/replica:0/task:0/device:CPU:0"
1534      assert x.values[
1535          1].device == "/job:localhost/replica:0/task:0/device:CPU:0"
1536      assert x.values[
1537          1].backing_device == "/job:localhost/replica:0/task:0/device:CPU:0"
1538
1539  @combinations.generate(
1540      combinations.combine(
1541          input_options=[
1542              distribute_lib.InputOptions(
1543                  experimental_place_dataset_on_device=True,
1544                  experimental_prefetch_to_device=False,
1545                  experimental_replication_mode=distribute_lib
1546                  .InputReplicationMode.PER_WORKER),
1547              distribute_lib.InputOptions(
1548                  experimental_place_dataset_on_device=True,
1549                  experimental_prefetch_to_device=True,
1550                  experimental_replication_mode=distribute_lib
1551                  .InputReplicationMode.PER_REPLICA)
1552          ],
1553          mode=["eager"],
1554          distribution=[
1555              strategy_combinations.mirrored_strategy_with_two_gpus,
1556              strategy_combinations.mirrored_strategy_with_cpu_1_and_2,
1557              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
1558          ]))
1559  def testDevicePlacementForInvalidCombinations(self, distribution,
1560                                                input_options):
1561
1562    def dataset_fn(input_context):
1563      return dataset_ops.Dataset.from_tensor_slices(
1564          np.full(4, input_context.input_pipeline_id))
1565
1566    with self.assertRaises(ValueError):
1567      distribution.experimental_distribute_datasets_from_function(
1568          dataset_fn, input_options)
1569
1570  @combinations.generate(
1571      combinations.combine(
1572          input_options=[
1573              distribute_lib.InputOptions(
1574                  experimental_place_dataset_on_device=False,
1575                  experimental_prefetch_to_device=False,
1576                  experimental_replication_mode=distribute_lib
1577                  .InputReplicationMode.PER_WORKER),
1578              distribute_lib.InputOptions(
1579                  experimental_place_dataset_on_device=False,
1580                  experimental_prefetch_to_device=True,
1581                  experimental_replication_mode=distribute_lib
1582                  .InputReplicationMode.PER_WORKER),
1583          ],
1584          mode=["eager"],
1585          distribution=[
1586              strategy_combinations.mirrored_strategy_with_two_gpus,
1587              strategy_combinations.mirrored_strategy_with_cpu_1_and_2,
1588              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
1589          ]))
1590  def testOutputValuesForPerWorkerInputOptions(self, distribution,
1591                                               input_options):
1592
1593    def dataset_fn(input_context):
1594      return dataset_ops.Dataset.from_tensor_slices(
1595          np.arange(1, 11).reshape(
1596              (2, 5)) * (input_context.input_pipeline_id + 1))
1597
1598    ds = distribution.experimental_distribute_datasets_from_function(
1599        dataset_fn, input_options)
1600
1601    # validating the values
1602    x = next(iter(ds))
1603    assert np.array_equal(x.values[0].numpy(), np.array([1, 2, 3, 4, 5]))
1604    assert np.array_equal(x.values[1].numpy(), np.array([6, 7, 8, 9, 10]))
1605
1606  @combinations.generate(
1607      combinations.combine(
1608          input_options=[
1609              distribute_lib.InputOptions(
1610                  experimental_place_dataset_on_device=True,
1611                  experimental_prefetch_to_device=False,
1612                  experimental_replication_mode=distribute_lib
1613                  .InputReplicationMode.PER_REPLICA),
1614              distribute_lib.InputOptions(
1615                  experimental_place_dataset_on_device=False,
1616                  experimental_prefetch_to_device=False,
1617                  experimental_replication_mode=distribute_lib
1618                  .InputReplicationMode.PER_REPLICA),
1619              distribute_lib.InputOptions(
1620                  experimental_place_dataset_on_device=False,
1621                  experimental_prefetch_to_device=True,
1622                  experimental_replication_mode=distribute_lib
1623                  .InputReplicationMode.PER_REPLICA),
1624          ],
1625          mode=["eager"],
1626          distribution=[
1627              strategy_combinations.mirrored_strategy_with_two_gpus,
1628              strategy_combinations.mirrored_strategy_with_cpu_1_and_2,
1629              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
1630          ]))
1631  def testOutputValuesForPerReplicaInputOptions(self, distribution,
1632                                                input_options):
1633
1634    def dataset_fn(input_context):
1635      return dataset_ops.Dataset.from_tensor_slices(
1636          np.arange(1, 10) * (input_context.input_pipeline_id + 1))
1637
1638    ds = distribution.experimental_distribute_datasets_from_function(
1639        dataset_fn, input_options)
1640    expected = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9])
1641    for i, x in enumerate(ds):
1642      # validating the values
1643      assert x.values[0].numpy() == expected[i]
1644      assert x.values[1].numpy() == expected[i] * 2
1645      loop_num = i
1646    assert loop_num == len(expected) - 1
1647
1648
1649class DistributedIteratorTfDataServiceTest(DistributedIteratorTestBase,
1650                                           parameterized.TestCase):
1651  """Tests for distributed iterators which read from tf.data service."""
1652
1653  def setUp(self):
1654    super(DistributedIteratorTfDataServiceTest, self).setUp()
1655    self.num_workers = 3
1656    if combinations.in_main_process():
1657      self.dispatcher = server_lib.DispatchServer()
1658      self.workers = []
1659      for _ in range(self.num_workers):
1660        self.workers.append(
1661            server_lib.WorkerServer(
1662                server_lib.WorkerConfig(
1663                    dispatcher_address=self.dispatcher.target.split("://")[1],
1664                    heartbeat_interval_ms=100,
1665                    dispatcher_timeout_ms=1000)))
1666      combinations.env().tf_data_service_dispatcher = self.dispatcher.target
1667
1668  @combinations.generate(
1669      combinations.combine(
1670          mode=["eager"],
1671          distribution=[
1672              strategy_combinations.one_device_strategy,
1673              strategy_combinations.mirrored_strategy_with_one_cpu,
1674              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
1675              strategy_combinations.tpu_strategy,
1676              strategy_combinations.central_storage_strategy_with_two_gpus,
1677              strategy_combinations.multi_worker_mirrored_2x2_gpu,
1678              strategy_combinations.multi_worker_mirrored_2x1_cpu,
1679          ]))
1680  def testTfDataService(self, distribution):
1681    worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
1682    input_workers = input_lib.InputWorkers(worker_device_pairs)
1683
1684    dataset = dataset_ops.Dataset.range(1, 50)
1685    dataset = dataset.apply(
1686        data_service_ops._distribute(
1687            processing_mode="parallel_epochs",
1688            service=combinations.env().tf_data_service_dispatcher,
1689            job_name="foo"))
1690
1691    dist_dataset = input_lib.get_distributed_dataset(dataset, input_workers,
1692                                                     distribution)
1693
1694    iterator = iter(dist_dataset)
1695    results = []
1696    for element in iterator:
1697      local_results = distribution.experimental_local_results(element)
1698      for result in local_results:
1699        # input_lib.distributed_dataset may add extra '0' elements to pad
1700        # per-replica results.
1701        if result.numpy() != 0:
1702          results.append(result.numpy())
1703    self.assertNotEmpty(results)
1704    gathered = distribution.gather(constant_op.constant(results), axis=0)
1705    self.assertCountEqual(self.num_workers * list(range(1, 50)), gathered)
1706
1707
1708if __name__ == "__main__":
1709  test_util.main()
1710