1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Library for testing DistributionStrategy descendants."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import functools
22import os
23import tempfile
24
25import numpy as np
26
27from tensorflow.core.protobuf import config_pb2
28from tensorflow.core.util import event_pb2
29from tensorflow.python.client import session as session_lib
30from tensorflow.python.data.ops import dataset_ops
31from tensorflow.python.distribute import distribute_lib
32from tensorflow.python.distribute import distribute_utils
33from tensorflow.python.distribute import distribution_strategy_context as ds_context
34from tensorflow.python.distribute import reduce_util
35from tensorflow.python.eager import backprop
36from tensorflow.python.eager import context
37from tensorflow.python.eager import def_function
38from tensorflow.python.eager import test
39from tensorflow.python.framework import dtypes
40from tensorflow.python.framework import errors
41from tensorflow.python.framework import ops
42from tensorflow.python.framework import test_util
43from tensorflow.python.lib.io import tf_record
44from tensorflow.python.ops import array_ops
45from tensorflow.python.ops import gen_math_ops
46from tensorflow.python.ops import gradients_impl
47from tensorflow.python.ops import init_ops
48from tensorflow.python.ops import init_ops_v2
49from tensorflow.python.ops import summary_ops_v2 as summary_ops
50from tensorflow.python.ops import variable_scope
51from tensorflow.python.ops import variables
52from tensorflow.python.platform import gfile
53from tensorflow.python.training import optimizer
54from tensorflow.python.training import training_util
55from tensorflow.python.util import nest
56from tensorflow.python.util import tf_inspect
57
58
59class _TestException(Exception):
60  pass
61
62
63# Conditionally wrap the fn in a def_function.function (so it runs in graph
64# mode).
65def _maybe_run_in_function(fn, run_in_function=False):
66  if not run_in_function or not context.executing_eagerly():
67    return fn
68  else:
69    return def_function.function()(fn)
70
71
72# May be the argument to either distribution.extended.call_for_each_replica() or
73# get_replica_context().merge_call()
74def _raise_exception_fn(_=None):
75  raise _TestException()
76
77
78# Must be the argument to a distribution.extended.call_for_each_replica() call,
79# calls a get_replica_context().merge_call() that raises an exception.
80def _merge_raises_fn():
81  ds_context.get_replica_context().merge_call(_raise_exception_fn)
82
83
84# Must be the argument to a get_replica_context().merge_call() call, calls
85# dist.extended.call_for_each_replica() with a function that raises an
86# exception.
87def _call_raises_fn(dist):
88  dist.extended.call_for_each_replica(_raise_exception_fn)
89
90
91# Must be the argument to a distribution.extended.call_for_each_replica() call,
92# calls a get_replica_context().merge_call() that calls a
93# call_for_each_replica() that raises an exception.
94def _merge_call_raises_fn():
95  ds_context.get_replica_context().merge_call(_call_raises_fn)
96
97
98# Must be the argument to a get_replica_context().merge_call() call, calls
99# dist.extended.call_for_each_replica() with a function that calls a
100# get_replica_context().merge_call() that raises an exception.
101def _call_merge_raises_fn(dist):
102  dist.extended.call_for_each_replica(_merge_raises_fn)
103
104
105# Must be the argument to a distribution.extended.call_for_each_replica() call,
106# calls a get_replica_context().merge_call() that calls a
107# call_for_each_replica() that calls a get_replica_context().merge_call() that
108# raises an exception.
109def _merge_call_merge_raises_fn():
110  ds_context.get_replica_context().merge_call(_call_merge_raises_fn)
111
112
113def _events_from_logdir(test_case, logdir):
114  """Reads summary events from log directory."""
115  test_case.assertTrue(gfile.Exists(logdir))
116  files = gfile.ListDirectory(logdir)
117  test_case.assertLen(files, 1)
118  records = list(tf_record.tf_record_iterator(os.path.join(logdir, files[0])))
119  result = []
120  for r in records:
121    event = event_pb2.Event()
122    event.ParseFromString(r)
123    result.append(event)
124  return result
125
126
127def create_variable_like_keras_layer(name, shape, dtype):
128  """Utitlity for create variables that works like variable in keras layer."""
129  initializer = functools.partial(
130      init_ops_v2.GlorotUniform(), shape, dtype=dtype)
131  return variables.Variable(
132      initial_value=initializer, name=name, trainable=True)
133
134
135def is_optimizer_v2_instance(optimizer_obj):
136  # For a optimizer instance, the v2 implementation has var_list as a required
137  # argument.
138  arg_spec = tf_inspect.getfullargspec(optimizer_obj.minimize)
139  return "var_list" in arg_spec.args[:-len(arg_spec.defaults)]
140
141
142class DistributionTestBase(test.TestCase):
143  """Some tests that should work with any DistributionStrategy."""
144
145  def _test_minimize_loss_eager(self, d):
146    with d.scope():
147      kernel = create_variable_like_keras_layer(
148          name="kernel", shape=(1, 1), dtype=dtypes.float32)
149      def loss(x):
150        y = array_ops.reshape(
151            gen_math_ops.mat_mul(x, kernel), []) - array_ops.identity(1.)
152        return y * y
153      # TODO(isaprykin): Extract implicit_grad+get_filtered_grad_fn into a
154      # common `implicit_grad` function and put it in DistributionStrategy.
155      grad_fn = backprop.implicit_grad(loss)
156      grad_fn = optimizer.get_filtered_grad_fn(grad_fn)
157
158      def update(v, g):
159        return v.assign_sub(0.2 * g)
160
161      one = array_ops.identity([[1.]])
162
163      def step():
164        """Perform one optimization step."""
165        # Run forward & backward to get gradients, variables list.
166        g_v = d.extended.call_for_each_replica(grad_fn, args=(one,))
167
168        # Update the variables using the gradients and the update() function.
169        before_list = []
170        after_list = []
171        for g, v in g_v:
172          fetched = d.extended.read_var(v)
173          before_list.append(fetched)
174          # control_dependencies irrelevant but harmless in eager execution
175          with ops.control_dependencies([fetched]):
176            g = d.extended.reduce_to(
177                reduce_util.ReduceOp.SUM, g, destinations=v)
178            with ops.control_dependencies(
179                d.extended.update(v, update, args=(g,), group=False)):
180              after_list.append(d.extended.read_var(v))
181        return before_list, after_list
182
183      for i in range(10):
184        b, a = step()
185        if i == 0:
186          before, = b  # pylint: disable=unbalanced-tuple-unpacking
187        after, = a  # pylint: disable=unbalanced-tuple-unpacking
188
189      error_before = abs(before.numpy() - 1)
190      error_after = abs(after.numpy() - 1)
191      # Error should go down
192      self.assertLess(error_after, error_before)
193
194  def _test_minimize_loss_graph(self,
195                                d,
196                                soft_placement=False,
197                                learning_rate=0.2):
198    config = config_pb2.ConfigProto()
199    config.allow_soft_placement = soft_placement
200    config.gpu_options.per_process_gpu_memory_fraction = 0.3
201    with context.graph_mode(), \
202         ops.Graph().as_default(), \
203         self.cached_session(config=config) as sess, \
204         d.scope():
205      kernel = create_variable_like_keras_layer(
206          name="kernel", shape=(1, 1), dtype=dtypes.float32)
207
208      def loss(x):
209        y = array_ops.reshape(
210            gen_math_ops.mat_mul(x, kernel), []) - array_ops.identity(1.)
211        return y * y
212
213      grad_fn = backprop.implicit_grad(loss)
214
215      def update(v, g):
216        return v.assign_sub(learning_rate * g)
217
218      one = array_ops.identity([[1.]])
219
220      def step():
221        """Perform one optimization step."""
222        # Run forward & backward to get gradients, variables list.
223        g_v = d.extended.call_for_each_replica(grad_fn, args=(one,))
224
225        # Update the variables using the gradients and the update() function.
226        before_list = []
227        after_list = []
228        for g, v in g_v:
229          fetched = d.extended.read_var(v)
230          before_list.append(fetched)
231          with ops.control_dependencies([fetched]):
232            g = d.extended.reduce_to(
233                reduce_util.ReduceOp.SUM, g, destinations=v)
234            with ops.control_dependencies(
235                d.extended.update(v, update, args=(g,), group=False)):
236              after_list.append(d.extended.read_var(v))
237        return before_list, after_list
238
239      before_out, after_out = step()
240      variables.global_variables_initializer().run()
241      for i in range(10):
242        b, a = sess.run((before_out, after_out))
243        if i == 0:
244          before, = b
245        after, = a
246
247      error_before = abs(before - 1)
248      error_after = abs(after - 1)
249      # Error should go down
250      self.assertLess(error_after, error_before)
251
252  def _test_summary_for_replica_zero_only(self, d):
253    logdir = tempfile.mkdtemp()
254
255    def run_fn():
256      """Function executed for each replica."""
257      with summary_writer.as_default():
258        replica_id = ds_context.get_replica_context().replica_id_in_sync_group
259        return summary_ops.write("a", replica_id)
260
261    with self.cached_session() as sess, d.scope(), \
262        summary_ops.always_record_summaries():
263      # We need global_step because summary writing op *always* has global_step
264      # as input, even when we always record summary or never record summary.
265      global_step = training_util.get_or_create_global_step()
266      if not context.executing_eagerly():
267        # When executing eagerly, variables are initialized immediately after
268        # creation, and its initializer will be None.
269        global_step.initializer.run()
270      summary_ops.set_step(0)
271      summary_writer = summary_ops.create_file_writer(logdir)
272      output = d.extended.call_for_each_replica(run_fn)
273      unwrapped = d.unwrap(output)
274      if not context.executing_eagerly():
275        sess.run(summary_writer.init())
276        sess.run(unwrapped)
277        sess.run(summary_writer.close())
278
279      events = _events_from_logdir(self, logdir)
280      # There will be 2 entries: 1 summary file header entry, and 1 entry
281      # written by replica 0.
282      self.assertLen(events, 2)
283      self.assertEqual(events[1].summary.value[0].tag, "a")
284      self.assertEqual(events[1].summary.value[0].simple_value, 0.0)
285
286  def _test_replica_id(self, d):
287    with d.scope():
288      expected_devices = [False] * len(d.extended.worker_devices)
289
290      def mark_devices_fn():
291        replica_id = self.evaluate(
292            ds_context.get_replica_context().replica_id_in_sync_group)
293        self.assertLess(replica_id, len(d.extended.worker_devices))
294        self.assertFalse(expected_devices[replica_id])
295        expected_devices[replica_id] = True
296
297      d.extended.call_for_each_replica(mark_devices_fn)
298      self.assertAllEqual(expected_devices,
299                          [True] * len(d.extended.worker_devices))
300
301  def _test_call_and_merge_exceptions(self, dist):
302    with dist.scope():
303      with self.assertRaises(_TestException):
304        dist.extended.call_for_each_replica(_raise_exception_fn)
305      with self.assertRaises(_TestException):
306        dist.extended.call_for_each_replica(_merge_raises_fn)
307      with self.assertRaises(_TestException):
308        dist.extended.call_for_each_replica(_merge_call_raises_fn)
309      with self.assertRaises(_TestException):
310        dist.extended.call_for_each_replica(_merge_call_merge_raises_fn)
311
312  def _input_fn_to_test_input_context(self, dataset_or_callable_fn,
313                                      expected_num_replicas_in_sync,
314                                      expected_num_input_pipelines,
315                                      expected_input_pipeline_id):
316    # Use a list of one element as counter so that it can be captured by the
317    # `_input_fn`. This counter is incremented by 1 each time an input_fn is
318    # called. We use this counter to check whether the `input_pipeline_id`
319    # matches the counter in the in-graph replication.
320    worker_id_counter = [0]
321
322    def _input_fn(input_context):
323      """Input fn for testing."""
324      self.assertIsNotNone(input_context)
325      self.assertEqual(expected_num_replicas_in_sync,
326                       input_context.num_replicas_in_sync)
327      self.assertEqual(expected_num_input_pipelines,
328                       input_context.num_input_pipelines)
329      if expected_input_pipeline_id is not None:
330        self.assertEqual(expected_input_pipeline_id,
331                         input_context.input_pipeline_id)
332      else:
333        self.assertEqual(worker_id_counter[0], input_context.input_pipeline_id)
334        worker_id_counter[0] += 1
335
336      return dataset_or_callable_fn()
337
338    return _input_fn
339
340  def _test_input_fn_iterable(
341      self, strategy, input_fn, expected_values, ignore_order=False):
342    assert_same = self.assertCountEqual if ignore_order else self.assertEqual
343
344    iterable = strategy.distribute_datasets_from_function(input_fn)
345    if context.executing_eagerly():
346      iterator = iter(iterable)
347
348      for expected_value in expected_values:
349        computed_value = self.evaluate(
350            list(strategy.experimental_local_results(next(iterator))))
351        assert_same(expected_value, computed_value)
352
353      with self.assertRaises(StopIteration):
354        self.evaluate(strategy.experimental_local_results(next(iterator)))
355
356      # After re-initializing the iterator, should be able to iterate again.
357      iterator = iter(iterable)
358
359      for expected_value in expected_values:
360        computed_value = self.evaluate(
361            list(strategy.experimental_local_results(next(iterator))))
362        assert_same(expected_value, computed_value)
363    else:
364      iterator = dataset_ops.make_initializable_iterator(iterable)
365      self._test_input_fn_iterator(iterator, strategy.extended.worker_devices,
366                                   expected_values, test_reinitialize=True,
367                                   ignore_order=ignore_order)
368
369  def _test_input_fn_iterator(self,
370                              iterator,
371                              devices,
372                              expected_values,
373                              sess=None,
374                              test_reinitialize=True,
375                              ignore_order=False):
376    evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)
377    evaluate(iterator.initializer)
378
379    for expected_value in expected_values:
380      next_element = iterator.get_next()
381      computed_value = evaluate(
382          [distribute_utils.select_replica(r, next_element) for r in
383           range(len(devices))])
384      if ignore_order:
385        self.assertCountEqual(expected_value, computed_value)
386      else:
387        self.assertEqual(expected_value, computed_value)
388
389    with self.assertRaises(errors.OutOfRangeError):
390      next_element = iterator.get_next()
391      evaluate(
392          [distribute_utils.select_replica(r, next_element) for r in
393           range(len(devices))])
394
395    # After re-initializing the iterator, should be able to iterate again.
396    if test_reinitialize:
397      evaluate(iterator.initializer)
398
399      for expected_value in expected_values:
400        next_element = iterator.get_next()
401        computed_value = evaluate([
402            distribute_utils.select_replica(r, next_element) for r in
403            range(len(devices))
404        ])
405        if ignore_order:
406          self.assertCountEqual(expected_value, computed_value)
407        else:
408          self.assertEqual(expected_value, computed_value)
409
410  def _test_global_step_update(self, strategy):
411    with strategy.scope():
412      global_step = variable_scope.get_variable(
413          "global_step",
414          shape=[],
415          dtype=dtypes.int64,
416          initializer=init_ops.zeros_initializer(),
417          trainable=False,
418          aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA)
419      self.evaluate(variables.global_variables_initializer())
420
421      def model_fn():
422        train_op = global_step.assign_add(1)
423        value = global_step.read_value()
424        return train_op, value
425
426      train_ops, value = strategy.extended.call_for_each_replica(model_fn)
427      self.evaluate(strategy.group(train_ops))
428      global_step_tensors = strategy.experimental_local_results(value)
429      global_step_values = self.evaluate(global_step_tensors)
430      self.assertEqual((1,) * len(global_step_tensors), global_step_values)
431
432  def _test_numpy_dataset(self, strategy, session=None, run_in_function=False):
433    if not isinstance(strategy, distribute_lib.StrategyV1):
434      self.skipTest("n/a: V1 only")
435    cached_session = session or self.cached_session()
436    with strategy.scope(), cached_session as sess:
437      x = np.asarray([[1, 2], [6, 12], [2, 4], [5, 10], [3, 6], [4, 8]])
438      y = np.asarray([5, 4, 3, 2, 1, 0])
439      batch_size = 6
440      if not strategy.extended._global_batch_size:  # pylint: disable=protected-access
441        batch_size = batch_size // strategy.num_replicas_in_sync
442
443      ds = strategy.extended.experimental_make_numpy_dataset(
444          (x, y), session=sess or self.cached_session())
445      ds = ds.repeat(2)  # 2 epochs
446      # We need to use the drop_remainder argument to get a known static
447      # input shape which is required for TPUs.
448      drop_remainder = strategy.extended.experimental_require_static_shapes
449      ds = ds.batch(batch_size, drop_remainder=drop_remainder)
450      i = strategy.make_dataset_iterator(ds)
451
452      self.evaluate(i.initializer)
453
454      def run_and_concatenate(strategy, i):
455        x, y = strategy.experimental_run(
456            _maybe_run_in_function(lambda z: z, run_in_function), i)
457        x, y = self.evaluate((strategy.experimental_local_results(x),
458                              strategy.experimental_local_results(y)))
459        return np.concatenate(x), np.concatenate(y)
460
461      x_1, y_1 = run_and_concatenate(strategy, i)
462      self.assertAllEqual(x, x_1)
463      self.assertAllEqual(y, y_1)
464      x_2, y_2 = run_and_concatenate(strategy, i)
465      self.assertAllEqual(x, x_2)
466      self.assertAllEqual(y, y_2)
467      with self.assertRaises(errors.OutOfRangeError):
468        run_and_concatenate(strategy, i)
469
470  def _test_trainable_variable(self, strategy):
471    for cls in [variables.VariableV1, variables.Variable]:
472      with strategy.scope():
473        v1 = cls(1.0)
474        self.assertEqual(True, v1.trainable)
475
476        v2 = cls(1.0, synchronization=variables.VariableSynchronization.ON_READ)
477        self.assertEqual(False, v2.trainable)
478
479        v3 = cls(1.0, synchronization=variables.VariableSynchronization.ON_READ,
480                 trainable=True)
481        self.assertEqual(True, v3.trainable)
482
483        v4 = cls(1.0, synchronization=variables.VariableSynchronization.ON_READ,
484                 trainable=False)
485        self.assertEqual(False, v4.trainable)
486
487
488class OneDeviceDistributionTestBase(test.TestCase):
489  """Some tests that should work with any one-device DistributionStrategy."""
490
491  def _test_run(self, strategy):
492    out1 = strategy.run(lambda: array_ops.identity(4.))
493    self.assertAllEqual([4.], self.evaluate(strategy.unwrap(out1)))
494
495    out2 = strategy.run(lambda x: {"a": x * 2, "b": x * x}, args=(out1,))
496    out2_vals = self.evaluate(nest.map_structure(strategy.unwrap, out2))
497    self.assertAllEqual([8.], out2_vals["a"])
498    self.assertAllEqual([16.], out2_vals["b"])
499
500    out3 = strategy.run(lambda b, a: a + 2 * b + 2, kwargs=out2)
501    self.assertAllEqual([42.], self.evaluate(strategy.unwrap(out3)))
502
503  def _test_all_reduce_sum(self, strategy):
504    self._test_collective_comms(
505        strategy, _all_sum, inputs=(4., [42., 43.]), expected=(4., [42., 43.]))
506
507  def _test_all_reduce_sum_gradients(self, strategy):
508    self._test_collective_comms_gradients(
509        strategy, _all_sum, inputs=[4.], expected_grads=[4.])
510
511  def _test_all_reduce_sum_gradient_tape(self, strategy):
512    self._test_collective_comms_gradient_tape(
513        strategy, _all_sum, inputs=[4.], expected_grads=[4.])
514
515  def _test_all_reduce_mean(self, strategy):
516    self._test_collective_comms(
517        strategy, _all_mean, inputs=(2., [21., 22.]), expected=(2., [21., 22.]))
518
519  def _test_all_reduce_mean_gradients(self, strategy):
520    self._test_collective_comms_gradients(
521        strategy, _all_mean, inputs=[5.], expected_grads=[5.])
522
523  def _test_all_reduce_mean_gradient_tape(self, strategy):
524    self._test_collective_comms_gradient_tape(
525        strategy, _all_mean, inputs=[5.], expected_grads=[5.])
526
527  def _test_collective_comms(self, strategy, comm_fn, inputs, expected):
528    inputs = strategy.make_input_fn_iterator(
529        lambda _: dataset_ops.Dataset.from_tensors(inputs))
530
531    self.evaluate(inputs.initialize())
532    outputs = self.evaluate(
533        list(
534            map(strategy.experimental_local_results,
535                strategy.experimental_run(comm_fn, inputs))))
536    self.assertAllEqual([expected[0]], outputs[0])
537    self.assertAllEqual([expected[1]], outputs[1])
538
539  def _test_collective_comms_gradients(self, strategy, comm_fn, inputs,
540                                       expected_grads):
541    if context.executing_eagerly():
542      self.skipTest("`tf.gradients` is not supported with eager execution.")
543
544    def step(c):
545      x = array_ops.identity(42.)
546      y = comm_fn(x) * c
547      return gradients_impl.gradients(y, [x])[0]
548
549    inputs = strategy.make_input_fn_iterator(
550        lambda _: dataset_ops.Dataset.from_tensors(inputs))
551
552    self.evaluate(inputs.initialize())
553    self.assertAllEqual(
554        expected_grads,
555        self.evaluate(
556            strategy.experimental_local_results(
557                strategy.experimental_run(step, inputs))))
558
559  def _test_collective_comms_gradient_tape(self, strategy, comm_fn, inputs,
560                                           expected_grads):
561
562    def step(c):
563      x = array_ops.identity(42.)
564      with backprop.GradientTape() as tape:
565        tape.watch(x)
566        y = comm_fn(x) * c
567      return tape.gradient(y, x)
568
569    inputs = strategy.make_input_fn_iterator(
570        lambda _: dataset_ops.Dataset.from_tensors(inputs))
571
572    self.evaluate(inputs.initialize())
573    self.assertAllEqual(
574        expected_grads,
575        self.evaluate(
576            strategy.experimental_local_results(
577                strategy.experimental_run(step, inputs))))
578
579  def _test_device_and_input_device_are_colocated(self, strategy):
580    if context.executing_eagerly():
581      self.skipTest(
582          "cross-device tests are not supported with eager execution.")
583    workers, _ = test_util.create_local_cluster(2, 0)
584    inputs = strategy.make_input_fn_iterator(
585        lambda _: dataset_ops.Dataset.range(5))
586    comm_fn = lambda x: x + 1
587    run_op = strategy.experimental_run(comm_fn, inputs)
588    with session_lib.Session(target=workers[1].target) as sess:
589      sess.run(inputs.initialize())
590      sess.run(run_op)
591
592  def _test_device_and_input_device_are_colocated_with_function(self, strategy):
593    if context.executing_eagerly():
594      self.skipTest(
595          "cross-device tests are not supported with eager execution.")
596    workers, _ = test_util.create_local_cluster(2, 0)
597    inputs = strategy.make_input_fn_iterator(
598        lambda _: dataset_ops.Dataset.range(5))
599    comm_fn = lambda x: x + 1
600    experimental_run = def_function.function()(strategy.experimental_run)
601    with ops.device("/job:worker/replica:0/task:1/device:CPU:0"):
602      # The tf.function must be defined on the right device as well.
603      run_op = experimental_run(comm_fn, inputs)
604    with session_lib.Session(target=workers[1].target) as sess:
605      sess.run(inputs.initialize())
606      sess.run(run_op)
607
608
609class TwoDeviceDistributionTestBase(test.TestCase):
610  """Some tests that should work with any two-device DistributionStrategy."""
611
612  def _test_run(self, strategy, run_in_function=False):
613    out1 = strategy.run(_maybe_run_in_function(
614        lambda: ds_context.get_replica_context().replica_id_in_sync_group + 1,
615        run_in_function))
616    self.assertAllEqual([1, 2], self.evaluate(strategy.unwrap(out1)))
617
618    out2 = strategy.run(_maybe_run_in_function(
619        lambda x: {"a": x * 2, "b": x * x}, run_in_function), args=(out1,))
620    out2_vals = self.evaluate(nest.map_structure(strategy.unwrap, out2))
621    self.assertAllEqual([2, 4], out2_vals["a"])
622    self.assertAllEqual([1, 4], out2_vals["b"])
623
624    out3 = strategy.run(_maybe_run_in_function(
625        lambda b, a: a + 2 * b + 2, run_in_function), kwargs=out2)
626    self.assertAllEqual([6, 14], self.evaluate(strategy.unwrap(out3)))
627
628  def _test_all_reduce_sum(self, strategy, run_in_function=False):
629    self._test_collective_comms(
630        strategy,
631        _all_sum,
632        inputs=([1., 3.], [[39., 2.], [3., 41.]]),
633        expected=(4., [42., 43.]),
634        run_in_function=run_in_function)
635
636  def _test_all_reduce_sum_gradients(self, strategy, run_in_function=False):
637    self._test_collective_comms_gradients(
638        strategy, _all_sum, inputs=[1., 3.], expected_grads=[4., 4.],
639        run_in_function=run_in_function)
640
641  def _test_all_reduce_sum_gradient_tape(self, strategy, run_in_function=False):
642    self._test_collective_comms_gradient_tape(
643        strategy, _all_sum, inputs=[1., 3.], expected_grads=[4., 4.],
644        run_in_function=run_in_function)
645
646  def _test_all_reduce_mean(self, strategy, run_in_function=False):
647    self._test_collective_comms(
648        strategy,
649        _all_mean,
650        inputs=([1., 3.], [[39., 2.], [3., 41.]]),
651        expected=(2., [21., 21.5]),
652        run_in_function=run_in_function)
653
654  def _test_all_reduce_mean_gradients(self, strategy, run_in_function=False):
655    self._test_collective_comms_gradients(
656        strategy, _all_mean, inputs=[1., 3.], expected_grads=[2., 2.],
657        run_in_function=run_in_function)
658
659  def _test_all_reduce_mean_gradient_tape(self, strategy,
660                                          run_in_function=False):
661    self._test_collective_comms_gradient_tape(
662        strategy, _all_mean, inputs=[1., 3.], expected_grads=[2., 2.],
663        run_in_function=run_in_function)
664
665  def _test_collective_comms(self, strategy, comm_fn, inputs, expected,
666                             run_in_function=False):
667    inputs = strategy.make_input_fn_iterator(
668        lambda _: dataset_ops.Dataset.from_tensor_slices(inputs))
669
670    self.evaluate(inputs.initialize())
671    outputs = self.evaluate(
672        list(
673            map(strategy.experimental_local_results,
674                strategy.experimental_run(
675                    _maybe_run_in_function(comm_fn, run_in_function), inputs))))
676    self.assertAllEqual([expected[0], expected[0]], outputs[0])
677    self.assertAllEqual([expected[1], expected[1]], outputs[1])
678
679  def _test_collective_comms_gradients(self, strategy, comm_fn, inputs,
680                                       expected_grads, run_in_function=False):
681    if context.executing_eagerly() and not run_in_function:
682      self.skipTest("`tf.gradients` is not supported with eager execution "
683                    "without using tf.functions.")
684
685    def step(c):
686      x = array_ops.identity(42.)
687      y = comm_fn(x) * c
688      return gradients_impl.gradients(y, [x])[0]
689
690    inputs = strategy.make_input_fn_iterator(
691        lambda _: dataset_ops.Dataset.from_tensor_slices(inputs))
692
693    self.evaluate(inputs.initialize())
694    self.assertAllEqual(
695        expected_grads,
696        self.evaluate(
697            strategy.experimental_local_results(
698                strategy.experimental_run(
699                    _maybe_run_in_function(step, run_in_function), inputs))))
700
701  def _test_collective_comms_gradient_tape(self, strategy, comm_fn, inputs,
702                                           expected_grads,
703                                           run_in_function=False):
704
705    def step(c):
706      x = array_ops.identity(42.)
707      with backprop.GradientTape() as tape:
708        tape.watch(x)
709        y = comm_fn(x) * c
710      return tape.gradient(y, x)
711
712    inputs = strategy.make_input_fn_iterator(
713        lambda _: dataset_ops.Dataset.from_tensor_slices(inputs))
714
715    self.evaluate(inputs.initialize())
716    self.assertAllEqual(
717        expected_grads,
718        self.evaluate(
719            strategy.experimental_local_results(
720                strategy.experimental_run(
721                    _maybe_run_in_function(step, run_in_function),
722                    inputs))))
723
724
725class RemoteSingleWorkerMirroredStrategyBase(DistributionTestBase):
726  """Tests for a Remote single worker."""
727
728  def _get_num_gpus(self):
729    pass
730
731  def _testNumReplicasInSync(self, distribution):
732    self.assertEqual(self._get_num_gpus(), distribution.num_replicas_in_sync)
733
734  def _testMinimizeLoss(self, distribution):
735    if context.executing_eagerly():
736      self._test_minimize_loss_eager(distribution)
737    else:
738      self._test_minimize_loss_graph(distribution, learning_rate=0.05)
739
740  def _testDeviceScope(self, distribution):
741    with distribution.scope():
742      a = array_ops.identity(1.)
743      with ops.device("/cpu:0"):
744        b = array_ops.identity(1.)
745      if context.executing_eagerly():
746        device = "/job:worker/replica:0/task:0/device:CPU:0"
747      else:
748        device = "/job:worker/replica:0/task:0"
749      self.assertEqual(a.device, device)
750      self.assertEqual(b.device, "/job:worker/replica:0/task:0/device:CPU:0")
751
752  def _testMakeInputFnIteratorWithDataset(self, distribution):
753    dataset_fn = lambda: dataset_ops.Dataset.range(100)
754    num_gpus = self._get_num_gpus()
755    num_workers = 1
756
757    expected_values = [[i+j for j in range(num_gpus)] * num_workers
758                       for i in range(0, 100, num_gpus)]
759
760    # Dummy cached_session is used in Eager
761    with self.cached_session() as sess:
762      # `expected_input_pipeline_id` is None because the input_fn will be called
763      # multiple times, each with a different input_pipeline_id.
764      input_fn = self._input_fn_to_test_input_context(
765          dataset_fn,
766          expected_num_replicas_in_sync=num_workers*num_gpus,
767          expected_num_input_pipelines=num_workers,
768          expected_input_pipeline_id=None)
769      iterator = distribution.make_input_fn_iterator(input_fn)
770      self._test_input_fn_iterator(
771          iterator, distribution.extended.worker_devices, expected_values, sess)
772
773  def _testMakeInputFnIteratorWithCallable(self, distribution):
774    def fn():
775      dataset = dataset_ops.Dataset.range(100)
776      it = dataset_ops.make_one_shot_iterator(dataset)
777      return it.get_next
778    num_gpus = self._get_num_gpus()
779    num_workers = 1
780
781    expected_values = []
782    for i in range(0, 100, num_gpus):
783      expected_values.append([i+j for j in range(num_gpus)] * num_workers)
784
785    # Dummy cached_session is used in Eager
786    with self.cached_session() as sess:
787      # `expected_input_pipeline_id` is None because the input_fn will be called
788      # multiple times, each with a different input_pipeline_id.
789      input_fn = self._input_fn_to_test_input_context(
790          fn,
791          expected_num_replicas_in_sync=num_workers*num_gpus,
792          expected_num_input_pipelines=num_workers,
793          expected_input_pipeline_id=None)
794      iterator = distribution.make_input_fn_iterator(input_fn)
795      self._test_input_fn_iterator(
796          iterator, distribution.extended.worker_devices, expected_values, sess,
797          test_reinitialize=False, ignore_order=True)
798
799
800def _all_sum(value):
801  ctx = ds_context.get_replica_context()
802  return ctx.all_reduce(reduce_util.ReduceOp.SUM, value)
803
804
805def _all_mean(value):
806  ctx = ds_context.get_replica_context()
807  return ctx.all_reduce(reduce_util.ReduceOp.MEAN, value)
808