1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for ParameterServerStrategy."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import copy
22import threading
23
24from absl.testing import parameterized
25from tensorflow.core.protobuf import config_pb2
26from tensorflow.python.data.ops import dataset_ops
27from tensorflow.python.distribute import central_storage_strategy
28from tensorflow.python.distribute import combinations
29from tensorflow.python.distribute import device_util
30from tensorflow.python.distribute import distribute_lib
31from tensorflow.python.distribute import distribute_utils
32from tensorflow.python.distribute import distribution_strategy_context as ds_context
33from tensorflow.python.distribute import input_lib
34from tensorflow.python.distribute import multi_worker_test_base
35from tensorflow.python.distribute import multi_worker_util
36from tensorflow.python.distribute import parameter_server_strategy
37from tensorflow.python.distribute import ps_values
38from tensorflow.python.distribute import reduce_util
39from tensorflow.python.distribute import strategy_test_lib
40from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
41from tensorflow.python.eager import backprop
42from tensorflow.python.eager import context
43from tensorflow.python.estimator import run_config
44from tensorflow.python.framework import constant_op
45from tensorflow.python.framework import device as tf_device
46from tensorflow.python.framework import dtypes
47from tensorflow.python.framework import errors
48from tensorflow.python.framework import ops
49from tensorflow.python.framework import tensor_util
50from tensorflow.python.ops import array_ops
51from tensorflow.python.ops import control_flow_ops
52from tensorflow.python.ops import gradients
53from tensorflow.python.ops import math_ops
54from tensorflow.python.ops import partitioned_variables
55from tensorflow.python.ops import resource_variable_ops
56from tensorflow.python.ops import variable_scope
57from tensorflow.python.ops import variables
58from tensorflow.python.platform import test
59from tensorflow.python.training import training_util
60
61CHIEF = run_config.TaskType.CHIEF
62WORKER = run_config.TaskType.WORKER
63PS = run_config.TaskType.PS
64
65
66def _get_replica_id_integer():
67  replica_id = ds_context.get_replica_context().replica_id_in_sync_group
68  if isinstance(replica_id, ops.Tensor):
69    replica_id = tensor_util.constant_value(replica_id)
70  return replica_id
71
72
73def create_test_objects(cluster_spec=None,
74                        task_type=None,
75                        task_id=None,
76                        num_gpus=None,
77                        sess_config=None):
78  sess_config = sess_config or config_pb2.ConfigProto()
79  if num_gpus is None:
80    num_gpus = context.num_gpus()
81  if cluster_spec and task_type and task_id is not None:
82    cluster_resolver = SimpleClusterResolver(
83        cluster_spec=multi_worker_util.normalize_cluster_spec(cluster_spec),
84        task_type=task_type,
85        task_id=task_id,
86        num_accelerators={'GPU': num_gpus})
87    distribution = parameter_server_strategy.ParameterServerStrategyV1(
88        cluster_resolver)
89    target = 'grpc://' + cluster_spec[WORKER][task_id]
90  else:
91    distribution = (
92        central_storage_strategy.CentralStorageStrategy._from_num_gpus(num_gpus)
93    )
94    target = ''
95
96  sess_config = copy.deepcopy(sess_config)
97  sess_config = distribution.update_config_proto(sess_config)
98
99  return distribution, target, sess_config
100
101
102class ParameterServerStrategyTestBase(
103    multi_worker_test_base.MultiWorkerTestBase):
104
105  def setUp(self):
106    self._result = 0
107    self._lock = threading.Lock()
108    self._init_condition = threading.Condition()
109    self._init_reached = 0
110    self._finish_condition = threading.Condition()
111    self._finish_reached = 0
112    self._sess_config = config_pb2.ConfigProto(allow_soft_placement=True)
113    super(ParameterServerStrategyTestBase, self).setUp()
114
115  def _get_test_objects(self, task_type, task_id, num_gpus):
116    return create_test_objects(
117        cluster_spec=self._cluster_spec,
118        task_type=task_type,
119        task_id=task_id,
120        num_gpus=num_gpus,
121        sess_config=self._sess_config)
122
123  def _test_device_assignment_distributed(self, task_type, task_id, num_gpus):
124    worker_device = '/job:%s/replica:0/task:%d' % (task_type, task_id)
125    d, _, sess_config = self._get_test_objects(task_type, task_id, num_gpus)
126    with ops.Graph().as_default(), \
127         self.cached_session(target=self._default_target,
128                             config=sess_config) as sess, \
129         d.scope():
130
131      # Define a variable outside the call_for_each_replica scope.
132      n = variable_scope.get_variable('n', initializer=10.0)
133      self.assertEqual(n.device, '/job:ps/task:0')
134
135      def model_fn():
136        if num_gpus == 0:
137          last_part_device = 'device:CPU:0'
138        else:
139          replica_id = _get_replica_id_integer()
140          last_part_device = ('device:GPU:%d' % replica_id)
141
142        a = constant_op.constant(1.0)
143        b = constant_op.constant(2.0)
144        c = a + b
145        self.assertEqual(a.device, worker_device + '/' + last_part_device)
146        self.assertEqual(b.device, worker_device + '/' + last_part_device)
147        self.assertEqual(c.device, worker_device + '/' + last_part_device)
148
149        # The device scope is ignored for variables but not for normal ops.
150        with ops.device('/job:worker/task:0'):
151          x = variable_scope.get_variable(
152              'x', initializer=10.0,
153              aggregation=variable_scope.VariableAggregation.SUM)
154          x_add = x.assign_add(c)
155          e = a + c
156        # The variable x is on the task 1 since the device_function has been
157        # called once before the model_fn.
158        self.assertEqual(x.device, '/job:ps/task:1')
159        self.assertEqual(x_add.device, x.device)
160        self.assertEqual(e.device,
161                         '/job:worker/replica:0/task:0/%s' % last_part_device)
162
163        # The colocate_vars_with can override the distribution's device.
164        with d.extended.colocate_vars_with(x):
165          y = variable_scope.get_variable(
166              'y', initializer=20.0,
167              aggregation=variable_scope.VariableAggregation.SUM)
168        # We add an identity here to avoid complaints about summing
169        # non-distributed values.
170        y_add = y.assign_add(array_ops.identity(x_add))
171        self.assertEqual(y.device, '/job:ps/task:1')
172        self.assertEqual(y_add.device, y.device)
173        self.assertEqual(y.device, x.device)
174
175        z = variable_scope.get_variable(
176            'z', initializer=10.0,
177            aggregation=variable_scope.VariableAggregation.SUM)
178        self.assertEqual(z.device, '/job:ps/task:0')
179        self.assertNotEqual(z.device, x.device)
180
181        with ops.control_dependencies([y_add]):
182          # We add an identity here to avoid complaints about summing
183          # non-distributed values.
184          z_add = z.assign_add(array_ops.identity(y))
185        with ops.control_dependencies([z_add]):
186          f = z + c
187        self.assertEqual(f.device, worker_device + '/' + last_part_device)
188
189        # The device scope would merge with the default worker device.
190        with ops.device('/CPU:1'):
191          g = e + 1.0
192        self.assertEqual(g.device, worker_device + '/device:CPU:1')
193
194        # This ops.colocate_with will be ignored when defining a variable but not
195        # for a normal tensor.
196        with ops.colocate_with(x):
197          u = variable_scope.get_variable('u', initializer=30.0)
198          v = variable_scope.get_variable('v', initializer=30.0)
199          h = f + 1.0
200        self.assertIn('/job:ps/', u.device)
201        self.assertIn('/job:ps/', v.device)
202        # u and v are on different parameter servers.
203        self.assertTrue(u.device != x.device or v.device != x.device)
204        self.assertTrue(u.device == x.device or v.device == x.device)
205        # Here h is not on one worker. Note h.device is canonical while x.device
206        # is not but.
207        self.assertIn('/job:ps/', h.device)
208        return y_add, z_add, f
209
210      y, z, f = d.extended.call_for_each_replica(model_fn)
211      self.assertNotEqual(y, None)
212      self.assertNotEqual(z, None)
213      self.assertNotEqual(f, None)
214
215      if context.num_gpus() >= 1 and num_gpus <= 1:
216        self.evaluate(variables.global_variables_initializer())
217        y_val, z_val, f_val = sess.run([y, z, f])
218        self.assertEqual(y_val, 33.0)
219        self.assertEqual(z_val, 43.0)
220        self.assertEqual(f_val, 46.0)
221
222  def _test_device_assignment_distributed_enable_partitioner(
223      self, task_type, task_id, num_gpus):
224    d, _, sess_config = self._get_test_objects(task_type, task_id, num_gpus)
225    num_shards = len(d.extended.parameter_devices)
226    partitioner = partitioned_variables.fixed_size_partitioner(num_shards)
227    with ops.Graph().as_default(), \
228         self.cached_session(target=self._default_target,
229                             config=sess_config) as sess, \
230         d.scope():
231
232      n = variable_scope.get_variable(
233          'n',
234          initializer=constant_op.constant([10.0, 20.0]),
235          aggregation=variable_scope.VariableAggregation.SUM,
236          partitioner=partitioner)
237
238      for part_id, var in enumerate(n):
239        self.assertEqual(var.device, '/job:ps/task:%d' % part_id)
240
241      def model_fn():
242        a = constant_op.constant([3.0, 5.0])
243        # The device scope is ignored for variables but not for normal ops.
244        with ops.device('/job:worker/task:0'):
245          x = variable_scope.get_variable(
246              'x',
247              initializer=constant_op.constant([10.0, 20.0]),
248              aggregation=variable_scope.VariableAggregation.SUM,
249              partitioner=partitioner)
250          x_add = x.assign_add(a, name='x_add')
251        # The variable x is on the task 1 since the device_function has been
252        # called once before the model_fn.
253        for part_id, var in enumerate(x):
254          self.assertEqual(var.device, '/job:ps/task:%d' % part_id)
255          self.assertEqual(var.device, x_add[part_id].device)
256
257        return x_add
258
259      x = d.extended.call_for_each_replica(model_fn)
260
261      if context.num_gpus() >= 1:
262        self.evaluate(variables.global_variables_initializer())
263        x_val = sess.run(x)
264        if num_gpus < 1:
265          self.assertEqual(x_val, [13.0, 25.0])
266        else:
267          x_expect = [10.0 + 3 * num_gpus, 20.0 + 5 * num_gpus]
268          self.assertEqual(x_val, x_expect)
269
270  def _test_device_assignment_local(self,
271                                    d,
272                                    compute_device='CPU',
273                                    variable_device='CPU',
274                                    num_gpus=0):
275    with ops.Graph().as_default(), \
276         self.cached_session(target=self._default_target,
277                             config=self._sess_config) as sess, \
278         d.scope():
279
280      def model_fn():
281        if 'CPU' in compute_device:
282          replica_compute_device = '/device:CPU:0'
283        else:
284          replica_id = _get_replica_id_integer()
285          replica_compute_device = ('/device:GPU:%d' % replica_id)
286        replica_compute_device = device_util.canonicalize(
287            replica_compute_device)
288
289        if 'CPU' in variable_device:
290          replica_variable_device = '/device:CPU:0'
291        else:
292          replica_id = _get_replica_id_integer()
293          replica_variable_device = ('/device:GPU:%d' % replica_id)
294        replica_variable_device = device_util.canonicalize(
295            replica_variable_device)
296
297        a = constant_op.constant(1.0)
298        b = constant_op.constant(2.0)
299        c = a + b
300        self.assertEqual(a.device, replica_compute_device)
301        self.assertEqual(b.device, replica_compute_device)
302        self.assertEqual(c.device, replica_compute_device)
303
304        # The device scope is ignored for variables but not for normal ops.
305        with ops.device('/device:GPU:2'):
306          x = variable_scope.get_variable(
307              'x', initializer=10.0,
308              aggregation=variable_scope.VariableAggregation.SUM)
309          x_add = x.assign_add(c)
310          e = a + c
311        self.assertEqual(
312            device_util.canonicalize(x.device), replica_variable_device)
313        self.assertEqual(x_add.device, x.device)
314        self.assertEqual(e.device, device_util.canonicalize('/device:GPU:2'))
315
316        # The colocate_vars_with can override the distribution's device.
317        with d.extended.colocate_vars_with(x):
318          y = variable_scope.get_variable(
319              'y', initializer=20.0,
320              aggregation=variable_scope.VariableAggregation.SUM)
321        # We add an identity here to avoid complaints about summing
322        # non-distributed values.
323        y_add = y.assign_add(array_ops.identity(x_add))
324        self.assertEqual(
325            device_util.canonicalize(y.device), replica_variable_device)
326        self.assertEqual(y_add.device, y.device)
327        self.assertEqual(y.device, x.device)
328
329        z = variable_scope.get_variable(
330            'z', initializer=10.0,
331            aggregation=variable_scope.VariableAggregation.SUM)
332        self.assertEqual(
333            device_util.canonicalize(z.device), replica_variable_device)
334
335        with ops.control_dependencies([y_add]):
336          # We add an identity here to avoid complaints about summing
337          # non-distributed values.
338          z_add = z.assign_add(array_ops.identity(y))
339        with ops.control_dependencies([z_add]):
340          f = z + c
341        self.assertEqual(f.device, replica_compute_device)
342
343        # The device scope would merge with the default worker device.
344        with ops.device('/CPU:1'):
345          g = e + 1.0
346        self.assertEqual(g.device, device_util.canonicalize('/device:CPU:1'))
347
348        # This ops.colocate_with will be ignored when defining a variable but not
349        # for a normal tensor.
350        with ops.colocate_with(x):
351          u = variable_scope.get_variable('u', initializer=30.0)
352          h = f + 1.0
353        self.assertEqual(
354            device_util.canonicalize(u.device), replica_variable_device)
355        self.assertEqual(
356            device_util.canonicalize(x.device),
357            device_util.canonicalize(h.device))
358        return y_add, z_add, f
359
360      y, z, f = d.extended.call_for_each_replica(model_fn)
361      self.assertNotEqual(y, None)
362      self.assertNotEqual(z, None)
363      self.assertNotEqual(f, None)
364
365      if context.num_gpus() >= 1 and num_gpus <= 1:
366        self.evaluate(variables.global_variables_initializer())
367        y_val, z_val, f_val = sess.run([y, z, f])
368        self.assertEqual(y_val, 33.0)
369        self.assertEqual(z_val, 43.0)
370        self.assertEqual(f_val, 46.0)
371
372  def _test_simple_increment(self, task_type, task_id, num_gpus):
373    d, master_target, sess_config = self._get_test_objects(
374        task_type, task_id, num_gpus)
375    if d.extended._cluster_spec:
376      num_workers = len(d.extended._cluster_spec.as_dict().get(WORKER))
377      if 'chief' in d.extended._cluster_spec.as_dict():
378        num_workers += 1
379    else:
380      num_workers = 1
381    with ops.Graph().as_default(), \
382         self.cached_session(target=master_target,
383                             config=sess_config) as sess, \
384         d.scope():
385
386      def model_fn():
387        x = variable_scope.get_variable(
388            'x', initializer=10.0,
389            aggregation=variable_scope.VariableAggregation.SUM)
390        y = variable_scope.get_variable(
391            'y', initializer=20.0,
392            aggregation=variable_scope.VariableAggregation.SUM)
393        z = variable_scope.get_variable(
394            'z', initializer=30.0,
395            aggregation=variable_scope.VariableAggregation.ONLY_FIRST_REPLICA)
396
397        # We explicitly make a constant tensor here to avoid complaints about
398        # summing non-distributed values.
399        one = constant_op.constant(1.0)
400        x_add = x.assign_add(one, use_locking=True)
401        y_add = y.assign_add(one, use_locking=True)
402        z_add = z.assign_add(one, use_locking=True)
403
404        train_op = control_flow_ops.group(x_add, y_add, z_add)
405        return x, y, z, train_op
406
407      x, y, z, train_op = d.extended.call_for_each_replica(model_fn)
408      train_op = d.group(train_op)
409
410      if task_id == 0:
411        self.evaluate(variables.global_variables_initializer())
412
413      # Workers waiting for chief worker's initializing variables.
414      self._init_condition.acquire()
415      self._init_reached += 1
416      while self._init_reached != num_workers:
417        self._init_condition.wait()
418      self._init_condition.notify_all()
419      self._init_condition.release()
420
421      sess.run(train_op)
422
423      # Wait for other workers to finish training.
424      self._finish_condition.acquire()
425      self._finish_reached += 1
426      while self._finish_reached != num_workers:
427        self._finish_condition.wait()
428      self._finish_condition.notify_all()
429      self._finish_condition.release()
430
431      x_val, y_val, z_val = sess.run([x, y, z])
432      self.assertEqual(x_val, 10.0 + 1.0 * num_workers * d.num_replicas_in_sync)
433      self.assertEqual(y_val, 20.0 + 1.0 * num_workers * d.num_replicas_in_sync)
434      self.assertEqual(z_val, 30.0 + 1.0 * num_workers)
435
436  def _test_minimize_loss_graph(self, task_type, task_id, num_gpus):
437    d, master_target, sess_config = self._get_test_objects(
438        task_type, task_id, num_gpus)
439    if task_type:
440      # Multi-worker
441      assert hasattr(d.extended, '_cluster_spec') and d.extended._cluster_spec
442      num_workers = len(d.extended._cluster_spec.as_dict().get(WORKER))
443      if CHIEF in d.extended._cluster_spec.as_dict():
444        num_workers += 1
445    else:
446      # local
447      num_workers = 1
448
449    with ops.Graph().as_default(), \
450         self.cached_session(target=master_target,
451                             config=sess_config) as sess, \
452         d.scope():
453      kernel = strategy_test_lib.create_variable_like_keras_layer(
454          'kernel', (1, 1), dtypes.float32,)
455
456      def loss_fn(x):
457        y = array_ops.reshape(
458            math_ops.matmul(x, kernel), []) - constant_op.constant(1.)
459        return y * y
460
461      # TODO(yuefengz, apassos): eager.backprop.implicit_grad is not safe for
462      # multiple graphs (b/111216820).
463      def grad_fn(x):
464        loss = loss_fn(x)
465        var_list = (
466            variables.trainable_variables() + ops.get_collection(
467                ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
468        grads = gradients.gradients(loss, var_list)
469        ret = list(zip(grads, var_list))
470        return ret
471
472      def update(v, g):
473        return v.assign_sub(0.05 * g, use_locking=True)
474
475      one = constant_op.constant([[1.]])
476
477      def step():
478        """Perform one optimization step."""
479        # Run forward & backward to get gradients, variables list.
480        g_v = d.extended.call_for_each_replica(grad_fn, args=(one,))
481        # Update the variables using the gradients and the update() function.
482        before_list = []
483        after_list = []
484        for g, v in g_v:
485          fetched = d.extended.read_var(v)
486          before_list.append(fetched)
487          with ops.control_dependencies([fetched]):
488            # TODO(yuefengz): support non-Mirrored variable as destinations.
489            g = d.extended.reduce_to(
490                reduce_util.ReduceOp.SUM, g, destinations=v)
491            with ops.control_dependencies(
492                d.extended.update(v, update, args=(g,), group=False)):
493              after_list.append(d.extended.read_var(v))
494        return before_list, after_list
495
496      before_out, after_out = step()
497
498      if (not task_type or
499          multi_worker_util.is_chief(
500              d.extended._cluster_spec, task_type, task_id)):
501        self.evaluate(variables.global_variables_initializer())
502
503      # Workers waiting for chief worker's initializing variables.
504      self._init_condition.acquire()
505      self._init_reached += 1
506      while self._init_reached != num_workers:
507        self._init_condition.wait()
508      self._init_condition.notify_all()
509      self._init_condition.release()
510
511      for i in range(10):
512        b, a = sess.run((before_out, after_out))
513        if i == 0:
514          before, = b
515        after, = a
516
517      error_before = abs(before - 1)
518      error_after = abs(after - 1)
519      # Error should go down
520      self.assertLess(error_after, error_before)
521
522  def _test_input_fn_iterator(self,
523                              task_type,
524                              task_id,
525                              num_gpus,
526                              input_fn,
527                              expected_values,
528                              test_reinitialize=True,
529                              ignore_order=False):
530    distribution, master_target, config = self._get_test_objects(
531        task_type, task_id, num_gpus)
532    devices = distribution.extended.worker_devices
533
534    with ops.Graph().as_default(), \
535         self.cached_session(config=config,
536                             target=master_target) as sess:
537      iterator = distribution.make_input_fn_iterator(input_fn)
538      sess.run(iterator.initializer)
539
540      for expected_value in expected_values:
541        next_element = iterator.get_next()
542        computed_value = sess.run([distribute_utils.select_replica(
543            r, next_element) for r in range(len(devices))])
544        if ignore_order:
545          self.assertCountEqual(expected_value, computed_value)
546        else:
547          self.assertEqual(expected_value, computed_value)
548
549      with self.assertRaises(errors.OutOfRangeError):
550        next_element = iterator.get_next()
551        sess.run([distribute_utils.select_replica(r, next_element)
552                  for r in range(len(devices))])
553
554      # After re-initializing the iterator, should be able to iterate again.
555      if test_reinitialize:
556        sess.run(iterator.initializer)
557
558        for expected_value in expected_values:
559          next_element = iterator.get_next()
560          computed_value = sess.run([distribute_utils.select_replica(
561              r, next_element) for r in range(len(devices))])
562          if ignore_order:
563            self.assertCountEqual(expected_value, computed_value)
564          else:
565            self.assertEqual(expected_value, computed_value)
566
567
568class ParameterServerStrategyTest(
569    ParameterServerStrategyTestBase,
570    strategy_test_lib.DistributionTestBase,
571    strategy_test_lib.TwoDeviceDistributionTestBase,
572    parameterized.TestCase):
573
574  @classmethod
575  def setUpClass(cls):
576    cls._cluster_spec = multi_worker_test_base.create_in_process_cluster(
577        num_workers=3, num_ps=2)
578    cls._default_target = 'grpc://' + cls._cluster_spec[WORKER][0]
579
580  @combinations.generate(combinations.combine(mode=['graph']))
581  def test_num_replicas_in_sync(self):
582    strategy, _, _ = create_test_objects(num_gpus=2)
583    # All the devices on a given worker are in sync which in this case is the
584    # number of gpus on each worker.
585    self.assertEqual(2, strategy.num_replicas_in_sync)
586
587  @combinations.generate(combinations.combine(mode=['graph']))
588  def testDeviceAssignmentLocalCPU(self):
589    strategy, _, _ = create_test_objects(num_gpus=0)
590    self._test_device_assignment_local(
591        strategy, compute_device='CPU', variable_device='CPU', num_gpus=0)
592
593  @combinations.generate(combinations.combine(mode=['graph']))
594  def testDeviceAssignmentLocalOneGPU(self):
595    strategy, _, _ = create_test_objects(num_gpus=1)
596    self._test_device_assignment_local(
597        strategy, compute_device='GPU', variable_device='GPU', num_gpus=1)
598
599  @combinations.generate(combinations.combine(mode=['graph']))
600  def testDeviceAssignmentLocalTwoGPUs(self):
601    strategy, _, _ = create_test_objects(num_gpus=2)
602    self._test_device_assignment_local(
603        strategy, compute_device='GPU', variable_device='CPU', num_gpus=2)
604
605  @combinations.generate(
606      combinations.combine(mode=['graph'], num_gpus=[0, 1, 2]))
607  def testDeviceAssignmentDistributed(self, num_gpus):
608    self._test_device_assignment_distributed('worker', 1, num_gpus)
609
610  @combinations.generate(
611      combinations.combine(mode=['graph'], num_gpus=[0, 1, 2]))
612  def testDeviceAssignmentDistributedEnablePartitioner(self, num_gpus):
613    self._test_device_assignment_distributed_enable_partitioner(
614        'worker', 1, num_gpus)
615
616  @combinations.generate(combinations.combine(mode=['graph']))
617  def testSimpleBetweenGraph(self):
618    self._run_between_graph_clients(self._test_simple_increment,
619                                    self._cluster_spec, context.num_gpus())
620
621  @combinations.generate(
622      combinations.combine(mode=['graph'], required_gpus=[0, 1, 2]))
623  def testLocalSimpleIncrement(self, required_gpus):
624    self._test_simple_increment(None, 0, required_gpus)
625
626  @combinations.generate(
627      combinations.combine(mode=['graph'], required_gpus=[0, 1, 2]))
628  def testMinimizeLossGraphDistributed(self, required_gpus):
629    self._run_between_graph_clients(self._test_minimize_loss_graph,
630                                    self._cluster_spec, required_gpus)
631
632  @combinations.generate(
633      combinations.combine(mode=['graph'], required_gpus=[0, 1, 2]))
634  def testMinimizeLossGraphLocal(self, required_gpus):
635    self._test_minimize_loss_graph(None, None, required_gpus)
636
637  # TODO(priyag): Refactor this and other multi worker tests.
638  @combinations.generate(
639      combinations.combine(
640          mode=['graph'], required_gpus=[1, 2], use_dataset=[True, False]))
641  def testMakeInputFnIteratorDistributed(self, required_gpus, use_dataset):
642    if use_dataset:
643      fn = lambda: dataset_ops.Dataset.range(100)
644    else:
645      def fn():
646        dataset = dataset_ops.Dataset.range(100)
647        it = dataset_ops.make_one_shot_iterator(dataset)
648        return it.get_next
649
650    expected_values = [[i + j
651                        for j in range(required_gpus)]
652                       for i in range(0, 100, required_gpus)]
653
654    input_fn = self._input_fn_to_test_input_context(
655        fn,
656        expected_num_replicas_in_sync=required_gpus,
657        expected_num_input_pipelines=3,
658        expected_input_pipeline_id=1)  # because task_id = 1
659    self._test_input_fn_iterator(
660        'worker',
661        1,
662        required_gpus,
663        input_fn,
664        expected_values,
665        test_reinitialize=use_dataset,
666        ignore_order=not use_dataset)
667
668  @combinations.generate(
669      combinations.combine(
670          mode=['graph'], required_gpus=[1, 2], use_dataset=[True, False]))
671  def testMakeInputFnIteratorLocal(self, required_gpus, use_dataset):
672    if use_dataset:
673      fn = lambda: dataset_ops.Dataset.range(100)
674    else:
675
676      def fn():
677        dataset = dataset_ops.Dataset.range(100)
678        it = dataset_ops.make_one_shot_iterator(dataset)
679        return it.get_next
680
681    expected_values = [[i + j
682                        for j in range(required_gpus)]
683                       for i in range(0, 100, required_gpus)]
684
685    input_fn = self._input_fn_to_test_input_context(
686        fn,
687        expected_num_replicas_in_sync=required_gpus,
688        expected_num_input_pipelines=1,
689        expected_input_pipeline_id=0)  # only one worker and pipeline for local.
690    self._test_input_fn_iterator(
691        None,
692        None,
693        required_gpus,
694        input_fn,
695        expected_values,
696        test_reinitialize=use_dataset,
697        ignore_order=not use_dataset)
698
699  @combinations.generate(combinations.combine(mode=['graph']))
700  def testGlobalStepUpdate(self):
701    strategy, _, _ = create_test_objects()
702    self._test_global_step_update(strategy)
703
704  @combinations.generate(combinations.combine(mode=['graph']))
705  def testUpdateConfigProtoMultiWorker(self):
706    strategy, _, _ = create_test_objects(
707        cluster_spec=self._cluster_spec,
708        task_type='worker',
709        task_id=1,
710        num_gpus=2)
711
712    config_proto = config_pb2.ConfigProto(device_filters=['to_be_overridden'])
713
714    new_config = strategy.update_config_proto(config_proto)
715
716    # Verify device filters.
717    self.assertEqual(['/job:worker/task:1', '/job:ps'],
718                     new_config.device_filters)
719
720    # Verify isolate_session_state
721    self.assertFalse(new_config.isolate_session_state)
722
723  @combinations.generate(combinations.combine(mode=['graph']))
724  def testUpdateConfigProtoLocal(self):
725    strategy, _, _ = create_test_objects(num_gpus=2)
726
727    config_proto = config_pb2.ConfigProto()
728    new_config = strategy.update_config_proto(config_proto)
729
730    # Verify isolate_session_state
731    self.assertTrue(new_config.isolate_session_state)
732
733  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
734  def testInMultiWorkerMode(self):
735    strategy, _, _ = create_test_objects(
736        cluster_spec=self._cluster_spec,
737        task_type='worker',
738        task_id=1,
739        num_gpus=0)
740    self.assertTrue(strategy.extended._in_multi_worker_mode())
741
742  @combinations.generate(combinations.combine(mode=['eager']))
743  def testEagerCustomTrainingUnimplementedError(self):
744    cluster_spec = multi_worker_test_base.create_in_process_cluster(
745        num_workers=3, num_ps=2)
746    cluster_resolver = SimpleClusterResolver(
747        cluster_spec=multi_worker_util.normalize_cluster_spec(cluster_spec),
748        task_type='worker',
749        task_id=1,
750        num_accelerators={'GPU': 0})
751    strategy = parameter_server_strategy.ParameterServerStrategyV1(
752        cluster_resolver)
753    dataset = dataset_ops.DatasetV2.from_tensor_slices([5., 6., 7., 8.])
754
755    def train_step(data):
756      return math_ops.square(data)
757
758    self.assertRaisesRegex(NotImplementedError, 'ParameterServerStrategy*',
759                           strategy.experimental_distribute_dataset,
760                           dataset.batch(2))
761
762    self.assertRaisesRegex(NotImplementedError, 'ParameterServerStrategy*',
763                           strategy.distribute_datasets_from_function,
764                           lambda _: dataset)
765
766    self.assertRaisesRegex(NotImplementedError, 'ParameterServerStrategy*',
767                           strategy.scope)
768
769    self.assertRaisesRegex(NotImplementedError, 'ParameterServerStrategy*',
770                           strategy.run, train_step)
771
772  @combinations.generate(combinations.combine(
773      mode=['graph'],
774      prefetch_to_device=[None, True]))
775  def test_prefetch_to_device_dataset(self, prefetch_to_device):
776    distribution, _, _ = create_test_objects(
777        cluster_spec=self._cluster_spec,
778        task_type='worker',
779        task_id=0,
780        num_gpus=2)
781    if prefetch_to_device is None:
782      input_options = None
783    else:
784      input_options = distribute_lib.InputOptions(
785          experimental_prefetch_to_device=prefetch_to_device)
786    dataset = dataset_ops.Dataset.range(100)
787    dataset = dataset.batch(distribution.num_replicas_in_sync)
788    dataset = distribution.experimental_distribute_dataset(
789        dataset, options=input_options)
790    if isinstance(dataset, input_lib.DistributedDatasetV1):
791      item = dataset.make_initializable_iterator().get_next()
792    else:
793      self.skipTest('unsupported test combination')
794    device_types = {
795        tf_device.DeviceSpec.from_string(tensor.device).device_type for
796        tensor in item.values}
797    self.assertAllEqual(list(device_types), ['GPU'])
798
799  @combinations.generate(combinations.combine(mode=['graph']))
800  def test_prefetch_to_host_dataset(self):
801    distribution, _, _ = create_test_objects(
802        cluster_spec=self._cluster_spec,
803        task_type='worker',
804        task_id=0,
805        num_gpus=2)
806    input_options = distribute_lib.InputOptions(
807        experimental_prefetch_to_device=False)
808    dataset = dataset_ops.Dataset.range(100)
809    dataset = dataset.batch(distribution.num_replicas_in_sync)
810    dataset = distribution.experimental_distribute_dataset(
811        dataset, options=input_options)
812    if isinstance(dataset, input_lib.DistributedDatasetV1):
813      item = dataset.make_initializable_iterator().get_next()
814    else:
815      self.skipTest('unsupported test combination')
816    device_types = {
817        tf_device.DeviceSpec.from_string(tensor.device).device_type for
818        tensor in item.values}
819    self.assertAllEqual(list(device_types), ['CPU'])
820
821
822class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase,
823                                           parameterized.TestCase):
824
825  @classmethod
826  def setUpClass(cls):
827    cls._cluster_spec = multi_worker_test_base.create_in_process_cluster(
828        num_workers=3, num_ps=2, has_chief=True)
829    cls._default_target = 'grpc://' + cls._cluster_spec[CHIEF][0]
830
831  @combinations.generate(
832      combinations.combine(mode=['graph'], required_gpus=[0, 1, 2]))
833  def testSimpleBetweenGraph(self, required_gpus):
834    self._run_between_graph_clients(self._test_simple_increment,
835                                    self._cluster_spec, required_gpus)
836
837  @combinations.generate(
838      combinations.combine(mode=['graph'], num_gpus=[0, 1, 2]))
839  def testMinimizeLossGraph(self, num_gpus):
840    self._run_between_graph_clients(self._test_minimize_loss_graph,
841                                    self._cluster_spec, num_gpus)
842
843  @combinations.generate(combinations.combine(mode=['graph']))
844  def testGlobalStepIsWrappedOnTwoGPUs(self):
845    strategy, _, _ = create_test_objects(num_gpus=2)
846    with ops.Graph().as_default(), strategy.scope():
847      created_step = training_util.create_global_step()
848      get_step = training_util.get_global_step()
849      self.assertEqual(created_step, get_step,
850                       msg=('created_step %s type %s vs. get_step %s type %s' %
851                            (id(created_step), created_step.__class__.__name__,
852                             id(get_step), get_step.__class__.__name__)))
853      self.assertIs(ps_values.AggregatingVariable, type(created_step))
854      self.assertIs(ps_values.AggregatingVariable, type(get_step))
855      self.assertIs(strategy, created_step.distribute_strategy)
856
857  @combinations.generate(combinations.combine(mode=['graph']))
858  def testGlobalStepIsNotWrappedOnOneGPU(self):
859    strategy, _, _ = create_test_objects(num_gpus=1)
860    with ops.Graph().as_default(), strategy.scope():
861      created_step = training_util.create_global_step()
862      get_step = training_util.get_global_step()
863      self.assertEqual(created_step, get_step,
864                       msg=('created_step %s type %s vs. get_step %s type %s' %
865                            (id(created_step), created_step.__class__.__name__,
866                             id(get_step), get_step.__class__.__name__)))
867      self.assertIs(resource_variable_ops.ResourceVariable, type(created_step))
868      self.assertIs(resource_variable_ops.ResourceVariable, type(get_step))
869      # All variables have an _distribute_strategy parameter. Only variable
870      # subclasses in distribution strategy expose it publicly.
871      self.assertFalse(hasattr(strategy, 'distribute_strategy'))
872      self.assertIs(strategy, created_step._distribute_strategy)
873
874  @combinations.generate(combinations.combine(mode=['graph'], required_gpus=2))
875  def testValueContainer(self):
876    strategy, _, _ = create_test_objects(num_gpus=2)
877    with ops.Graph().as_default(), strategy.scope():
878
879      def f():
880        with backprop.GradientTape() as tape:
881          v = variable_scope.get_variable('v', initializer=10.0)
882          _ = v * v
883        v, = tape.watched_variables()
884        w = strategy.extended.value_container(v)
885        self.assertIs(ps_values.AggregatingVariable, type(w))
886
887      strategy.extended.call_for_each_replica(f)
888
889
890class CentralStorageStrategyTest(strategy_test_lib.DistributionTestBase,
891                                 parameterized.TestCase):
892
893  @combinations.generate(combinations.combine(mode=['graph', 'eager'],
894                                              required_gpus=2))
895  def testNumpyDataset(self):
896    strategy, _, _ = create_test_objects(num_gpus=2)
897    self._test_numpy_dataset(strategy)
898
899  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
900  def testInMultiWorkerMode(self):
901    strategy, _, _ = create_test_objects(num_gpus=0)
902    self.assertFalse(strategy.extended._in_multi_worker_mode())
903
904
905if __name__ == '__main__':
906  test.main()
907