1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Class CollectiveAllReduceStrategy implementing DistributionStrategy."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import copy
22import threading
23import time
24import weakref
25
26from tensorflow.core.protobuf import rewriter_config_pb2
27from tensorflow.core.protobuf import tensorflow_server_pb2
28from tensorflow.python.distribute import collective_util
29from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
30from tensorflow.python.distribute import cross_device_utils
31from tensorflow.python.distribute import device_util
32from tensorflow.python.distribute import distribute_lib
33from tensorflow.python.distribute import distribute_utils
34from tensorflow.python.distribute import distribution_strategy_context as ds_context
35from tensorflow.python.distribute import input_lib
36from tensorflow.python.distribute import mirrored_strategy
37from tensorflow.python.distribute import multi_worker_util
38from tensorflow.python.distribute import numpy_dataset
39from tensorflow.python.distribute import reduce_util
40from tensorflow.python.distribute import values
41from tensorflow.python.distribute.cluster_resolver import ClusterResolver
42from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
43from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver
44from tensorflow.python.eager import context
45from tensorflow.python.framework import errors
46from tensorflow.python.framework import ops
47from tensorflow.python.ops import array_ops
48from tensorflow.python.ops import collective_ops
49from tensorflow.python.platform import tf_logging as logging
50from tensorflow.python.training.tracking import base
51from tensorflow.python.util import deprecation
52from tensorflow.python.util.tf_export import tf_export
53
54
55# pylint: disable=line-too-long
56@tf_export("distribute.MultiWorkerMirroredStrategy", v1=[])
57class CollectiveAllReduceStrategy(distribute_lib.Strategy):
58  """A distribution strategy for synchronous training on multiple workers.
59
60  This strategy implements synchronous distributed training across multiple
61  workers, each with potentially multiple GPUs. Similar to
62  `tf.distribute.MirroredStrategy`, it replicates all variables and computations
63  to each local device. The difference is that it uses a distributed collective
64  implementation (e.g. all-reduce), so that multiple workers can work together.
65
66  You need to launch your program on each worker and configure
67  `cluster_resolver` correctly. For example, if you are using
68  `tf.distribute.cluster_resolver.TFConfigClusterResolver`, each worker needs to
69  have its corresponding `task_type` and `task_id` set in the `TF_CONFIG`
70  environment variable. An example TF_CONFIG on worker-0 of a two worker cluster
71  is:
72
73  ```
74  TF_CONFIG = '{"cluster": {"worker": ["localhost:12345", "localhost:23456"]}, "task": {"type": "worker", "index": 0} }'
75  ```
76
77  Your program runs on each worker as-is. Note that collectives require each
78  worker to participate. All `tf.distribute` and non `tf.distribute` API may use
79  collectives internally, e.g. checkpointing and saving since reading a
80  `tf.Variable` with `tf.VariableSynchronization.ON_READ` all-reduces the value.
81  Therefore it's recommended to run exactly the same program on each worker.
82  Dispatching based on `task_type` or `task_id` of the worker is error-prone.
83
84  `cluster_resolver.num_accelerators()` determines the number of GPUs the
85  strategy uses. If it's zero, the strategy uses the CPU. All workers need to
86  use the same number of devices, otherwise the behavior is undefined.
87
88  This strategy is not intended for TPU. Use `tf.distribute.TPUStrategy`
89  instead.
90
91  After setting up TF_CONFIG, using this strategy is similar to using
92  `tf.distribute.MirroredStrategy` and `tf.distribute.TPUStrategy`.
93
94  ```
95  strategy = tf.distribute.MultiWorkerMirroredStrategy()
96
97  with strategy.scope():
98    model = tf.keras.Sequential([
99      tf.keras.layers.Dense(2, input_shape=(5,)),
100    ])
101    optimizer = tf.keras.optimizers.SGD(learning_rate=0.1)
102
103  def dataset_fn(ctx):
104    x = np.random.random((2, 5)).astype(np.float32)
105    y = np.random.randint(2, size=(2, 1))
106    dataset = tf.data.Dataset.from_tensor_slices((x, y))
107    return dataset.repeat().batch(1, drop_remainder=True)
108  dist_dataset = strategy.distribute_datasets_from_function(dataset_fn)
109
110  model.compile()
111  model.fit(dist_dataset)
112  ```
113
114  You can also write your own training loop:
115
116  ```
117  @tf.function
118  def train_step(iterator):
119
120    def step_fn(inputs):
121      features, labels = inputs
122      with tf.GradientTape() as tape:
123        logits = model(features, training=True)
124        loss = tf.keras.losses.sparse_categorical_crossentropy(
125            labels, logits)
126
127      grads = tape.gradient(loss, model.trainable_variables)
128      optimizer.apply_gradients(zip(grads, model.trainable_variables))
129
130    strategy.run(step_fn, args=(next(iterator),))
131
132  for _ in range(NUM_STEP):
133    train_step(iterator)
134  ```
135
136  See
137  [Multi-worker training with Keras](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras)
138  for a detailed tutorial.
139
140  __Saving__
141
142  You need to save and checkpoint on all workers instead of just one. This is
143  because variables whose synchronization=ON_READ triggers aggregation during
144  saving. It's recommended to save to a different path on each worker to avoid
145  race conditions. Each worker saves the same thing. See
146  [Multi-worker training with Keras](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras#model_saving_and_loading)
147  tutorial for examples.
148
149  __Known Issues__
150
151  * `tf.distribute.cluster_resolver.TFConfigClusterResolver` does not return the
152  correct number of accelerators. The strategy uses all available GPUs if
153  `cluster_resolver` is `tf.distribute.cluster_resolver.TFConfigClusterResolver`
154  or `None`.
155  * In eager mode, the strategy needs to be created before calling any other
156  Tensorflow API.
157
158  """
159  # pylint: enable=line-too-long
160
161  # TODO(anjalisridhar): Update our guides with examples showing how we can use
162  # the cluster_resolver argument.
163
164  # The starting number for collective keys. This should only be set in tests.
165  _collective_key_base = 0
166
167  def __init__(self,
168               cluster_resolver=None,
169               communication_options=None):
170    """Creates the strategy.
171
172    Args:
173      cluster_resolver: optional
174        `tf.distribute.cluster_resolver.ClusterResolver`. If `None`,
175        `tf.distribute.cluster_resolver.TFConfigClusterResolver` is used.
176      communication_options: optional
177        `tf.distribute.experimental.CommunicationOptions`. This configures the
178        default options for cross device communications. It can be overridden by
179        options provided to the communication APIs like
180        `tf.distribute.ReplicaContext.all_reduce`. See
181        `tf.distribute.experimental.CommunicationOptions` for details.
182    """
183    if communication_options is None:
184      communication_options = collective_util.Options()
185    super(CollectiveAllReduceStrategy, self).__init__(
186        CollectiveAllReduceExtended(
187            self,
188            cluster_resolver=cluster_resolver,
189            communication_options=communication_options))
190
191    distribute_lib.distribution_strategy_gauge.get_cell("V2").set(
192        "MultiWorkerMirroredStrategy")
193    # pylint: disable=protected-access
194    distribute_lib.distribution_strategy_replica_gauge.get_cell(
195        "num_workers").set(self.extended._num_workers)
196    distribute_lib.distribution_strategy_replica_gauge.get_cell(
197        "num_replicas_per_worker").set(self.extended._num_gpus_per_worker)
198
199  @classmethod
200  def _from_local_devices(cls, devices, communication_options=None):
201    """A convenience method to create an object with a list of devices."""
202    obj = cls(communication_options=communication_options)
203    obj.extended._initialize_local(TFConfigClusterResolver(), devices=devices)  # pylint: disable=protected-access
204    return obj
205
206  @property
207  def cluster_resolver(self):
208    """Returns the cluster resolver associated with this strategy.
209
210    As a multi-worker strategy,
211    `tf.distribute.experimental.MultiWorkerMirroredStrategy` provides the
212    associated `tf.distribute.cluster_resolver.ClusterResolver`. If the user
213    provides one in `__init__`, that instance is returned; if the user does
214    not, a default `TFConfigClusterResolver` is provided.
215    """
216    return self.extended._cluster_resolver  # pylint: disable=protected-access
217
218
219class _CollectiveAllReduceStrategyExperimentalMeta(type):
220
221  @classmethod
222  def __instancecheck__(cls, instance):
223    # This is to make isinstance(tf.distribute.MultiWorkerMirroredStrategy(),
224    # tf.distribute.experimental.MultiWorkerMirroredStrategy). Some libraries is
225    # performing such check.
226    return isinstance(instance, CollectiveAllReduceStrategy)
227
228
229@tf_export("distribute.experimental.MultiWorkerMirroredStrategy", v1=[])
230class _CollectiveAllReduceStrategyExperimental(
231    CollectiveAllReduceStrategy,
232    metaclass=_CollectiveAllReduceStrategyExperimentalMeta):
233
234  __doc__ = CollectiveAllReduceStrategy.__doc__
235
236  @deprecation.deprecated(
237      None, "use distribute.MultiWorkerMirroredStrategy instead")
238  def __init__(self,
239               communication=collective_util.CommunicationImplementation.AUTO,
240               cluster_resolver=None):
241    """Creates the strategy.
242
243    Args:
244      communication: optional
245        `tf.distribute.experimental.CommunicationImplementation`. This is a hint
246        on the preferred collective communication implementation. Possible
247        values include `AUTO`, `RING`, and `NCCL`.
248      cluster_resolver: optional
249        `tf.distribute.cluster_resolver.ClusterResolver`. If `None`,
250        `tf.distribute.cluster_resolver.TFConfigClusterResolver` is used.
251    """
252    communication_options = collective_util.Options(
253        implementation=communication)
254    super(_CollectiveAllReduceStrategyExperimental,
255          self).__init__(cluster_resolver, communication_options)
256
257  @classmethod
258  def _from_local_devices(
259      cls,
260      devices,
261      communication=collective_util.CommunicationImplementation.AUTO):
262    """A convenience method to create an object with a list of devices."""
263    obj = cls(communication)
264    obj.extended._initialize_local(TFConfigClusterResolver(), devices=devices)  # pylint: disable=protected-access
265    return obj
266
267
268_CollectiveAllReduceStrategyExperimental.__name__ = CollectiveAllReduceStrategy.__name__
269
270
271@tf_export(v1=["distribute.experimental.MultiWorkerMirroredStrategy"])  # pylint: disable=missing-docstring
272class CollectiveAllReduceStrategyV1(distribute_lib.StrategyV1):
273
274  __doc__ = CollectiveAllReduceStrategy.__doc__
275
276  # The starting number for collective keys. This should only be set in tests.
277  _collective_key_base = 0
278
279  def __init__(self,
280               communication=collective_util.CommunicationImplementation.AUTO,
281               cluster_resolver=None):
282    """Initializes the object."""
283    communication_options = collective_util.Options(
284        implementation=communication)
285    super(CollectiveAllReduceStrategyV1, self).__init__(
286        CollectiveAllReduceExtended(
287            self,
288            cluster_resolver=cluster_resolver,
289            communication_options=communication_options))
290    distribute_lib.distribution_strategy_gauge.get_cell("V1").set(
291        "MultiWorkerMirroredStrategy")
292    # pylint: disable=protected-access
293    distribute_lib.distribution_strategy_replica_gauge.get_cell(
294        "num_workers").set(self.extended._num_workers)
295    distribute_lib.distribution_strategy_replica_gauge.get_cell(
296        "num_gpu_per_worker").set(self.extended._num_gpus_per_worker)
297
298
299class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
300  """Implementation of CollectiveAllReduceStrategy."""
301
302  # Whether to perdically check the health of the cluster. If any worker is not
303  # reachable, collectives are aborted and the user program should get a
304  # tf.errors.UnavailableError. It's required to restart in order to recover.
305  _enable_check_health = True
306  # Check health interval in seconds.
307  _check_health_interval = 30
308  # Timeout in seconds for the first check health. The first check health needs
309  # to wait for cluster, which may make a longer time.
310  _check_health_initial_timeout = 0
311  # Times to retry before considering the peer is down.
312  _check_health_retry_limit = 3
313  # Timeout in seconds the each check health.
314  _check_health_timeout = 10
315
316  def __init__(self, container_strategy, cluster_resolver,
317               communication_options):
318    if not isinstance(communication_options, collective_util.Options):
319      raise ValueError("communication_options must be an instance of "
320                       "tf.distribute.experimental.CommunicationOptions")
321    self._cluster_resolver = cluster_resolver or TFConfigClusterResolver()
322    if not isinstance(self._cluster_resolver, ClusterResolver):
323      raise ValueError("cluster_resolver must be an instance of "
324                       "tf.distribute.cluster_resolver.ClusterResolver")
325    distribute_lib.StrategyExtendedV1.__init__(self, container_strategy)
326    self._communication_options = communication_options
327    self._collective_key_base = container_strategy._collective_key_base  # pylint: disable=protected-access
328    self._initialize_strategy(self._cluster_resolver)
329    self._cfer_fn_cache = weakref.WeakKeyDictionary()
330    self.experimental_enable_get_next_as_optional = True
331    assert isinstance(self._cross_device_ops,
332                      cross_device_ops_lib.CollectiveAllReduce)
333
334  def _initialize_strategy(self, cluster_resolver):
335    if cluster_resolver.cluster_spec().as_dict():
336      self._initialize_multi_worker(cluster_resolver)
337    else:
338      self._initialize_local(cluster_resolver)
339
340  def _initialize_local(self, cluster_resolver, devices=None):
341    """Initializes the object for local training."""
342    self._is_chief = True
343    self._num_workers = 1
344
345    if ops.executing_eagerly_outside_functions():
346      try:
347        context.context().configure_collective_ops(
348            scoped_allocator_enabled_ops=("CollectiveReduce",))
349      except RuntimeError:
350        logging.warning("Collective ops is not configured at program startup. "
351                        "Some performance features may not be enabled.")
352      self._collective_ops_configured = True
353
354    # TODO(b/126786766): TFConfigClusterResolver returns wrong number of GPUs in
355    # some cases.
356    if isinstance(cluster_resolver, TFConfigClusterResolver):
357      num_gpus = context.num_gpus()
358    else:
359      num_gpus = cluster_resolver.num_accelerators().get("GPU", 0)
360
361    if devices:
362      local_devices = devices
363    else:
364      if num_gpus:
365        local_devices = tuple("/device:GPU:%d" % i for i in range(num_gpus))
366      else:
367        local_devices = ("/device:CPU:0",)
368
369    self._worker_device = device_util.canonicalize("/device:CPU:0")
370    self._host_input_device = numpy_dataset.SingleDevice(self._worker_device)
371
372    self._collective_keys = cross_device_utils.CollectiveKeys(
373        group_key_start=1 + self._collective_key_base)
374    self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
375        devices=local_devices,
376        group_size=len(local_devices),
377        collective_keys=self._collective_keys)
378    # CrossDeviceOps for per host tensors.
379    self._host_cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
380        devices=[self._worker_device],
381        group_size=self._num_workers,
382        collective_keys=self._collective_keys)
383    super(CollectiveAllReduceExtended, self)._initialize_single_worker(
384        local_devices)
385
386    self._cluster_spec = None
387    self._task_type = None
388    self._task_id = None
389    self._id_in_cluster = 0
390
391    # This is a mark to tell whether we are running with standalone client or
392    # independent worker. Right now with standalone client, strategy object is
393    # created as local strategy and then turn into multi-worker strategy via
394    # configure call.
395    self._local_or_standalone_client_mode = True
396
397    # Save the num_gpus_per_worker and rpc_layer for configure method.
398    self._num_gpus_per_worker = num_gpus
399    self._rpc_layer = cluster_resolver.rpc_layer
400    self._warn_nccl_no_gpu()
401
402    logging.info(
403        "Single-worker MultiWorkerMirroredStrategy with local_devices "
404        "= %r, communication = %s", local_devices,
405        self._communication_options.implementation)
406
407  def _initialize_multi_worker(self, cluster_resolver):
408    """Initializes the object for multi-worker training."""
409    cluster_spec = multi_worker_util.normalize_cluster_spec(
410        cluster_resolver.cluster_spec())
411    task_type = cluster_resolver.task_type
412    task_id = cluster_resolver.task_id
413    if task_type is None or task_id is None:
414      raise ValueError("When `cluster_spec` is given, you must also specify "
415                       "`task_type` and `task_id`.")
416    self._cluster_spec = cluster_spec
417    self._task_type = task_type
418    self._task_id = task_id
419    self._id_in_cluster = multi_worker_util.id_in_cluster(
420        self._cluster_spec, self._task_type, self._task_id)
421
422    self._num_workers = multi_worker_util.worker_count(cluster_spec, task_type)
423    if not self._num_workers:
424      raise ValueError("No `worker`, `chief` or `evaluator` tasks can be found "
425                       "in `cluster_spec`.")
426
427    self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type,
428                                                task_id)
429
430    self._worker_device = "/job:%s/task:%d" % (task_type, task_id)
431    self._host_input_device = numpy_dataset.SingleDevice(self._worker_device)
432
433    if (ops.executing_eagerly_outside_functions() and
434        not getattr(self, "_local_or_standalone_client_mode", False)):
435      context.context().configure_collective_ops(
436          collective_leader=multi_worker_util.collective_leader(
437              cluster_spec, task_type, task_id),
438          scoped_allocator_enabled_ops=("CollectiveReduce",),
439          device_filters=("/job:%s/task:%d" % (task_type, task_id),))
440      self._collective_ops_configured = True
441
442    # Starting a std server in eager mode and in independent worker mode.
443    if (context.executing_eagerly() and
444        not getattr(self, "_std_server_started", False) and
445        not getattr(self, "_local_or_standalone_client_mode", False)):
446      # Checking _local_or_standalone_client_mode as well because we should not
447      # create the std server in standalone client mode.
448      config_proto = copy.deepcopy(context.context().config)
449      config_proto = self._update_config_proto(config_proto)
450
451      if hasattr(cluster_resolver, "port"):
452        port = cluster_resolver.port
453      else:
454        port = 0
455      server_def = tensorflow_server_pb2.ServerDef(
456          cluster=cluster_spec.as_cluster_def(),
457          default_session_config=config_proto,
458          job_name=task_type,
459          task_index=task_id,
460          protocol=cluster_resolver.rpc_layer or "grpc",
461          port=port)
462      context.context().enable_collective_ops(server_def)
463      self._std_server_started = True
464      # The `ensure_initialized` is needed before calling
465      # `context.context().devices()`.
466      context.context().ensure_initialized()
467      logging.info(
468          "Enabled multi-worker collective ops with available devices: %r",
469          context.context().devices())
470
471    # TODO(yuefengz): The `num_gpus` is only for this particular task. It
472    # assumes all workers have the same number of GPUs. We should remove this
473    # assumption by querying all tasks for their numbers of GPUs.
474    # TODO(b/126786766): TFConfigClusterResolver returns wrong number of GPUs in
475    # some cases.
476    if isinstance(cluster_resolver, TFConfigClusterResolver):
477      num_gpus = context.num_gpus()
478    else:
479      num_gpus = cluster_resolver.num_accelerators().get("GPU", 0)
480
481    if num_gpus:
482      local_devices = tuple("%s/device:GPU:%d" % (self._worker_device, i)
483                            for i in range(num_gpus))
484    else:
485      local_devices = (self._worker_device,)
486
487    self._collective_keys = cross_device_utils.CollectiveKeys(
488        group_key_start=1 + self._collective_key_base)
489    self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
490        devices=local_devices,
491        group_size=len(local_devices) * self._num_workers,
492        collective_keys=self._collective_keys)
493    # CrossDeviceOps for per host tensors.
494    self._host_cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
495        devices=[self._worker_device],
496        group_size=self._num_workers,
497        collective_keys=self._collective_keys)
498    super(CollectiveAllReduceExtended, self)._initialize_single_worker(
499        local_devices)
500
501    # Add a default device so that ops without specified devices will not end up
502    # on other workers.
503    self._default_device = "/job:%s/task:%d" % (task_type, task_id)
504
505    # Save the num_gpus_per_worker and rpc_layer for configure method.
506    self._num_gpus_per_worker = num_gpus
507    self._rpc_layer = cluster_resolver.rpc_layer
508    self._warn_nccl_no_gpu()
509
510    if self._enable_check_health:
511      self._start_check_health_thread()
512
513    logging.info(
514        "MultiWorkerMirroredStrategy with cluster_spec = %r, task_type = %r, "
515        "task_id = %r, num_workers = %r, local_devices = %r, "
516        "communication = %s", cluster_spec.as_dict(), task_type, task_id,
517        self._num_workers, local_devices,
518        self._communication_options.implementation)
519
520  def __del__(self):
521    self._stop_check_health_thread()
522
523  def _input_workers_with_options(self, options=None):
524    host_device = device_util.get_host_for_device(self._worker_device)
525    if not options or options.experimental_prefetch_to_device:
526      return input_lib.InputWorkers([(host_device, self.worker_devices)])
527    else:
528      return input_lib.InputWorkers([(
529          host_device,
530          [device_util.get_host_for_device(worker) for worker in
531           self.worker_devices])])
532
533  @property
534  def _input_workers(self):
535    return self._input_workers_with_options()
536
537  def _get_variable_creator_initial_value(self,
538                                          replica_id,
539                                          device,
540                                          primary_var,
541                                          **kwargs):
542    if replica_id == 0:  # First replica on each worker.
543      assert device is not None
544      assert primary_var is None
545
546      def initial_value_fn():  # pylint: disable=g-missing-docstring
547        # Only the first device participates in the broadcast of initial values.
548        group_key = self._collective_keys.get_group_key([device])
549        group_size = self._num_workers
550        collective_instance_key = (
551            self._collective_keys.get_instance_key(group_key, device))
552
553        with ops.device(device):
554          initial_value = kwargs["initial_value"]
555          if callable(initial_value):
556            initial_value = initial_value()
557          if isinstance(initial_value, base.CheckpointInitialValue):
558            initial_value = initial_value.wrapped_value
559          assert not callable(initial_value)
560          initial_value = ops.convert_to_tensor(
561              initial_value, dtype=kwargs.get("dtype", None))
562
563          if self._num_workers > 1:
564            if self._is_chief:
565              bcast_send = collective_ops.broadcast_send(
566                  initial_value, initial_value.shape, initial_value.dtype,
567                  group_size, group_key, collective_instance_key)
568              with ops.control_dependencies([bcast_send]):
569                return array_ops.identity(initial_value)
570            else:
571              return collective_ops.broadcast_recv(initial_value.shape,
572                                                   initial_value.dtype,
573                                                   group_size, group_key,
574                                                   collective_instance_key)
575          return initial_value
576
577      return initial_value_fn
578    else:
579      return super(CollectiveAllReduceExtended,
580                   self)._get_variable_creator_initial_value(
581                       replica_id=replica_id,
582                       device=device,
583                       primary_var=primary_var,
584                       **kwargs)
585
586  def _make_input_context(self):
587    input_context = distribute_lib.InputContext(
588        num_input_pipelines=self._num_workers,
589        input_pipeline_id=self._id_in_cluster,
590        num_replicas_in_sync=self._num_replicas_in_sync)
591    return input_context
592
593  def _experimental_distribute_dataset(self, dataset, options):
594    if (options and options.experimental_replication_mode ==
595        distribute_lib.InputReplicationMode.PER_REPLICA):
596      raise NotImplementedError(
597          "InputReplicationMode.PER_REPLICA "
598          "is only supported in "
599          "`experimental_distribute_datasets_from_function`."
600      )
601    input_context = self._make_input_context()
602    return input_lib.get_distributed_dataset(
603        dataset,
604        self._input_workers_with_options(options),
605        self._container_strategy(),
606        num_replicas_in_sync=self._num_replicas_in_sync,
607        input_context=input_context)
608
609  def _distribute_datasets_from_function(self, dataset_fn, options):
610    if (options and options.experimental_replication_mode ==
611        distribute_lib.InputReplicationMode.PER_REPLICA):
612      raise NotImplementedError(
613          "InputReplicationMode.PER_REPLICA "
614          "is only supported in "
615          " `experimental_distribute_datasets_from_function` "
616          "of tf.distribute.MirroredStrategy")
617    input_context = self._make_input_context()
618    return input_lib.get_distributed_datasets_from_function(
619        dataset_fn=dataset_fn,
620        input_workers=self._input_workers_with_options(options),
621        input_contexts=[input_context],
622        strategy=self._container_strategy())
623
624  def _experimental_distribute_values_from_function(self, value_fn):
625    per_replica_values = []
626    num_local_replicas = len(self.worker_devices)
627    for local_replica_id in range(num_local_replicas):
628      replica_id = (self._id_in_cluster * num_local_replicas +
629                    local_replica_id)
630      value_context = distribute_lib.ValueContext(
631          replica_id, self._num_replicas_in_sync)
632      per_replica_values.append(value_fn(value_context))
633    return distribute_utils.regroup(per_replica_values, always_wrap=True)
634
635  def _make_dataset_iterator(self, dataset):
636    """Distributes the dataset to each local GPU."""
637    input_context = self._make_input_context()
638    return input_lib.DatasetIterator(
639        dataset,
640        self._input_workers,
641        self._container_strategy(),
642        num_replicas_in_sync=self._num_replicas_in_sync,
643        input_context=input_context)
644
645  def _make_input_fn_iterator(
646      self,
647      input_fn,
648      replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
649    """Distributes the input function to each local GPU."""
650    input_context = self._make_input_context()
651    return input_lib.InputFunctionIterator(input_fn, self._input_workers,
652                                           [input_context],
653                                           self._container_strategy())
654
655  def _configure(self,
656                 session_config=None,
657                 cluster_spec=None,
658                 task_type=None,
659                 task_id=None):
660    """Configures the object.
661
662    Args:
663      session_config: a `tf.compat.v1.ConfigProto`
664      cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the
665        cluster configurations.
666      task_type: the current task type, such as "worker".
667      task_id: the current task id.
668
669    Raises:
670      ValueError: if `task_type` is not in the `cluster_spec`.
671    """
672    if cluster_spec:
673      # Use the num_gpus_per_worker recorded in constructor since _configure
674      # doesn't take num_gpus.
675      cluster_resolver = SimpleClusterResolver(
676          cluster_spec=multi_worker_util.normalize_cluster_spec(cluster_spec),
677          task_type=task_type,
678          task_id=task_id,
679          num_accelerators={"GPU": self._num_gpus_per_worker},
680          rpc_layer=self._rpc_layer)
681      self._initialize_multi_worker(cluster_resolver)
682      assert isinstance(self._cross_device_ops,
683                        cross_device_ops_lib.CollectiveAllReduce)
684
685    if session_config:
686      session_config.CopyFrom(self._update_config_proto(session_config))
687
688  def _update_config_proto(self, config_proto):
689    updated_config = copy.deepcopy(config_proto)
690    # Enable the scoped allocator optimization for CollectiveOps.  This
691    # optimization converts many small all-reduces into fewer larger
692    # all-reduces.
693    rewrite_options = updated_config.graph_options.rewrite_options
694    rewrite_options.scoped_allocator_optimization = (
695        rewriter_config_pb2.RewriterConfig.ON)
696    # We turn on ScopedAllocator only for CollectiveReduce op, i.e. enable_op =
697    # ["CollectiveReduce"].  Since we can't assign to a repeated proto field, we
698    # clear and then append.
699    del rewrite_options.scoped_allocator_opts.enable_op[:]
700    rewrite_options.scoped_allocator_opts.enable_op.append("CollectiveReduce")
701
702    if (not ops.executing_eagerly_outside_functions() and
703        self._communication_options.implementation ==
704        collective_util.CommunicationImplementation.NCCL):
705      updated_config.experimental.collective_nccl = True
706
707    if not self._cluster_spec:
708      return updated_config
709
710    assert self._task_type
711    assert self._task_id is not None
712
713    # Collective group leader is needed for collective ops to coordinate
714    # workers.
715    updated_config.experimental.collective_group_leader = (
716        multi_worker_util.collective_leader(self._cluster_spec, self._task_type,
717                                            self._task_id))
718
719    # The device filters prevent communication between workers.
720    del updated_config.device_filters[:]
721    updated_config.device_filters.append(
722        "/job:%s/task:%d" % (self._task_type, self._task_id))
723
724    return updated_config
725
726  def _get_cross_device_ops(self, value):
727    # CollectiveAllReduce works on a predefined set of devices. In most cases
728    # they should be the compute devices, but certain use cases may reduce host
729    # tensors as well (e.g. early stopping). We infer the cross_device_ops to
730    # use based on the number of devices, since inputs don't always have device
731    # annotations. The compute devices one is preferred since we can potentially
732    # leverage NCCL.
733    if isinstance(value, values.DistributedValues):
734      num_devices = len(value._values)  # pylint: disable=protected-access
735    else:
736      num_devices = 1
737    if num_devices == len(self.worker_devices):
738      return self._cross_device_ops
739    else:
740      return self._host_cross_device_ops
741
742  def _gather_to_implementation(self, value, destinations, axis, options):
743    return self._get_cross_device_ops(value)._gather(  # pylint: disable=protected-access
744        value,
745        destinations=destinations,
746        axis=axis,
747        options=options)
748
749  def _reduce_to(self, reduce_op, value, destinations, options):
750    if (isinstance(value, values.Mirrored) and
751        reduce_op == reduce_util.ReduceOp.MEAN):
752      return value
753    assert not isinstance(value, values.Mirrored)
754
755    if (isinstance(value, values.DistributedValues) and
756        len(self.worker_devices) == 1):
757      value = value.values[0]
758
759    # When there are multiple workers, we need to reduce across workers using
760    # collective ops.
761    if (not isinstance(value, values.DistributedValues) and
762        self._num_workers == 1):
763      # This function handles reducing values that are not PerReplica or
764      # Mirrored values. For example, the same value could be present on all
765      # replicas in which case `value` would be a single value or value could
766      # be 0.
767      return cross_device_ops_lib.reduce_non_distributed_value(
768          reduce_op, value, destinations, len(self.worker_devices))
769    return self._get_cross_device_ops(value).reduce(
770        reduce_op,
771        value,
772        destinations=destinations,
773        options=self._communication_options.merge(options))
774
775  def _replica_ctx_all_reduce(self, reduce_op, value, options=None):
776    """Implements `StrategyExtendedV2._replica_ctx_all_reduce`."""
777    # This implementation avoids using `merge_call` and just launches collective
778    # ops in one replica.
779    if options is None:
780      options = collective_util.Options()
781
782    if context.executing_eagerly():
783      # In eager mode, falls back to the default implemenation that uses
784      # `merge_call`. Replica functions are running sequentially in eager mode,
785      # and due to the blocking nature of collective ops, execution will hang if
786      # collective ops are to be launched sequentially.
787      return super()._replica_ctx_all_reduce(reduce_op, value, options)
788
789    replica_context = ds_context.get_replica_context()
790    assert replica_context, (
791        "`StrategyExtended._replica_ctx_all_reduce` must be called in a "
792        "replica context")
793    return self._cross_device_ops._all_reduce(  # pylint: disable=protected-access
794        reduce_op,
795        value,
796        replica_context._replica_id,  # pylint: disable=protected-access
797        options)
798
799  def _check_health(self):
800    while True:
801      if self._check_health_thread_should_stop.is_set():
802        return
803      for job in self._cluster_spec.jobs:
804        for task_id in range(self._cluster_spec.num_tasks(job)):
805          peer = "/job:{}/replica:0/task:{}".format(job, task_id)
806          attempts = 0
807          while True:
808            attempts += 1
809            try:
810              context.context().check_collective_ops_peer_health(
811                  peer, timeout_in_ms=self._check_health_timeout * 1000)
812              # If check_collective_ops_peer_health doesn't raise an Exception,
813              # the peer is healthy.
814              break
815            except (errors.UnavailableError, errors.FailedPreconditionError,
816                    errors.DeadlineExceededError) as e:
817              # TODO(b/151232436): Always raise UnavailableError when a peer
818              # fails. Now there could be many kinds of errors:
819              # - Unavailable: when the peer is not reachable, e.g. it's down.
820              # - FailedPrecondition: when the peer has restarted.
821              if attempts < self._check_health_retry_limit:
822                logging.warning("%s seems down, retrying %d/%d", peer, attempts,
823                                self._check_health_retry_limit)
824                continue
825              logging.error(
826                  "Cluster check alive failed, %s is down, "
827                  "aborting collectives: %s", peer, e)
828              context.context().abort_collective_ops(
829                  errors.UNAVAILABLE,
830                  "cluster check alive failed, {} is down".format(peer))
831              return
832            except Exception as e:  # pylint: disable=broad-except
833              logging.error("Unexpected exception in check alive: %s", e)
834              context.context().abort_collective_ops(
835                  errors.INTERNAL,
836                  "unexecpted exception in check alive: %s" % e)
837              return
838      time.sleep(self._check_health_interval)
839
840  def _start_check_health_thread(self):
841    if not context.executing_eagerly():
842      logging.info("Check health is only supported in eager.")
843      return
844    # Use a dummy all-reduce as a barrier to wait for all workers to be up,
845    # otherwise the check health may fail immediately.
846
847    # Use array_ops.identity to create the dummy tensor so that we have a new
848    # Tensor. If we use constant it may be a cached from on a /job:localhost
849    # device, which will cause some code that relies on tensor.device to error.
850    #
851    # TODO(b/151232436): change to an explicit barrier if we have it.
852    dummy_value = array_ops.identity([])
853    logging.info("Waiting for the cluster, timeout = %s",
854                 self._check_health_initial_timeout or "inf")
855    try:
856      self._host_cross_device_ops.reduce(
857          reduce_util.ReduceOp.SUM,
858          dummy_value,
859          dummy_value,
860          options=collective_util.Options(
861              timeout_seconds=self._check_health_initial_timeout,
862              implementation=collective_util.CommunicationImplementation.RING))
863      if context.is_async():
864        context.async_wait()
865    except errors.DeadlineExceededError:
866      raise RuntimeError(
867          "Timeout waiting for the cluster, timeout is %d seconds" %
868          self._check_health_initial_timeout)
869    logging.info("Cluster is ready.")
870    self._check_health_thread_should_stop = threading.Event()
871    # Start the thread as daemon to avoid it blocking the program from exiting.
872    # We try best to shutdown the thread but __del__ is not guaranteed to be
873    # called when program exists.
874    self._check_health_thread = threading.Thread(
875        target=self._check_health,
876        daemon=True)
877    self._check_health_thread.start()
878
879  def _stop_check_health_thread(self):
880    if getattr(self, "_check_health_thread", None):
881      logging.info("stopping check health thread")
882      self._check_health_thread_should_stop.set()
883      self._check_health_thread.join()
884      self._check_health_thread = None
885      logging.info("check health thread stopped")
886
887  def _warn_nccl_no_gpu(self):
888    if ((self._communication_options.implementation ==
889         collective_util.CommunicationImplementation.NCCL) and
890        self._num_gpus_per_worker == 0):
891      logging.warning("Enabled NCCL communication but no GPUs detected/"
892                      "specified.")
893
894  def _in_multi_worker_mode(self):
895    """Whether this strategy indicates working in multi-worker settings."""
896    return self._num_workers > 1
897
898  @property
899  def experimental_between_graph(self):
900    return True
901
902  @property
903  def experimental_should_init(self):
904    return True
905
906  @property
907  def should_checkpoint(self):
908    return self._is_chief
909
910  @property
911  def should_save_summary(self):
912    return self._is_chief
913
914  @property
915  def _num_replicas_in_sync(self):
916    return len(self.worker_devices) * self._num_workers
917
918  # TODO(priyag): Delete this once all strategies use global batch size.
919  @property
920  def _global_batch_size(self):
921    """`make_dataset_iterator` and `make_numpy_iterator` use global batch size.
922
923    `make_input_fn_iterator` assumes per-replica batching.
924
925    Returns:
926      Boolean.
927    """
928    return True
929
930  def _get_replica_id_in_sync_group(self, replica_id):
931    return self._id_in_cluster * len(self.worker_devices) + replica_id
932
933  def _get_local_replica_id(self, replica_id_in_sync_group):
934    return (replica_id_in_sync_group -
935            self._id_in_cluster * len(self.worker_devices))
936