1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import os
21import threading
22
23from absl.testing import parameterized
24
25from tensorflow.python.distribute.parallel_device import parallel_device
26from tensorflow.python.eager import backprop
27from tensorflow.python.eager import context
28from tensorflow.python.eager import def_function
29from tensorflow.python.framework import config
30from tensorflow.python.framework import constant_op
31from tensorflow.python.framework import ops
32from tensorflow.python.module import module
33from tensorflow.python.ops import array_ops
34from tensorflow.python.ops import collective_ops
35from tensorflow.python.ops import control_flow_ops
36from tensorflow.python.ops import gen_resource_variable_ops
37from tensorflow.python.ops import math_ops
38from tensorflow.python.ops import variables
39from tensorflow.python.platform import test
40from tensorflow.python.saved_model import load
41from tensorflow.python.saved_model import save
42from tensorflow.python.training import checkpoint_management
43from tensorflow.python.training.tracking import util as tracking
44from tensorflow.python.util import nest
45
46# When running collectives asynchronously, we need to give each parallel device
47# execution a unique ID so the collectives don't interfere. Since the op is
48# replicated with group/instance key intact, the replicated nodes will
49# communicate.
50# TODO(allenl): Switch to using a collective manager.
51_COUNTER_LOCK = threading.Lock()
52_COUNTER = 100
53
54
55def _collective_reduce(inputs, operation, num_replicas):
56
57  def _reduce_tensor(tensor):
58    with _COUNTER_LOCK:
59      global _COUNTER
60      keys = _COUNTER
61      _COUNTER += 1
62    return collective_ops.all_reduce(
63        t=tensor,
64        group_size=num_replicas,
65        merge_op=operation,
66        group_key=keys,
67        instance_key=keys,
68        final_op="Id")
69
70  return nest.map_structure(_reduce_tensor, inputs)
71
72
73def _collective_sum(inputs, num_replicas):
74  return _collective_reduce(
75      inputs=inputs, operation="Add", num_replicas=num_replicas)
76
77
78class _Dense(module.Module):
79
80  def __init__(self, output_size):
81    self.output_size = output_size
82    self.kernel = None
83    self.bias = None
84
85  def __call__(self, x):
86    if self.kernel is None:
87      self.kernel = variables.Variable(
88          array_ops.ones(
89              array_ops.stack([self.output_size,
90                               array_ops.shape(x)[-1]])))
91      self.bias = variables.Variable(array_ops.ones([self.output_size]))
92    return math_ops.matmul(x, self.kernel, transpose_b=True) + self.bias
93
94
95class _VirtualDeviceTestCase(test.TestCase):
96
97  def setUp(self):
98    super(_VirtualDeviceTestCase, self).setUp()
99    ctx = context.context()
100    if ctx.list_physical_devices("TPU"):
101      self.device_type = "TPU"
102    elif ctx.list_physical_devices("GPU"):
103      self.device_type = "GPU"
104      gpus = ctx.list_physical_devices(self.device_type)
105      ctx.set_logical_device_configuration(gpus[0], [
106          context.LogicalDeviceConfiguration(memory_limit=100),
107          context.LogicalDeviceConfiguration(memory_limit=100),
108      ])
109    else:
110      self.device_type = "CPU"
111      cpus = ctx.list_physical_devices("CPU")
112      ctx.set_logical_device_configuration(cpus[0], [
113          context.LogicalDeviceConfiguration(),
114          context.LogicalDeviceConfiguration(),
115      ])
116
117    self.device = parallel_device.ParallelDevice(components=[
118        "/job:localhost/device:{}:0".format(self.device_type),
119        self.device_type + ":1"
120    ])
121    self.assertIn(self.device_type + ":0", self.device.components[0])
122    self.assertIn(self.device_type + ":1", self.device.components[1])
123
124
125class ParallelDeviceTests(_VirtualDeviceTestCase, parameterized.TestCase):
126
127  def test_register_parallel_device(self):
128    with self.device:
129      c = constant_op.constant(1.)
130      d = constant_op.constant(2.)
131      e = c + d
132      outputs = self.device.unpack(e)
133    self.assertAllClose([3., 3.], outputs)
134
135    self.assertIn(self.device.components[0], outputs[0].backing_device)
136    self.assertIn(self.device.components[1], outputs[1].backing_device)
137
138  def test_device_id(self):
139    device_ids = self.device.unpack(self.device.device_ids)
140    self.assertAllClose([0, 1], device_ids)
141    # TODO(allenl): Should device IDs be int64 so they can be placed on GPUs?
142    # Currently backing_device is CPU.
143    self.assertIn(self.device.components[0], device_ids[0].device)
144    self.assertIn(self.device.components[1], device_ids[1].device)
145
146  def test_collective_reduce(self):
147    if self.device_type == "TPU":
148      self.skipTest("ParallelDevice collectives on TPUs need work")
149    with self.device:
150      x = self.device.pack(
151          [constant_op.constant(-1.5),
152           constant_op.constant(3.5)])
153      reduced = _collective_sum(x, num_replicas=2)
154      outputs = self.device.unpack(reduced)
155    self.assertAllClose([2., 2.], outputs)
156    self.assertIn(self.device.components[0], outputs[0].backing_device)
157    self.assertIn(self.device.components[1], outputs[1].backing_device)
158
159  def test_collective_reduce_async_scope(self):
160    if self.device_type == "TPU":
161      self.skipTest("ParallelDevice collectives on TPUs need work")
162    # Note that ops on the parallel device currently don't execute
163    # asynchronously. The test is just that we don't get deadlocks.
164    with context.async_scope(), self.device:
165      x = self.device.pack(
166          [constant_op.constant(-1.5),
167           constant_op.constant(3.5)])
168      reduced = _collective_sum(x, num_replicas=2)
169      outputs = self.device.unpack(reduced)
170    self.assertAllClose([2., 2.], outputs)
171    self.assertIn(self.device.components[0], outputs[0].backing_device)
172    self.assertIn(self.device.components[1], outputs[1].backing_device)
173
174  def test_collective_reduce_async_context(self):
175    if self.device_type == "TPU":
176      self.skipTest("ParallelDevice collectives on TPUs need work")
177    previous = config.get_synchronous_execution()
178    try:
179      context._reset_context()
180      config.set_synchronous_execution(False)
181      self.setUp()
182      # Note that ops on the parallel device currently don't execute
183      # asynchronously. The test is just that we don't get deadlocks.
184      with self.device:
185        x = self.device.pack(
186            [constant_op.constant(-1.5),
187             constant_op.constant(3.5)])
188        reduced = _collective_sum(x, num_replicas=2)
189        outputs = self.device.unpack(reduced)
190      self.assertAllClose([2., 2.], outputs)
191      self.assertIn(self.device.components[0], outputs[0].backing_device)
192      self.assertIn(self.device.components[1], outputs[1].backing_device)
193    finally:
194      context._reset_context()
195      config.set_synchronous_execution(previous)
196
197  @parameterized.named_parameters(
198      [("RunFunctionsEagerly", True),
199       ("", False)])
200  def test_cond(self, run_functions_eagerly):
201    try:
202      def_function.run_functions_eagerly(run_functions_eagerly)
203      with self.device:
204        pred = self.device.pack([True, False])
205        capture = self.device.pack([[1.], [2.]])
206        result = control_flow_ops.cond(
207            pred,
208            def_function.function(lambda: capture * 2.),
209            def_function.function(lambda: capture * 4.))
210      self.assertAllClose(
211          [[2.], [8.]], self.device.unpack(result))
212    finally:
213      def_function.run_functions_eagerly(False)
214
215  def test_cond_with_variable(self):
216    with self.device:
217      pred = self.device.pack([True, False])
218      capture = self.device.pack([[1.], [2.]])
219      v = None
220      @def_function.function
221      def true_branch():
222        nonlocal v
223        if v is None:
224          v = variables.Variable(constant_op.constant(2.))
225        return v * capture
226      result = control_flow_ops.cond(
227          pred, true_branch, def_function.function(lambda: capture * 4.))
228    self.assertAllClose(
229        [[2.], [8.]], self.device.unpack(result))
230    self.assertAllClose(
231        [2., 2.], self.device.unpack(v))
232    # There are two unique variable handles with separate storage.
233    h1, _ = self.device.unpack(v.handle)
234    gen_resource_variable_ops.assign_variable_op(h1, constant_op.constant(3.))
235    self.assertAllClose(
236        [3., 2.], self.device.unpack(v))
237
238  def test_collective_in_function(self):
239    if self.device_type == "TPU":
240      self.skipTest("ParallelDevice collectives on TPUs need work")
241    c = constant_op.constant([2])
242
243    @def_function.function
244    def broadcast_send_recv(device_id):
245
246      @def_function.function
247      def send():
248        s0 = collective_ops.broadcast_send(
249            c * 3, c.shape, c.dtype, group_size=2, group_key=1, instance_key=1)
250        with ops.control_dependencies([s0.op]):
251          return array_ops.identity(c)
252
253      @def_function.function
254      def recv():
255        r0 = collective_ops.broadcast_recv(
256            c.shape, c.dtype, group_size=2, group_key=1, instance_key=1)
257        return r0
258
259      return control_flow_ops.switch_case(
260          device_id, branch_fns={0: send, 1: recv})
261
262    with self.device:
263      result = broadcast_send_recv(self.device.device_ids)
264    self.assertAllClose([[2], [6]], self.device.unpack(result))
265
266  def test_use_in_graph_error_is_informative(self):
267    @def_function.function
268    def uses_parallel():
269      with self.device:
270        return self.device.unpack(array_ops.ones([]))
271
272    with self.assertRaisesRegex(NotImplementedError, "inside `tf.function`"):
273      uses_parallel()
274
275  def test_checkpointing(self):
276    prefix = os.path.join(self.get_temp_dir(), "ckpt")
277    with self.device:
278      different_values = self.device.pack(
279          [constant_op.constant(-1.),
280           constant_op.constant(3.)])
281      v = variables.Variable(different_values)
282      checkpoint = tracking.Checkpoint(v=v)
283    save_path = checkpoint.save(prefix)
284    with self.device:
285      v.assign(constant_op.constant(0.))
286    checkpoint.restore(save_path).assert_consumed()
287    with self.device:
288      outputs = self.device.unpack(v)
289    self.assertAllClose([-1., 3.], outputs)
290
291    with self.device:
292      restore_on_create = tracking.Checkpoint()
293      restore_on_create.restore(save_path)
294      restore_on_create.v = variables.Variable(0.)
295      outputs = self.device.unpack(restore_on_create.v)
296    self.assertAllClose([-1., 3.], outputs)
297
298    # Changing the number of devices / restoring into a single-device copy is OK
299    single_device = tracking.Checkpoint(v=variables.Variable(0.))
300    status = single_device.restore(save_path)
301    status.assert_existing_objects_matched()
302    self.assertAllClose(-1., single_device.v)
303    with self.assertRaisesRegex(AssertionError, "parallel_component_1"):
304      # There are parts of the variable that aren't restored into a
305      # single-device copy.
306      status.assert_consumed()
307
308  def test_saved_model(self):
309    with self.device:
310      different_values = self.device.pack(
311          [constant_op.constant(-1.),
312           constant_op.constant(3.)])
313      m = module.Module()
314      m.v = variables.Variable(different_values)
315      m.f = def_function.function(lambda: m.v * 2.)
316      self.assertAllClose([-2., 6.], self.device.unpack(m.f()))
317    saved_model_path = os.path.join(self.get_temp_dir(), "saved_model")
318    save.save(m, saved_model_path)
319
320    context._reset_context()
321    self.setUp()
322
323    single_device_loaded = load.load(saved_model_path)
324    self.assertAllClose(-2., single_device_loaded.f())
325    with self.device:
326      parallel_loaded = load.load(saved_model_path)
327      self.assertAllClose([-2., 6.], self.device.unpack(parallel_loaded.f()))
328      self.assertAllClose([-1., 3.], self.device.unpack(parallel_loaded.v))
329      parallel_loaded.v.assign(self.device.pack([.1, .2]))
330      self.assertAllClose([.2, .4], self.device.unpack(parallel_loaded.f()))
331
332  def _assert_close_to_non_parallel(self, computation):
333    """Asserts that replication of `computation` works and is equivalent."""
334    with self.device:
335      parallel_result = computation()
336    non_parallel_result = computation()
337    # The computations should have the same number and structure of Tensor
338    # objects, even though the tensors themselves will be on different devices
339    # and represent different numbers of values.
340    nest.assert_same_structure(parallel_result, non_parallel_result)
341    non_parallel_flat = nest.flatten(non_parallel_result)
342    parallel_flat = nest.flatten(parallel_result)
343    self.assertGreater(len(parallel_flat), 0)
344    for non_parallel, parallel in zip(non_parallel_flat, parallel_flat):
345      self.assertEqual(self.device._name, parallel.device)
346      self.assertNotEqual(self.device._name, non_parallel.device)
347      for parallel_component in self.device.unpack(parallel):
348        self.assertAllClose(non_parallel, parallel_component)
349
350  def test_capturing(self):
351    with self.device:
352      x = constant_op.constant([1., 2.])
353      x = array_ops.identity(x)
354
355      @def_function.function
356      def f(y):
357        return x + y
358
359      y = array_ops.ones([2])
360      parallel_result = f(y)
361    self.assertAllClose([[2., 3.]] * 2, self.device.unpack(parallel_result))
362
363  def test_euclidean_norm(self):
364    def _test_fn():
365      with backprop.GradientTape() as tape:
366        x = array_ops.ones([5, 5])
367        tape.watch(x)
368        y = math_ops.reduce_euclidean_norm(x, axis=constant_op.constant(1))
369      return y, tape.gradient(y, x)
370    self._assert_close_to_non_parallel(_test_fn)
371
372  def test_reduce_sum(self):
373    def _test_fn():
374      with backprop.GradientTape() as tape:
375        x = array_ops.ones([5, 5])
376        tape.watch(x)
377        y = math_ops.reduce_sum(x, axis=constant_op.constant(1))
378      return y, tape.gradient(y, x)
379    self._assert_close_to_non_parallel(_test_fn)
380
381  def test_variable_created_in_function(self):
382
383    class M(module.Module):
384
385      def __init__(self):
386        self.v = None
387        self.w = None
388        self.x = None
389        self.z = None
390
391      @def_function.function(autograph=False)
392      def __call__(self, x):
393        if self.v is None:
394          with ops.init_scope():
395            initial_value = constant_op.constant(2.)
396            self.z = variables.Variable(initial_value)
397          self.x = variables.Variable(initial_value)
398          self.w = variables.Variable(lambda: constant_op.constant(2.))
399          self.v = variables.Variable(constant_op.constant(2.))
400        return x * self.v * self.w * self.x * self.z
401
402    with self.device:
403      m = M()
404      packed_outputs = m(array_ops.ones([]))
405      outputs = self.device.unpack(packed_outputs)
406    self.assertAllClose([16., 16.], outputs)
407
408  def test_different_shapes(self):
409    with self.device:
410      x = self.device.pack(
411          [constant_op.constant([1., 2.]),
412           constant_op.constant([5.])])
413      y = x * 2.
414    with self.assertRaisesRegex(Exception,
415                                "components do not all have the same shape"):
416      y.shape  # pylint: disable=pointless-statement
417    self.assertAllClose([[2., 4.], [10.]], self.device.unpack(y))
418
419    different_axes = self.device.pack(
420        [constant_op.constant([1., 2.]),
421         constant_op.constant([[5.]])])
422    with self.assertRaisesRegex(Exception,
423                                "components do not all have the same shape"):
424      different_axes.shape  # pylint: disable=pointless-statement
425
426
427class LayerTests(_VirtualDeviceTestCase):
428
429  def test_layer_forward(self):
430    with self.device:
431      layer = _Dense(5)
432      x = constant_op.constant([[2.]])
433      y = layer(x)
434      outputs = self.device.unpack(y)
435    self.assertAllClose([[3.] * 5], outputs[0])
436    self.assertAllClose([[3.] * 5], outputs[1])
437    self.assertIn(self.device.components[0], outputs[0].backing_device)
438    self.assertIn(self.device.components[1], outputs[1].backing_device)
439
440    # With different Layer inputs we get different outputs
441    with self.device:
442      x = self.device.pack(
443          [constant_op.constant([[-0.5]]),
444           constant_op.constant([[0.5]])])
445      y = layer(x)
446      outputs = self.device.unpack(y)
447    self.assertGreater(
448        math_ops.reduce_max(math_ops.abs(outputs[0] - outputs[1])), 1e-5)
449    self.assertIn(self.device.components[0], outputs[0].backing_device)
450    self.assertIn(self.device.components[1], outputs[1].backing_device)
451
452  def test_layer_sync_training(self):
453    if self.device_type == "TPU":
454      self.skipTest("ParallelDevice collectives on TPUs need work")
455    with self.device:
456      layer = _Dense(5)
457
458      with backprop.GradientTape() as tape:
459        x = self.device.pack(
460            [constant_op.constant([[-0.5]]),
461             constant_op.constant([[0.5]])])
462        y = layer(x)
463        loss = (y - math_ops.range(5.))**2.
464      parameters = layer.trainable_variables
465      unreduced_gradients = tape.gradient(loss, parameters)
466      reduced_gradients = _collective_sum(unreduced_gradients, num_replicas=2)
467      for grad, param in zip(reduced_gradients, parameters):
468        param.assign_sub(0.01 * grad)
469    final_kernels = self.device.unpack(layer.kernel)
470    self.assertAllClose(final_kernels[0], final_kernels[1])
471    final_bias = self.device.unpack(layer.bias)
472    expected_bias = (1. - 0.01 * 2. * (1. + .5 - math_ops.range(5.)) -
473                     0.01 * 2. * (1. - .5 - math_ops.range(5.)))
474    self.assertAllClose(expected_bias, final_bias[0])
475    self.assertAllClose(expected_bias, final_bias[1])
476    self.assertIn(self.device.components[0], final_kernels[0].backing_device)
477    self.assertIn(self.device.components[1], final_kernels[1].backing_device)
478
479  def test_layer_divergent_buffer_training(self):
480    with self.device:
481      layer = _Dense(5)
482
483      with backprop.GradientTape() as tape:
484        x = self.device.pack(
485            [constant_op.constant([[-0.5]]),
486             constant_op.constant([[0.5]])])
487        y = layer(x)
488        loss = (y - math_ops.range(5.))**2.
489      parameters = layer.trainable_variables
490      unreduced_gradients = tape.gradient(loss, parameters)
491      for grad, param in zip(unreduced_gradients, parameters):
492        param.assign_sub(0.01 * grad)
493    final_kernels = self.device.unpack(layer.kernel)
494    self.assertNotAllClose(final_kernels[0], final_kernels[1])
495    final_bias = self.device.unpack(layer.bias)
496    self.assertAllClose(1. - 0.01 * 2. * (1. - .5 - math_ops.range(5.)),
497                        final_bias[0])
498    self.assertAllClose(1. - 0.01 * 2. * (1. + .5 - math_ops.range(5.)),
499                        final_bias[1])
500    self.assertIn(self.device.components[0], final_kernels[0].backing_device)
501    self.assertIn(self.device.components[1], final_kernels[1].backing_device)
502
503  def test_training_loop(self):
504    if self.device_type == "TPU":
505      self.skipTest("ParallelDevice collectives on TPUs need work")
506    for _ in range(5):
507      layer = _Dense(5)
508      checkpoint = tracking.Checkpoint(layer=layer)
509      manager = checkpoint_management.CheckpointManager(
510          checkpoint, directory=self.get_temp_dir(), max_to_keep=5)
511      manager.restore_or_initialize()
512
513      for _ in range(10):
514        with self.device:
515          with backprop.GradientTape() as tape:
516            x = self.device.pack(
517                [constant_op.constant([[-0.5]]),
518                 constant_op.constant([[0.5]])])
519            y = layer(x)
520            loss = (y - math_ops.range(5.))**2.
521          parameters = layer.trainable_variables
522          unreduced_gradients = tape.gradient(loss, parameters)
523          reduced_gradients = _collective_sum(
524              unreduced_gradients, num_replicas=len(self.device.components))
525          for grad, param in zip(reduced_gradients, parameters):
526            param.assign_sub(0.01 * grad)
527
528        manager.save()
529
530
531if __name__ == "__main__":
532  ops.enable_eager_execution()
533  test.main()
534