1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=line-too-long
16"""Library for running a computation across multiple devices.
17
18The intent of this library is that you can write an algorithm in a stylized way
19and it will be usable with a variety of different `tf.distribute.Strategy`
20implementations. Each descendant will implement a different strategy for
21distributing the algorithm across multiple devices/machines.  Furthermore, these
22changes can be hidden inside the specific layers and other library classes that
23need special treatment to run in a distributed setting, so that most users'
24model definition code can run unchanged. The `tf.distribute.Strategy` API works
25the same way with eager and graph execution.
26
27*Guides*
28
29* [TensorFlow v2.x](https://www.tensorflow.org/guide/distributed_training)
30* [TensorFlow v1.x](https://github.com/tensorflow/docs/blob/master/site/en/r1/guide/distribute_strategy.ipynb)
31
32*Tutorials*
33
34* [Distributed Training Tutorials](https://www.tensorflow.org/tutorials/distribute/)
35
36  The tutorials cover how to use `tf.distribute.Strategy` to do distributed
37  training with native Keras APIs, custom training loops,
38  and Estimator APIs. They also cover how to save/load model when using
39  `tf.distribute.Strategy`.
40
41*Glossary*
42
43* _Data parallelism_ is where we run multiple copies of the model
44  on different slices of the input data. This is in contrast to
45  _model parallelism_ where we divide up a single copy of a model
46  across multiple devices.
47  Note: we only support data parallelism for now, but
48  hope to add support for model parallelism in the future.
49* A _device_ is a CPU or accelerator (e.g. GPUs, TPUs) on some machine that
50  TensorFlow can run operations on (see e.g. `tf.device`). You may have multiple
51  devices on a single machine, or be connected to devices on multiple
52  machines. Devices used to run computations are called _worker devices_.
53  Devices used to store variables are _parameter devices_. For some strategies,
54  such as `tf.distribute.MirroredStrategy`, the worker and parameter devices
55  will be the same (see mirrored variables below). For others they will be
56  different. For example, `tf.distribute.experimental.CentralStorageStrategy`
57  puts the variables on a single device (which may be a worker device or may be
58  the CPU), and `tf.distribute.experimental.ParameterServerStrategy` puts the
59  variables on separate machines called _parameter servers_ (see below).
60* A _replica_ is one copy of the model, running on one slice of the
61  input data. Right now each replica is executed on its own
62  worker device, but once we add support for model parallelism
63  a replica may span multiple worker devices.
64* A _host_ is the CPU device on a machine with worker devices, typically
65  used for running input pipelines.
66* A _worker_ is defined to be the physical machine(s) containing the physical
67  devices (e.g. GPUs, TPUs) on which the replicated computation is executed. A
68  worker may contain one or more replicas, but contains at least one
69  replica. Typically one worker will correspond to one machine, but in the case
70  of very large models with model parallelism, one worker may span multiple
71  machines. We typically run one input pipeline per worker, feeding all the
72  replicas on that worker.
73* _Synchronous_, or more commonly _sync_, training is where the updates from
74  each replica are aggregated together before updating the model variables. This
75  is in contrast to _asynchronous_, or _async_ training, where each replica
76  updates the model variables independently. You may also have replicas
77  partitioned into groups which are in sync within each group but async between
78  groups.
79* _Parameter servers_: These are machines that hold a single copy of
80  parameters/variables, used by some strategies (right now just
81  `tf.distribute.experimental.ParameterServerStrategy`). All replicas that want
82  to operate on a variable retrieve it at the beginning of a step and send an
83  update to be applied at the end of the step. These can in principle support
84  either sync or async training, but right now we only have support for async
85  training with parameter servers. Compare to
86  `tf.distribute.experimental.CentralStorageStrategy`, which puts all variables
87  on a single device on the same machine (and does sync training), and
88  `tf.distribute.MirroredStrategy`, which mirrors variables to multiple devices
89  (see below).
90
91* _Replica context_ vs. _Cross-replica context_ vs _Update context_
92
93  A _replica context_ applies
94  when you execute the computation function that was called with `strategy.run`.
95  Conceptually, you're in replica context when executing the computation
96  function that is being replicated.
97
98  An _update context_ is entered in a `tf.distribute.StrategyExtended.update`
99  call.
100
101  An _cross-replica context_ is entered when you enter a `strategy.scope`. This
102  is useful for calling `tf.distribute.Strategy` methods which operate across
103  the replicas (like `reduce_to()`). By default you start in a _replica context_
104  (the "default single _replica context_") and then some methods can switch you
105  back and forth.
106
107* _Distributed value_: Distributed value is represented by the base class
108  `tf.distribute.DistributedValues`. `tf.distribute.DistributedValues` is useful
109  to represent values on multiple devices, and it contains a map from replica id
110  to values. Two representative kinds of `tf.distribute.DistributedValues` are
111  "PerReplica" and "Mirrored" values.
112
113  "PerReplica" values exist on the worker
114  devices, with a different value for each replica. They are produced by
115  iterating through a distributed dataset returned by
116  `tf.distribute.Strategy.experimental_distribute_dataset` and
117  `tf.distribute.Strategy.distribute_datasets_from_function`. They
118  are also the typical result returned by
119  `tf.distribute.Strategy.run`.
120
121  "Mirrored" values are like "PerReplica" values, except we know that the value
122  on all replicas are the same. We can safely read a "Mirrored" value in a
123  cross-replica context by using the value on any replica.
124
125* _Unwrapping_ and _merging_: Consider calling a function `fn` on multiple
126  replicas, like `strategy.run(fn, args=[w])` with an
127  argument `w` that is a `tf.distribute.DistributedValues`. This means `w` will
128  have a map taking replica id `0` to `w0`, replica id `1` to `w1`, etc.
129  `strategy.run()` unwraps `w` before calling `fn`, so it calls `fn(w0)` on
130  device `d0`, `fn(w1)` on device `d1`, etc.  It then merges the return
131  values from `fn()`, which leads to one common object if the returned values
132  are the same object from every replica, or a `DistributedValues` object
133  otherwise.
134
135* _Reductions_ and _all-reduce_: A _reduction_ is a method of aggregating
136  multiple values into one value, like "sum" or "mean". If a strategy is doing
137  sync training, we will perform a reduction on the gradients to a parameter
138  from all replicas before applying the update. _All-reduce_ is an algorithm for
139  performing a reduction on values from multiple devices and making the result
140  available on all of those devices.
141
142* _Mirrored variables_: These are variables that are created on multiple
143  devices, where we keep the variables in sync by applying the same
144  updates to every copy. Mirrored variables are created with
145  `tf.Variable(...synchronization=tf.VariableSynchronization.ON_WRITE...)`.
146  Normally they are only used in synchronous training.
147
148* _SyncOnRead variables_
149
150  _SyncOnRead variables_ are created by
151  `tf.Variable(...synchronization=tf.VariableSynchronization.ON_READ...)`, and
152  they are created on multiple devices. In replica context, each
153  component variable on the local replica can perform reads and writes without
154  synchronization with each other. When the
155  _SyncOnRead variable_ is read in cross-replica context, the values from
156  component variables are aggregated and returned.
157
158  _SyncOnRead variables_ bring a lot of custom configuration difficulty to the
159  underlying logic, so we do not encourage users to instantiate and use
160  _SyncOnRead variable_ on their own. We have mainly used _SyncOnRead
161  variables_ for use cases such as batch norm and metrics. For performance
162  reasons, we often don't need to keep these statistics in sync every step and
163  they can be accumulated on each replica independently. The only time we want
164  to sync them is reporting or checkpointing, which typically happens in
165  cross-replica context. _SyncOnRead variables_ are also often used by advanced
166  users who want to control when variable values are aggregated. For example,
167  users sometimes want to maintain gradients independently on each replica for a
168  couple of steps without aggregation.
169
170* _Distribute-aware layers_
171
172  Layers are generally called in a replica context, except when defining a
173  Keras functional model. `tf.distribute.in_cross_replica_context` will let you
174  determine which case you are in. If in a replica context,
175  the `tf.distribute.get_replica_context` function will return the default
176  replica context outside a strategy scope, `None` within a strategy scope, and
177  a `tf.distribute.ReplicaContext` object inside a strategy scope and within a
178  `tf.distribute.Strategy.run` function. The `ReplicaContext` object has an
179  `all_reduce` method for aggregating across all replicas.
180
181
182Note that we provide a default version of `tf.distribute.Strategy` that is
183used when no other strategy is in scope, that provides the same API with
184reasonable default behavior.
185"""
186# pylint: enable=line-too-long
187
188from __future__ import absolute_import
189from __future__ import division
190from __future__ import print_function
191
192import collections
193import copy
194import enum  # pylint: disable=g-bad-import-order
195import threading
196import weakref
197
198import six
199
200from tensorflow.python.autograph.core import ag_ctx as autograph_ctx
201from tensorflow.python.autograph.impl import api as autograph
202from tensorflow.python.data.ops import dataset_ops
203from tensorflow.python.distribute import collective_util
204from tensorflow.python.distribute import device_util
205from tensorflow.python.distribute import distribution_strategy_context
206from tensorflow.python.distribute import numpy_dataset
207from tensorflow.python.distribute import reduce_util
208from tensorflow.python.eager import context as eager_context
209from tensorflow.python.eager import def_function
210from tensorflow.python.eager import monitoring
211from tensorflow.python.framework import constant_op
212from tensorflow.python.framework import dtypes
213from tensorflow.python.framework import ops
214from tensorflow.python.framework import tensor_shape
215from tensorflow.python.framework import tensor_util
216from tensorflow.python.ops import array_ops
217from tensorflow.python.ops import control_flow_ops
218from tensorflow.python.ops import custom_gradient
219from tensorflow.python.ops import math_ops
220from tensorflow.python.ops import resource_variable_ops
221from tensorflow.python.ops import summary_ops_v2
222from tensorflow.python.ops import variable_scope
223from tensorflow.python.ops.losses import losses_impl
224from tensorflow.python.platform import tf_logging
225from tensorflow.python.training.tracking import base as trackable
226from tensorflow.python.util import deprecation
227from tensorflow.python.util import nest
228from tensorflow.python.util import tf_contextlib
229from tensorflow.python.util.deprecation import deprecated
230from tensorflow.python.util.tf_export import tf_export
231from tensorflow.tools.docs import doc_controls
232
233
234# ------------------------------------------------------------------------------
235# Context tracking whether in a strategy.update() or .update_non_slot() call.
236
237
238_update_replica_id = threading.local()
239
240
241def get_update_replica_id():
242  """Get the current device if in a `tf.distribute.Strategy.update()` call."""
243  try:
244    return _update_replica_id.current
245  except AttributeError:
246    return None
247
248
249class UpdateContext(object):
250  """Context manager when you are in `update()` or `update_non_slot()`."""
251
252  __slots__ = ["_replica_id", "_old_replica_id"]
253
254  def __init__(self, replica_id):
255    self._replica_id = replica_id
256    self._old_replica_id = None
257
258  def __enter__(self):
259    self._old_replica_id = get_update_replica_id()
260    _update_replica_id.current = self._replica_id
261
262  def __exit__(self, exception_type, exception_value, traceback):
263    del exception_type, exception_value, traceback
264    _update_replica_id.current = self._old_replica_id
265
266
267# ------------------------------------------------------------------------------
268# Public utility functions.
269
270
271@tf_export(v1=["distribute.get_loss_reduction"])
272def get_loss_reduction():
273  """`tf.distribute.ReduceOp` corresponding to the last loss reduction.
274
275  This is used to decide whether loss should be scaled in optimizer (used only
276  for estimator + v1 optimizer use case).
277
278  Returns:
279    `tf.distribute.ReduceOp` corresponding to the last loss reduction for
280    estimator and v1 optimizer use case. `tf.distribute.ReduceOp.SUM` otherwise.
281  """
282  if not distribution_strategy_context.get_strategy()._scale_loss_for_estimator:  # pylint: disable=protected-access
283    # If we are not in Estimator context then return 'SUM'. We do not need to
284    # scale loss in the optimizer.
285    return reduce_util.ReduceOp.SUM
286  last_reduction = ops.get_default_graph()._last_loss_reduction  # pylint: disable=protected-access
287  if (last_reduction == losses_impl.Reduction.SUM or
288      last_reduction == "sum"):  # Check for tf.keras.losses.Reduction.SUM
289    return reduce_util.ReduceOp.SUM
290  return reduce_util.ReduceOp.MEAN
291
292
293# ------------------------------------------------------------------------------
294# Internal API for validating the current thread mode
295
296
297def _require_cross_replica_or_default_context_extended(extended,
298                                                       error_message=None):
299  """Verify in cross-replica context."""
300  context = _get_per_thread_mode()
301  cross_replica = context.cross_replica_context
302  if cross_replica is not None and cross_replica.extended is extended:
303    return
304  if context is _get_default_replica_mode():
305    return
306  strategy = extended._container_strategy()  # pylint: disable=protected-access
307  # We have an error to report, figure out the right message.
308  if context.strategy is not strategy:
309    _wrong_strategy_scope(strategy, context)
310  assert cross_replica is None
311  if not error_message:
312    error_message = ("Method requires being in cross-replica context, use "
313                     "get_replica_context().merge_call()")
314  raise RuntimeError(error_message)
315
316
317def _wrong_strategy_scope(strategy, context):
318  # Figure out the right error message.
319  if not distribution_strategy_context.has_strategy():
320    raise RuntimeError(
321        'Need to be inside "with strategy.scope()" for %s' %
322        (strategy,))
323  else:
324    raise RuntimeError(
325        "Mixing different tf.distribute.Strategy objects: %s is not %s" %
326        (context.strategy, strategy))
327
328
329def require_replica_context(replica_ctx):
330  """Verify in `replica_ctx` replica context."""
331  context = _get_per_thread_mode()
332  if context.replica_context is replica_ctx: return
333  # We have an error to report, figure out the right message.
334  if context.replica_context is None:
335    raise RuntimeError("Need to be inside `call_for_each_replica()`")
336  if context.strategy is replica_ctx.strategy:
337    # Two different ReplicaContexts with the same tf.distribute.Strategy.
338    raise RuntimeError("Mismatching ReplicaContext.")
339  raise RuntimeError(
340      "Mismatching tf.distribute.Strategy objects: %s is not %s." %
341      (context.strategy, replica_ctx.strategy))
342
343
344def _require_strategy_scope_strategy(strategy):
345  """Verify in a `strategy.scope()` in this thread."""
346  context = _get_per_thread_mode()
347  if context.strategy is strategy: return
348  _wrong_strategy_scope(strategy, context)
349
350
351def _require_strategy_scope_extended(extended):
352  """Verify in a `distribution_strategy.scope()` in this thread."""
353  context = _get_per_thread_mode()
354  if context.strategy.extended is extended: return
355  # Report error.
356  strategy = extended._container_strategy()  # pylint: disable=protected-access
357  _wrong_strategy_scope(strategy, context)
358
359
360# ------------------------------------------------------------------------------
361# Internal context managers used to implement the DistributionStrategy
362# base class
363
364
365class _CurrentDistributionContext(object):
366  """Context manager setting the current `tf.distribute.Strategy`.
367
368  Also: overrides the variable creator and optionally the current device.
369  """
370
371  def __init__(self,
372               strategy,
373               var_creator_scope,
374               var_scope=None,
375               default_device=None):
376    self._context = distribution_strategy_context._CrossReplicaThreadMode(  # pylint: disable=protected-access
377        strategy)
378    self._var_creator_scope = var_creator_scope
379    self._var_scope = var_scope
380    if default_device:
381      self._device_scope = ops.device(default_device)
382    else:
383      self._device_scope = None
384    self._same_scope_again_count = 0
385
386  def __enter__(self):
387    # Allow this scope to be entered if this strategy is already in scope.
388    if distribution_strategy_context.has_strategy():
389      _require_cross_replica_or_default_context_extended(
390          self._context.strategy.extended)
391      self._same_scope_again_count += 1
392    else:
393      _push_per_thread_mode(self._context)
394      if self._var_scope:
395        self._var_scope.__enter__()
396      self._var_creator_scope.__enter__()
397      if self._device_scope:
398        self._device_scope.__enter__()
399    return self._context.strategy
400
401  def __exit__(self, exception_type, exception_value, traceback):
402    if self._same_scope_again_count > 0:
403      self._same_scope_again_count -= 1
404      return
405    if self._device_scope:
406      try:
407        self._device_scope.__exit__(exception_type, exception_value, traceback)
408      except RuntimeError as e:
409        six.raise_from(
410            RuntimeError("Device scope nesting error: move call to "
411                         "tf.distribute.set_strategy() out of `with` scope."),
412            e)
413
414    try:
415      self._var_creator_scope.__exit__(
416          exception_type, exception_value, traceback)
417    except RuntimeError as e:
418      six.raise_from(
419          RuntimeError("Variable creator scope nesting error: move call to "
420                       "tf.distribute.set_strategy() out of `with` scope."),
421          e)
422
423    if self._var_scope:
424      try:
425        self._var_scope.__exit__(exception_type, exception_value, traceback)
426      except RuntimeError as e:
427        six.raise_from(
428            RuntimeError("Variable scope nesting error: move call to "
429                         "tf.distribute.set_strategy() out of `with` scope."),
430            e)
431    _pop_per_thread_mode()
432
433
434# TODO(yuefengz): add more replication modes.
435@tf_export("distribute.InputReplicationMode")
436class InputReplicationMode(enum.Enum):
437  """Replication mode for input function.
438
439  * `PER_WORKER`: The input function will be called on each worker
440    independently, creating as many input pipelines as number of workers.
441    Replicas will dequeue from the local Dataset on their worker.
442    `tf.distribute.Strategy` doesn't manage any state sharing between such
443    separate input pipelines.
444  * `PER_REPLICA`: The input function will be called on each replica separately.
445    `tf.distribute.Strategy` doesn't manage any state sharing between such
446    separate input pipelines.
447  """
448  PER_WORKER = "PER_WORKER"
449  PER_REPLICA = "PER_REPLICA"
450
451
452@tf_export("distribute.InputContext")
453class InputContext(object):
454  """A class wrapping information needed by an input function.
455
456  This is a context class that is passed to the user's input function and
457  contains information about the compute replicas and input pipelines. The
458  number of compute replicas (in sync training) helps compute the local batch
459  size from the desired global batch size for each replica. The input pipeline
460  information can be used to return a different subset of the input in each
461  replica (for e.g. shard the input pipeline, use a different input
462  source etc).
463  """
464
465  __slots__ = [
466      "_num_input_pipelines", "_input_pipeline_id", "_num_replicas_in_sync"
467  ]
468
469  def __init__(self,
470               num_input_pipelines=1,
471               input_pipeline_id=0,
472               num_replicas_in_sync=1):
473    """Initializes an InputContext object.
474
475    Args:
476      num_input_pipelines: the number of input pipelines in a cluster.
477      input_pipeline_id: the current input pipeline id, should be an int in
478        [0,`num_input_pipelines`).
479      num_replicas_in_sync: the number of replicas that are in sync.
480    """
481    self._num_input_pipelines = num_input_pipelines
482    self._input_pipeline_id = input_pipeline_id
483    self._num_replicas_in_sync = num_replicas_in_sync
484
485  @property
486  def num_replicas_in_sync(self):
487    """Returns the number of compute replicas in sync."""
488    return self._num_replicas_in_sync
489
490  @property
491  def input_pipeline_id(self):
492    """Returns the input pipeline ID."""
493    return self._input_pipeline_id
494
495  @property
496  def num_input_pipelines(self):
497    """Returns the number of input pipelines."""
498    return self._num_input_pipelines
499
500  def get_per_replica_batch_size(self, global_batch_size):
501    """Returns the per-replica batch size.
502
503    Args:
504      global_batch_size: the global batch size which should be divisible by
505        `num_replicas_in_sync`.
506
507    Returns:
508      the per-replica batch size.
509
510    Raises:
511      ValueError: if `global_batch_size` not divisible by
512        `num_replicas_in_sync`.
513    """
514    if global_batch_size % self._num_replicas_in_sync != 0:
515      raise ValueError("The `global_batch_size` %r is not divisible by "
516                       "`num_replicas_in_sync` %r " %
517                       (global_batch_size, self._num_replicas_in_sync))
518    return global_batch_size // self._num_replicas_in_sync
519
520  def __str__(self):
521    return "tf.distribute.InputContext(input pipeline id {}, total: {})".format(
522        self.input_pipeline_id, self.num_input_pipelines)
523
524
525@tf_export("distribute.experimental.ValueContext", v1=[])
526class ValueContext(object):
527  """A class wrapping information needed by a distribute function.
528
529  This is a context class that is passed to the `value_fn` in
530  `strategy.experimental_distribute_values_from_function` and contains
531  information about the compute replicas. The `num_replicas_in_sync` and
532  `replica_id` can be used to customize the value on each replica.
533
534  Example usage:
535
536  1. Directly constructed.
537
538  >>> def value_fn(context):
539  ...   return context.replica_id_in_sync_group/context.num_replicas_in_sync
540  >>> context = tf.distribute.experimental.ValueContext(
541  ...   replica_id_in_sync_group=2, num_replicas_in_sync=4)
542  >>> per_replica_value = value_fn(context)
543  >>> per_replica_value
544  0.5
545
546  2. Passed in by `experimental_distribute_values_from_function`.
547
548  >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
549  >>> def value_fn(value_context):
550  ...   return value_context.num_replicas_in_sync
551  >>> distributed_values = (
552  ...      strategy.experimental_distribute_values_from_function(
553  ...        value_fn))
554  >>> local_result = strategy.experimental_local_results(distributed_values)
555  >>> local_result
556  (2, 2)
557
558  """
559
560  __slots__ = ["_replica_id_in_sync_group", "_num_replicas_in_sync"]
561
562  def __init__(self,
563               replica_id_in_sync_group=0,
564               num_replicas_in_sync=1):
565    """Initializes an ValueContext object.
566
567    Args:
568      replica_id_in_sync_group: the current replica_id, should be an int in
569        [0,`num_replicas_in_sync`).
570      num_replicas_in_sync: the number of replicas that are in sync.
571    """
572    self._replica_id_in_sync_group = replica_id_in_sync_group
573    self._num_replicas_in_sync = num_replicas_in_sync
574
575  @property
576  def num_replicas_in_sync(self):
577    """Returns the number of compute replicas in sync."""
578    return self._num_replicas_in_sync
579
580  @property
581  def replica_id_in_sync_group(self):
582    """Returns the replica ID."""
583    return self._replica_id_in_sync_group
584
585  def __str__(self):
586    return (("tf.distribute.ValueContext(replica id {}, "
587             " total replicas in sync: ""{})")
588            .format(self.replica_id_in_sync_group, self.num_replicas_in_sync))
589
590
591@tf_export("distribute.RunOptions")
592class RunOptions(
593    collections.namedtuple("RunOptions", [
594        "experimental_enable_dynamic_batch_size",
595        "experimental_bucketizing_dynamic_shape",
596    ])):
597  """Run options for `strategy.run`.
598
599  This can be used to hold some strategy specific configs.
600
601  Attributes:
602    experimental_enable_dynamic_batch_size: Boolean. Only applies to
603      TPUStrategy. Default to True. If True, TPUStrategy will enable dynamic
604      padder to support dynamic batch size for the inputs. Otherwise only static
605      shape inputs are allowed.
606    experimental_bucketizing_dynamic_shape: Boolean. Only applies to
607      TPUStrategy. Default to False. If True, TPUStrategy will automatic
608      bucketize inputs passed into `run` if the input shape is
609      dynamic. This is a performance optimization to reduce XLA recompilation,
610      which should not have impact on correctness.
611  """
612
613  def __new__(cls,
614              experimental_enable_dynamic_batch_size=True,
615              experimental_bucketizing_dynamic_shape=False):
616    return super(RunOptions,
617                 cls).__new__(cls, experimental_enable_dynamic_batch_size,
618                              experimental_bucketizing_dynamic_shape)
619
620
621@tf_export("distribute.InputOptions", v1=[])
622class InputOptions(
623    collections.namedtuple("InputOptions", [
624        "experimental_prefetch_to_device",
625        "experimental_replication_mode",
626        "experimental_place_dataset_on_device",
627    ])):
628  """Run options for `experimental_distribute_dataset(s_from_function)`.
629
630  This can be used to hold some strategy specific configs.
631
632  ```python
633  # Setup TPUStrategy
634  resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
635  tf.config.experimental_connect_to_cluster(resolver)
636  tf.tpu.experimental.initialize_tpu_system(resolver)
637  strategy = tf.distribute.TPUStrategy(resolver)
638
639  dataset = tf.data.Dataset.range(16)
640  distributed_dataset_on_host = (
641      strategy.experimental_distribute_dataset(
642          dataset,
643          tf.distribute.InputOptions(
644              experimental_replication_mode=
645              experimental_replication_mode.PER_WORKER,
646              experimental_place_dataset_on_device=False)))
647  ```
648
649  Attributes:
650    experimental_prefetch_to_device: Boolean. Defaults to True. If True, dataset
651      elements will be prefetched to accelerator device memory. When False,
652      dataset elements are prefetched to host device memory. Must be False when
653      using TPUEmbedding API. experimental_prefetch_to_device can only be used
654      with experimental_replication_mode=PER_WORKER
655    experimental_replication_mode: Replication mode for the input function.
656      Currently, the InputReplicationMode.PER_REPLICA is only supported with
657      tf.distribute.MirroredStrategy.
658      experimental_distribute_datasets_from_function.
659      The default value is InputReplicationMode.PER_WORKER.
660    experimental_place_dataset_on_device: Boolean. Default to False. When True,
661      dataset will be placed on the device, otherwise it will remain on the
662      host. experimental_place_dataset_on_device=True can only be used with
663      experimental_replication_mode=PER_REPLICA
664  """
665
666  def __new__(cls,
667              experimental_prefetch_to_device=True,
668              experimental_replication_mode=InputReplicationMode.PER_WORKER,
669              experimental_place_dataset_on_device=False):
670    return super(InputOptions,
671                 cls).__new__(cls, experimental_prefetch_to_device,
672                              experimental_replication_mode,
673                              experimental_place_dataset_on_device)
674
675# ------------------------------------------------------------------------------
676# Base classes for all distribution strategies.
677
678
679# Base class for v1 Strategy and v2 Strategy classes. For API's specific to
680# v1/v2 Strategy, add to implementing classes of StrategyBase.
681# pylint: disable=line-too-long
682class StrategyBase(object):
683  """A state & compute distribution policy on a list of devices.
684
685  See [the guide](https://www.tensorflow.org/guide/distributed_training)
686  for overview and examples. See `tf.distribute.StrategyExtended` and
687  [`tf.distribute`](https://www.tensorflow.org/api_docs/python/tf/distribute)
688  for a glossary of concepts mentioned on this page such as "per-replica",
689  _replica_, and _reduce_.
690
691  In short:
692
693  * To use it with Keras `compile`/`fit`,
694    [please
695    read](https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_keras).
696  * You may pass descendant of `tf.distribute.Strategy` to
697    `tf.estimator.RunConfig` to specify how a `tf.estimator.Estimator`
698    should distribute its computation. See
699    [guide](https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_estimator_limited_support).
700  * Otherwise, use `tf.distribute.Strategy.scope` to specify that a
701    strategy should be used when building an executing your model.
702    (This puts you in the "cross-replica context" for this strategy, which
703    means the strategy is put in control of things like variable placement.)
704  * If you are writing a custom training loop, you will need to call a few more
705    methods,
706    [see the
707    guide](https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_custom_training_loops):
708
709      * Start by creating a `tf.data.Dataset` normally.
710      * Use `tf.distribute.Strategy.experimental_distribute_dataset` to convert
711        a `tf.data.Dataset` to something that produces "per-replica" values.
712        If you want to manually specify how the dataset should be partitioned
713        across replicas, use
714        `tf.distribute.Strategy.distribute_datasets_from_function`
715        instead.
716      * Use `tf.distribute.Strategy.run` to run a function
717        once per replica, taking values that may be "per-replica" (e.g.
718        from a `tf.distribute.DistributedDataset` object) and returning
719        "per-replica" values.
720        This function is executed in "replica context", which means each
721        operation is performed separately on each replica.
722      * Finally use a method (such as `tf.distribute.Strategy.reduce`) to
723        convert the resulting "per-replica" values into ordinary `Tensor`s.
724
725  A custom training loop can be as simple as:
726
727  ```
728  with my_strategy.scope():
729    @tf.function
730    def distribute_train_epoch(dataset):
731      def replica_fn(input):
732        # process input and return result
733        return result
734
735      total_result = 0
736      for x in dataset:
737        per_replica_result = my_strategy.run(replica_fn, args=(x,))
738        total_result += my_strategy.reduce(tf.distribute.ReduceOp.SUM,
739                                           per_replica_result, axis=None)
740      return total_result
741
742    dist_dataset = my_strategy.experimental_distribute_dataset(dataset)
743    for _ in range(EPOCHS):
744      train_result = distribute_train_epoch(dist_dataset)
745  ```
746
747  This takes an ordinary `dataset` and `replica_fn` and runs it
748  distributed using a particular `tf.distribute.Strategy` named
749  `my_strategy` above. Any variables created in `replica_fn` are created
750  using `my_strategy`'s policy, and library functions called by
751  `replica_fn` can use the `get_replica_context()` API to implement
752  distributed-specific behavior.
753
754  You can use the `reduce` API to aggregate results across replicas and use
755  this as a return value from one iteration over a
756  `tf.distribute.DistributedDataset`. Or
757  you can use `tf.keras.metrics` (such as loss, accuracy, etc.) to
758  accumulate metrics across steps in a given epoch.
759
760  See the
761  [custom training loop
762  tutorial](https://www.tensorflow.org/tutorials/distribute/custom_training)
763  for a more detailed example.
764
765  Note: `tf.distribute.Strategy` currently does not support TensorFlow's
766  partitioned variables (where a single variable is split across multiple
767  devices) at this time.
768  """
769  # pylint: enable=line-too-long
770
771  # TODO(josh11b): Partitioned computations, state; sharding
772  # TODO(josh11b): Model parallelism: "replicas" with multiple devices; shuffling
773
774  def __init__(self, extended):
775    self._extended = extended
776
777    # Flag that is used to indicate whether distribution strategy is used with
778    # Estimator. This is required for backward compatibility of loss scaling
779    # when using v1 optimizer with estimator.
780    self._scale_loss_for_estimator = False
781
782    if not hasattr(extended, "_retrace_functions_for_each_device"):
783      # pylint: disable=protected-access
784      # `extended._retrace_functions_for_each_device` dictates
785      # whether the same function will be retraced when it is called on
786      # different devices.
787      try:
788        extended._retrace_functions_for_each_device = (
789            len(extended.worker_devices) > 1)
790        distribution_strategy_replica_gauge.get_cell("num_replicas").set(
791            self.num_replicas_in_sync)
792      except:  # pylint: disable=bare-except
793        # Default for the case where extended.worker_devices can't return
794        # a sensible value.
795        extended._retrace_functions_for_each_device = True
796
797    # Below are the dicts of axis(int) -> `tf.function`.
798    self._mean_reduce_helper_fns = {}
799    self._reduce_sum_fns = {}
800
801    # Whether this strategy is designed to work with `ClusterCoordinator`.
802    self._should_use_with_coordinator = False
803
804  @property
805  def extended(self):
806    """`tf.distribute.StrategyExtended` with additional methods."""
807    return self._extended
808
809  @tf_contextlib.contextmanager
810  def _scale_loss_for_estimator_enabled(self):
811    """Scope which sets a flag used for scaling losses in optimizer.
812
813    Yields:
814      `_scale_loss_for_estimator_enabled` is a context manager with a
815      side effect, but doesn't return a value.
816    """
817    self._scale_loss_for_estimator = True
818    try:
819      yield
820    finally:
821      self._scale_loss_for_estimator = False
822
823  # pylint: disable=line-too-long
824  def scope(self):
825    """Context manager to make the strategy current and distribute variables.
826
827    This method returns a context manager, and is used as follows:
828
829    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
830    >>> # Variable created inside scope:
831    >>> with strategy.scope():
832    ...   mirrored_variable = tf.Variable(1.)
833    >>> mirrored_variable
834    MirroredVariable:{
835      0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>,
836      1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=1.0>
837    }
838    >>> # Variable created outside scope:
839    >>> regular_variable = tf.Variable(1.)
840    >>> regular_variable
841    <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>
842
843    _What happens when Strategy.scope is entered?_
844
845    * `strategy` is installed in the global context as the "current" strategy.
846      Inside this scope, `tf.distribute.get_strategy()` will now return this
847      strategy. Outside this scope, it returns the default no-op strategy.
848    * Entering the scope also enters the "cross-replica context". See
849      `tf.distribute.StrategyExtended` for an explanation on cross-replica and
850      replica contexts.
851    * Variable creation inside `scope` is intercepted by the strategy. Each
852      strategy defines how it wants to affect the variable creation. Sync
853      strategies like `MirroredStrategy`, `TPUStrategy` and
854      `MultiWorkerMiroredStrategy` create variables replicated on each replica,
855      whereas `ParameterServerStrategy` creates variables on the parameter
856      servers. This is done using a custom `tf.variable_creator_scope`.
857    * In some strategies, a default device scope may also be entered: in
858      `MultiWorkerMiroredStrategy`, a default device scope of "/CPU:0" is
859      entered on each worker.
860
861    Note: Entering a scope does not automatically distribute a computation, except
862      in the case of high level training framework like keras `model.fit`. If
863      you're not using `model.fit`, you
864      need to use `strategy.run` API to explicitly distribute that computation.
865      See an example in the [custom training loop tutorial](https://www.tensorflow.org/tutorials/distribute/custom_training).
866
867
868    _What should be in scope and what should be outside?_
869
870    There are a number of requirements on what needs to happen inside the scope.
871    However, in places where we have information about which strategy is in use,
872    we often enter the scope for the user, so they don't have to do it
873    explicitly (i.e. calling those either inside or outside the scope is OK).
874
875    * Anything that creates variables that should be distributed variables
876      must be called in a `strategy.scope`. This can be accomplished either by
877      directly calling the variable creating function within the scope context,
878      or by relying on another API like `strategy.run` or `keras.Model.fit` to
879      automatically enter it for you. Any variable that is created outside scope
880      will not be distributed and may have performance implications. Some common
881      objects that create variables in TF are Models, Optimizers, Metrics. Such
882      objects should always be initiliazized in the scope, and any functions
883      that may lazily create variables (e.g., `Model.__call__()`, tracing a
884      `tf.function`, etc.) should similarly be called within scope. Another
885      source of variable creation can be a checkpoint restore - when variables
886      are created lazily. Note that any variable created inside a strategy
887      captures the strategy information. So reading and writing to these
888      variables outside the `strategy.scope` can also work seamlessly, without
889      the user having to enter the scope.
890    * Some strategy APIs (such as `strategy.run` and `strategy.reduce`) which
891      require to be in a strategy's scope, enter the scope automatically, which
892      means when using those APIs you don't need to explicitly enter the scope
893      yourself.
894    * When a `tf.keras.Model` is created inside a `strategy.scope`, the Model
895      object captures the scope information. When high level training framework
896      methods such as `model.compile`, `model.fit`, etc. are then called, the
897      captured scope will be automatically entered, and the associated strategy
898      will be used to distribute the training etc. See a detailed example in
899      [distributed keras tutorial](https://www.tensorflow.org/tutorials/distribute/keras).
900      WARNING: Simply calling `model(..)` does not automatically enter the
901      captured scope -- only high level training framework APIs support this
902      behavior: `model.compile`, `model.fit`, `model.evaluate`, `model.predict`
903      and `model.save` can all be called inside or outside the scope.
904    * The following can be either inside or outside the scope:
905        * Creating the input datasets
906        * Defining `tf.function`s that represent your training step
907        * Saving APIs such as `tf.saved_model.save`. Loading creates variables,
908          so that should go inside the scope if you want to train the model in a
909          distributed way.
910        * Checkpoint saving. As mentioned above - `checkpoint.restore` may
911          sometimes need to be inside scope if it creates variables.
912
913    Returns:
914      A context manager.
915    """
916    return self._extended._scope(self)  # pylint: disable=protected-access
917  # pylint: enable=line-too-long
918
919  @doc_controls.do_not_doc_inheritable  # DEPRECATED, moving to `extended`
920  def colocate_vars_with(self, colocate_with_variable):
921    """DEPRECATED: use extended.colocate_vars_with() instead."""
922    return self._extended.colocate_vars_with(colocate_with_variable)
923
924  @doc_controls.do_not_generate_docs  # DEPRECATED: TF 1.x only
925  def make_dataset_iterator(self, dataset):
926    """DEPRECATED TF 1.x ONLY."""
927    return self._extended._make_dataset_iterator(dataset)  # pylint: disable=protected-access
928
929  @doc_controls.do_not_generate_docs  # DEPRECATED: TF 1.x only
930  def make_input_fn_iterator(self,
931                             input_fn,
932                             replication_mode=InputReplicationMode.PER_WORKER):
933    """DEPRECATED TF 1.x ONLY."""
934    if replication_mode != InputReplicationMode.PER_WORKER:
935      raise ValueError(
936          "Input replication mode not supported: %r" % replication_mode)
937    with self.scope():
938      return self.extended._make_input_fn_iterator(  # pylint: disable=protected-access
939          input_fn, replication_mode=replication_mode)
940
941  @doc_controls.do_not_generate_docs  # DEPRECATED: TF 1.x only
942  def experimental_run(self, fn, input_iterator=None):
943    """DEPRECATED TF 1.x ONLY."""
944    with self.scope():
945      args = (input_iterator.get_next(),) if input_iterator is not None else ()
946    return self.run(fn, args=args)
947
948  def experimental_distribute_dataset(self, dataset, options=None):
949    # pylint: disable=line-too-long
950    """Creates `tf.distribute.DistributedDataset` from `tf.data.Dataset`.
951
952    The returned `tf.distribute.DistributedDataset` can be iterated over
953    similar to regular datasets.
954    NOTE: The user cannot add any more transformations to a
955    `tf.distribute.DistributedDataset`. You can only create an iterator or
956    examine the `tf.TypeSpec` of the data generated by it. See API docs of
957    `tf.distribute.DistributedDataset` to learn more.
958
959    The following is an example:
960
961    >>> global_batch_size = 2
962    >>> # Passing the devices is optional.
963    ... strategy = tf.distribute.MirroredStrategy(devices=["GPU:0", "GPU:1"])
964    >>> # Create a dataset
965    ... dataset = tf.data.Dataset.range(4).batch(global_batch_size)
966    >>> # Distribute that dataset
967    ... dist_dataset = strategy.experimental_distribute_dataset(dataset)
968    >>> @tf.function
969    ... def replica_fn(input):
970    ...   return input*2
971    >>> result = []
972    >>> # Iterate over the `tf.distribute.DistributedDataset`
973    ... for x in dist_dataset:
974    ...   # process dataset elements
975    ...   result.append(strategy.run(replica_fn, args=(x,)))
976    >>> print(result)
977    [PerReplica:{
978      0: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([0])>,
979      1: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([2])>
980    }, PerReplica:{
981      0: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([4])>,
982      1: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([6])>
983    }]
984
985
986    Three key actions happening under the hood of this method are batching,
987    sharding, and prefetching.
988
989    In the code snippet above, `dataset` is batched by `global_batch_size`, and
990    calling `experimental_distribute_dataset` on it rebatches `dataset` to a
991    new batch size that is equal to the global batch size divided by the number
992    of replicas in sync. We iterate through it using a Pythonic for loop.
993    `x` is a `tf.distribute.DistributedValues` containing data for all replicas,
994    and each replica gets data of the new batch size.
995    `tf.distribute.Strategy.run` will take care of feeding the right per-replica
996    data in `x` to the right `replica_fn` executed on each replica.
997
998    Sharding contains autosharding across multiple workers and within every
999    worker. First, in multi-worker distributed training (i.e. when you use
1000    `tf.distribute.experimental.MultiWorkerMirroredStrategy`
1001    or `tf.distribute.TPUStrategy`), autosharding a dataset over a set of
1002    workers means that each worker is assigned a subset of the entire dataset
1003    (if the right `tf.data.experimental.AutoShardPolicy` is set). This is to
1004    ensure that at each step, a global batch size of non-overlapping dataset
1005    elements will be processed by each worker. Autosharding has a couple of
1006    different options that can be specified using
1007    `tf.data.experimental.DistributeOptions`. Then, sharding within each worker
1008    means the method will split the data among all the worker devices (if more
1009    than one a present). This will happen regardless of multi-worker
1010    autosharding.
1011
1012    Note: for autosharding across multiple workers, the default mode is
1013    `tf.data.experimental.AutoShardPolicy.AUTO`. This mode
1014    will attempt to shard the input dataset by files if the dataset is
1015    being created out of reader datasets (e.g. `tf.data.TFRecordDataset`,
1016    `tf.data.TextLineDataset`, etc.) or otherwise shard the dataset by data,
1017    where each of the workers will read the entire dataset and only process the
1018    shard assigned to it. However, if you have less than one input file per
1019    worker, we suggest that you disable dataset autosharding across workers by
1020    setting the `tf.data.experimental.DistributeOptions.auto_shard_policy` to be
1021    `tf.data.experimental.AutoShardPolicy.OFF`.
1022
1023    By default, this method adds a prefetch transformation at the end of the
1024    user provided `tf.data.Dataset` instance. The argument to the prefetch
1025    transformation which is `buffer_size` is equal to the number of replicas in
1026    sync.
1027
1028    If the above batch splitting and dataset sharding logic is undesirable,
1029    please use
1030    `tf.distribute.Strategy.distribute_datasets_from_function`
1031    instead, which does not do any automatic batching or sharding for you.
1032
1033    Note: If you are using TPUStrategy, the order in which the data is processed
1034    by the workers when using
1035    `tf.distribute.Strategy.experimental_distribute_dataset` or
1036    `tf.distribute.Strategy.distribute_datasets_from_function` is
1037    not guaranteed. This is typically required if you are using
1038    `tf.distribute` to scale prediction. You can however insert an index for
1039    each element in the batch and order outputs accordingly. Refer to [this
1040    snippet](https://www.tensorflow.org/tutorials/distribute/input#caveats)
1041    for an example of how to order outputs.
1042
1043    Note: Stateful dataset transformations are currently not supported with
1044    `tf.distribute.experimental_distribute_dataset` or
1045    `tf.distribute.distribute_datasets_from_function`. Any stateful
1046    ops that the dataset may have are currently ignored. For example, if your
1047    dataset has a `map_fn` that uses `tf.random.uniform` to rotate an image,
1048    then you have a dataset graph that depends on state (i.e the random seed) on
1049    the local machine where the python process is being executed.
1050
1051    For a tutorial on more usage and properties of this method, refer to the
1052    [tutorial on distributed input](https://www.tensorflow.org/tutorials/distribute/input#tfdistributestrategyexperimental_distribute_dataset).
1053    If you are interested in last partial batch handling, read [this section](https://www.tensorflow.org/tutorials/distribute/input#partial_batches).
1054
1055    Args:
1056      dataset: `tf.data.Dataset` that will be sharded across all replicas using
1057        the rules stated above.
1058      options: `tf.distribute.InputOptions` used to control options on how this
1059        dataset is distributed.
1060
1061    Returns:
1062      A `tf.distribute.DistributedDataset`.
1063    """
1064    # pylint: enable=line-too-long
1065    return self._extended._experimental_distribute_dataset(dataset, options)  # pylint: disable=protected-access
1066
1067  def distribute_datasets_from_function(self, dataset_fn, options=None):
1068    # pylint: disable=line-too-long
1069    """Distributes `tf.data.Dataset` instances created by calls to `dataset_fn`.
1070
1071    The argument `dataset_fn` that users pass in is an input function that has a
1072    `tf.distribute.InputContext` argument and returns a `tf.data.Dataset`
1073    instance. It is expected that the returned dataset from `dataset_fn` is
1074    already batched by per-replica batch size (i.e. global batch size divided by
1075    the number of replicas in sync) and sharded.
1076    `tf.distribute.Strategy.distribute_datasets_from_function` does
1077    not batch or shard the `tf.data.Dataset` instance
1078    returned from the input function. `dataset_fn` will be called on the CPU
1079    device of each of the workers and each generates a dataset where every
1080    replica on that worker will dequeue one batch of inputs (i.e. if a worker
1081    has two replicas, two batches will be dequeued from the `Dataset` every
1082    step).
1083
1084    This method can be used for several purposes. First, it allows you to
1085    specify your own batching and sharding logic. (In contrast,
1086    `tf.distribute.experimental_distribute_dataset` does batching and sharding
1087    for you.) For example, where
1088    `experimental_distribute_dataset` is unable to shard the input files, this
1089    method might be used to manually shard the dataset (avoiding the slow
1090    fallback behavior in `experimental_distribute_dataset`). In cases where the
1091    dataset is infinite, this sharding can be done by creating dataset replicas
1092    that differ only in their random seed.
1093
1094    The `dataset_fn` should take an `tf.distribute.InputContext` instance where
1095    information about batching and input replication can be accessed.
1096
1097    You can use `element_spec` property of the
1098    `tf.distribute.DistributedDataset` returned by this API to query the
1099    `tf.TypeSpec` of the elements returned by the iterator. This can be used to
1100    set the `input_signature` property of a `tf.function`. Follow
1101    `tf.distribute.DistributedDataset.element_spec` to see an example.
1102
1103    IMPORTANT: The `tf.data.Dataset` returned by `dataset_fn` should have a
1104    per-replica batch size, unlike `experimental_distribute_dataset`, which uses
1105    the global batch size. This may be computed using
1106    `input_context.get_per_replica_batch_size`.
1107
1108    Note: If you are using TPUStrategy, the order in which the data is processed
1109    by the workers when using
1110    `tf.distribute.Strategy.experimental_distribute_dataset` or
1111    `tf.distribute.Strategy.distribute_datasets_from_function` is
1112    not guaranteed. This is typically required if you are using
1113    `tf.distribute` to scale prediction. You can however insert an index for
1114    each element in the batch and order outputs accordingly. Refer to [this
1115    snippet](https://www.tensorflow.org/tutorials/distribute/input#caveats)
1116    for an example of how to order outputs.
1117
1118    Note: Stateful dataset transformations are currently not supported with
1119    `tf.distribute.experimental_distribute_dataset` or
1120    `tf.distribute.distribute_datasets_from_function`. Any stateful
1121    ops that the dataset may have are currently ignored. For example, if your
1122    dataset has a `map_fn` that uses `tf.random.uniform` to rotate an image,
1123    then you have a dataset graph that depends on state (i.e the random seed) on
1124    the local machine where the python process is being executed.
1125
1126    For a tutorial on more usage and properties of this method, refer to the
1127    [tutorial on distributed input](https://www.tensorflow.org/tutorials/distribute/input#tfdistributestrategyexperimental_distribute_datasets_from_function)).
1128    If you are interested in last partial batch handling, read [this section](https://www.tensorflow.org/tutorials/distribute/input#partial_batches).
1129
1130    Args:
1131      dataset_fn: A function taking a `tf.distribute.InputContext` instance and
1132        returning a `tf.data.Dataset`.
1133      options: `tf.distribute.InputOptions` used to control options on how this
1134        dataset is distributed.
1135
1136    Returns:
1137      A `tf.distribute.DistributedDataset`.
1138    """
1139    # pylint: enable=line-too-long
1140    return self._extended._distribute_datasets_from_function(  # pylint: disable=protected-access
1141        dataset_fn, options)
1142
1143  # TODO(b/162776748): Remove deprecated symbol.
1144  @doc_controls.do_not_doc_inheritable
1145  @deprecation.deprecated(None, "rename to distribute_datasets_from_function")
1146  def experimental_distribute_datasets_from_function(self,
1147                                                     dataset_fn,
1148                                                     options=None):
1149    return self.distribute_datasets_from_function(dataset_fn, options)
1150
1151  def run(self, fn, args=(), kwargs=None, options=None):
1152    """Invokes `fn` on each replica, with the given arguments.
1153
1154    This method is the primary way to distribute your computation with a
1155    tf.distribute object. It invokes `fn` on each replica. If `args` or `kwargs`
1156    have `tf.distribute.DistributedValues`, such as those produced by a
1157    `tf.distribute.DistributedDataset` from
1158    `tf.distribute.Strategy.experimental_distribute_dataset` or
1159    `tf.distribute.Strategy.distribute_datasets_from_function`,
1160    when `fn` is executed on a particular replica, it will be executed with the
1161    component of `tf.distribute.DistributedValues` that correspond to that
1162    replica.
1163
1164    `fn` is invoked under a replica context. `fn` may call
1165    `tf.distribute.get_replica_context()` to access members such as
1166    `all_reduce`. Please see the module-level docstring of tf.distribute for the
1167    concept of replica context.
1168
1169    All arguments in `args` or `kwargs` can be a nested structure of tensors,
1170    e.g. a list of tensors, in which case `args` and `kwargs` will be passed to
1171    the `fn` invoked on each replica. Or `args` or `kwargs` can be
1172    `tf.distribute.DistributedValues` containing tensors or composite tensors,
1173    i.e. `tf.compat.v1.TensorInfo.CompositeTensor`, in which case each `fn` call
1174    will get the component of a `tf.distribute.DistributedValues` corresponding
1175    to its replica. Note that arbitrary Python values that are not of the types
1176    above are not supported.
1177
1178    IMPORTANT: Depending on the implementation of `tf.distribute.Strategy` and
1179    whether eager execution is enabled, `fn` may be called one or more times. If
1180    `fn` is annotated with `tf.function` or `tf.distribute.Strategy.run` is
1181    called inside a `tf.function` (eager execution is disabled inside a
1182    `tf.function` by default), `fn` is called once per replica to generate a
1183    Tensorflow graph, which will then be reused for execution with new inputs.
1184    Otherwise, if eager execution is enabled, `fn` will be called once per
1185    replica every step just like regular python code.
1186
1187    Example usage:
1188
1189    1. Constant tensor input.
1190
1191    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
1192    >>> tensor_input = tf.constant(3.0)
1193    >>> @tf.function
1194    ... def replica_fn(input):
1195    ...   return input*2.0
1196    >>> result = strategy.run(replica_fn, args=(tensor_input,))
1197    >>> result
1198    PerReplica:{
1199      0: <tf.Tensor: shape=(), dtype=float32, numpy=6.0>,
1200      1: <tf.Tensor: shape=(), dtype=float32, numpy=6.0>
1201    }
1202
1203    2. DistributedValues input.
1204
1205    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
1206    >>> @tf.function
1207    ... def run():
1208    ...   def value_fn(value_context):
1209    ...     return value_context.num_replicas_in_sync
1210    ...   distributed_values = (
1211    ...     strategy.experimental_distribute_values_from_function(
1212    ...       value_fn))
1213    ...   def replica_fn2(input):
1214    ...     return input*2
1215    ...   return strategy.run(replica_fn2, args=(distributed_values,))
1216    >>> result = run()
1217    >>> result
1218    <tf.Tensor: shape=(), dtype=int32, numpy=4>
1219
1220    3. Use `tf.distribute.ReplicaContext` to allreduce values.
1221
1222    >>> strategy = tf.distribute.MirroredStrategy(["gpu:0", "gpu:1"])
1223    >>> @tf.function
1224    ... def run():
1225    ...    def value_fn(value_context):
1226    ...      return tf.constant(value_context.replica_id_in_sync_group)
1227    ...    distributed_values = (
1228    ...        strategy.experimental_distribute_values_from_function(
1229    ...            value_fn))
1230    ...    def replica_fn(input):
1231    ...      return tf.distribute.get_replica_context().all_reduce("sum", input)
1232    ...    return strategy.run(replica_fn, args=(distributed_values,))
1233    >>> result = run()
1234    >>> result
1235    PerReplica:{
1236      0: <tf.Tensor: shape=(), dtype=int32, numpy=1>,
1237      1: <tf.Tensor: shape=(), dtype=int32, numpy=1>
1238    }
1239
1240    Args:
1241      fn: The function to run on each replica.
1242      args: Optional positional arguments to `fn`. Its element can be a tensor,
1243        a nested structure of tensors or a `tf.distribute.DistributedValues`.
1244      kwargs: Optional keyword arguments to `fn`. Its element can be a tensor,
1245        a nested structure of tensors or a `tf.distribute.DistributedValues`.
1246      options: An optional instance of `tf.distribute.RunOptions` specifying
1247        the options to run `fn`.
1248
1249    Returns:
1250      Merged return value of `fn` across replicas. The structure of the return
1251      value is the same as the return value from `fn`. Each element in the
1252      structure can either be `tf.distribute.DistributedValues`, `Tensor`
1253      objects, or `Tensor`s (for example, if running on a single replica).
1254    """
1255    del options
1256
1257    if not isinstance(args, (list, tuple)):
1258      raise ValueError(
1259          "positional args must be a list or tuple, got {}".format(type(args)))
1260
1261    with self.scope():
1262      # tf.distribute supports Eager functions, so AutoGraph should not be
1263      # applied when the caller is also in Eager mode.
1264      fn = autograph.tf_convert(
1265          fn, autograph_ctx.control_status_ctx(), convert_by_default=False)
1266      return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
1267
1268  def reduce(self, reduce_op, value, axis):
1269    """Reduce `value` across replicas and return result on current device.
1270
1271    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
1272    >>> def step_fn():
1273    ...   i = tf.distribute.get_replica_context().replica_id_in_sync_group
1274    ...   return tf.identity(i)
1275    >>>
1276    >>> per_replica_result = strategy.run(step_fn)
1277    >>> total = strategy.reduce("SUM", per_replica_result, axis=None)
1278    >>> total
1279    <tf.Tensor: shape=(), dtype=int32, numpy=1>
1280
1281    To see how this would look with multiple replicas, consider the same
1282    example with MirroredStrategy with 2 GPUs:
1283
1284    ```python
1285    strategy = tf.distribute.MirroredStrategy(devices=["GPU:0", "GPU:1"])
1286    def step_fn():
1287      i = tf.distribute.get_replica_context().replica_id_in_sync_group
1288      return tf.identity(i)
1289
1290    per_replica_result = strategy.run(step_fn)
1291    # Check devices on which per replica result is:
1292    strategy.experimental_local_results(per_replica_result)[0].device
1293    # /job:localhost/replica:0/task:0/device:GPU:0
1294    strategy.experimental_local_results(per_replica_result)[1].device
1295    # /job:localhost/replica:0/task:0/device:GPU:1
1296
1297    total = strategy.reduce("SUM", per_replica_result, axis=None)
1298    # Check device on which reduced result is:
1299    total.device
1300    # /job:localhost/replica:0/task:0/device:CPU:0
1301
1302    ```
1303
1304    This API is typically used for aggregating the results returned from
1305    different replicas, for reporting etc. For example, loss computed from
1306    different replicas can be averaged using this API before printing.
1307
1308    Note: The result is copied to the "current" device - which would typically
1309    be the CPU of the worker on which the program is running. For `TPUStrategy`,
1310    it is the first TPU host. For multi client `MultiWorkerMirroredStrategy`,
1311    this is CPU of each worker.
1312
1313    There are a number of different tf.distribute APIs for reducing values
1314    across replicas:
1315    * `tf.distribute.ReplicaContext.all_reduce`: This differs from
1316    `Strategy.reduce` in that it is for replica context and does
1317    not copy the results to the host device. `all_reduce` should be typically
1318    used for reductions inside the training step such as gradients.
1319    * `tf.distribute.StrategyExtended.reduce_to` and
1320    `tf.distribute.StrategyExtended.batch_reduce_to`: These APIs are more
1321    advanced versions of `Strategy.reduce` as they allow customizing the
1322    destination of the result. They are also called in cross replica context.
1323
1324    _What should axis be?_
1325
1326    Given a per-replica value returned by `run`, say a
1327    per-example loss, the batch will be divided across all the replicas.  This
1328    function allows you to aggregate across replicas and optionally also across
1329    batch elements by specifying the axis parameter accordingly.
1330
1331    For example, if you have a global batch size of 8 and 2
1332    replicas, values for examples `[0, 1, 2, 3]` will be on replica 0 and
1333    `[4, 5, 6, 7]` will be on replica 1. With `axis=None`, `reduce` will
1334    aggregate only across replicas, returning `[0+4, 1+5, 2+6, 3+7]`.
1335    This is useful when each replica is computing a scalar or some other value
1336    that doesn't have a "batch" dimension (like a gradient or loss).
1337    ```
1338    strategy.reduce("sum", per_replica_result, axis=None)
1339    ```
1340
1341    Sometimes, you will want to aggregate across both the global batch _and_
1342    all replicas. You can get this behavior by specifying the batch
1343    dimension as the `axis`, typically `axis=0`. In this case it would return a
1344    scalar `0+1+2+3+4+5+6+7`.
1345    ```
1346    strategy.reduce("sum", per_replica_result, axis=0)
1347    ```
1348
1349    If there is a last partial batch, you will need to specify an axis so
1350    that the resulting shape is consistent across replicas. So if the last
1351    batch has size 6 and it is divided into [0, 1, 2, 3] and [4, 5], you
1352    would get a shape mismatch unless you specify `axis=0`. If you specify
1353    `tf.distribute.ReduceOp.MEAN`, using `axis=0` will use the correct
1354    denominator of 6. Contrast this with computing `reduce_mean` to get a
1355    scalar value on each replica and this function to average those means,
1356    which will weigh some values `1/8` and others `1/4`.
1357
1358    Args:
1359      reduce_op: a `tf.distribute.ReduceOp` value specifying how values should
1360        be combined. Allows using string representation of the enum such as
1361        "SUM", "MEAN".
1362      value: a `tf.distribute.DistributedValues` instance, e.g. returned by
1363        `Strategy.run`, to be combined into a single tensor. It can also be a
1364        regular tensor when used with `OneDeviceStrategy` or default strategy.
1365      axis: specifies the dimension to reduce along within each
1366        replica's tensor. Should typically be set to the batch dimension, or
1367        `None` to only reduce across replicas (e.g. if the tensor has no batch
1368        dimension).
1369
1370    Returns:
1371      A `Tensor`.
1372    """
1373    # TODO(josh11b): support `value` being a nest.
1374    _require_cross_replica_or_default_context_extended(self._extended)
1375    if isinstance(reduce_op, six.string_types):
1376      reduce_op = reduce_util.ReduceOp(reduce_op.upper())
1377    if axis is None:
1378      return self._extended._reduce(reduce_op, value)  # pylint: disable=protected-access
1379    if reduce_op == reduce_util.ReduceOp.SUM:
1380
1381      def reduce_sum(v):
1382        return math_ops.reduce_sum(v, axis=axis)
1383
1384      if eager_context.executing_eagerly():
1385        # As some strategies (e.g. TPUStrategy) doesn't support pure eager
1386        # execution, wrap the `reduce_sum_fn` with a `tf.function` so it can be
1387        # run from eager mode. Cache the tf.function by `axis` to avoid the
1388        # same function to be traced again.
1389        if axis not in self._reduce_sum_fns:
1390
1391          def reduce_sum_fn(v):
1392            return self.run(reduce_sum, args=(v,))
1393
1394          self._reduce_sum_fns[axis] = def_function.function(reduce_sum_fn)
1395        value = self._reduce_sum_fns[axis](value)
1396      else:
1397        value = self.run(reduce_sum, args=(value,))
1398
1399      return self._extended._reduce(reduce_op, value)  # pylint: disable=protected-access
1400    if reduce_op != reduce_util.ReduceOp.MEAN:
1401      raise TypeError("Expected `reduce_op` to be a `tf.distribute.ReduceOp`, "
1402                      "not: %r" % reduce_op)
1403    # TODO(josh11b): Support list/tuple and tensor axis values.
1404    if not isinstance(axis, six.integer_types):
1405      raise TypeError("Expected `axis` to be an integer not: %r" % axis)
1406
1407    def mean_reduce_helper(v, axis=axis):
1408      """Computes the numerator and denominator on each replica."""
1409      numer = math_ops.reduce_sum(v, axis=axis)
1410      if v.shape.rank is not None:
1411        # Note(joshl): We support axis < 0 to be consistent with the
1412        # tf.math.reduce_* operations.
1413        if axis < 0:
1414          if axis + v.shape.rank < 0:
1415            raise ValueError(
1416                "`axis` = %r out of range for `value` with rank %d" %
1417                (axis, v.shape.rank))
1418          axis += v.shape.rank
1419        elif axis >= v.shape.rank:
1420          raise ValueError(
1421              "`axis` = %r out of range for `value` with rank %d" %
1422              (axis, v.shape.rank))
1423        # TF v2 returns `None` for unknown dimensions and an integer for
1424        # known dimension, whereas TF v1 returns tensor_shape.Dimension(None)
1425        # or tensor_shape.Dimension(integer). `dimension_value` hides this
1426        # difference, always returning `None` or an integer.
1427        dim = tensor_shape.dimension_value(v.shape[axis])
1428        if dim is not None:
1429          # By returning a python value in the static shape case, we can
1430          # maybe get a fast path for reducing the denominator.
1431          # TODO(b/151871486): Remove array_ops.identity after we fallback to
1432          # simple reduction if inputs are all on CPU.
1433          return numer, array_ops.identity(
1434              constant_op.constant(dim, dtype=dtypes.int64))
1435      elif axis < 0:
1436        axis = axis + array_ops.rank(v)
1437      # TODO(b/151871486): Remove array_ops.identity after we fallback to simple
1438      # reduction if inputs are all on CPU.
1439      denom = array_ops.identity(
1440          array_ops.shape_v2(v, out_type=dtypes.int64)[axis])
1441      # TODO(josh11b): Should we cast denom to v.dtype here instead of after the
1442      # reduce is complete?
1443      return numer, denom
1444
1445    if eager_context.executing_eagerly():
1446      # As some strategies (e.g. TPUStrategy) doesn't support pure eager
1447      # execution, wrap the `mean_reduce_helper` with a `tf.function` so it can
1448      # be run from eager mode. Cache the tf.function by `axis` to avoid the
1449      # same function to be traced again.
1450      if axis not in self._mean_reduce_helper_fns:
1451
1452        def mean_reduce_fn(v):
1453          return self.run(mean_reduce_helper, args=(v,))
1454
1455        self._mean_reduce_helper_fns[axis] = def_function.function(
1456            mean_reduce_fn)
1457      numer, denom = self._mean_reduce_helper_fns[axis](value)
1458    else:
1459      numer, denom = self.run(mean_reduce_helper, args=(value,))
1460
1461    # TODO(josh11b): Should batch reduce here instead of doing two.
1462    numer = self._extended._reduce(reduce_util.ReduceOp.SUM, numer)  # pylint: disable=protected-access
1463    denom = self._extended._reduce(reduce_util.ReduceOp.SUM, denom)  # pylint: disable=protected-access
1464    denom = math_ops.cast(denom, numer.dtype)
1465    return math_ops.truediv(numer, denom)
1466
1467  @doc_controls.do_not_doc_inheritable  # DEPRECATED
1468  def unwrap(self, value):
1469    """Returns the list of all local per-replica values contained in `value`.
1470
1471    DEPRECATED: Please use `experimental_local_results` instead.
1472
1473    Note: This only returns values on the workers initiated by this client.
1474    When using a `tf.distribute.Strategy` like
1475    `tf.distribute.experimental.MultiWorkerMirroredStrategy`, each worker
1476    will be its own client, and this function will only return values
1477    computed on that worker.
1478
1479    Args:
1480      value: A value returned by `experimental_run()`,
1481        `extended.call_for_each_replica()`, or a variable created in `scope`.
1482
1483    Returns:
1484      A tuple of values contained in `value`. If `value` represents a single
1485      value, this returns `(value,).`
1486    """
1487    return self._extended._local_results(value)  # pylint: disable=protected-access
1488
1489  def experimental_local_results(self, value):
1490    """Returns the list of all local per-replica values contained in `value`.
1491
1492    Note: This only returns values on the worker initiated by this client.
1493    When using a `tf.distribute.Strategy` like
1494    `tf.distribute.experimental.MultiWorkerMirroredStrategy`, each worker
1495    will be its own client, and this function will only return values
1496    computed on that worker.
1497
1498    Args:
1499      value: A value returned by `experimental_run()`, `run()`,
1500        `extended.call_for_each_replica()`, or a variable created in `scope`.
1501
1502    Returns:
1503      A tuple of values contained in `value`. If `value` represents a single
1504      value, this returns `(value,).`
1505    """
1506    return self._extended._local_results(value)  # pylint: disable=protected-access
1507
1508  @doc_controls.do_not_doc_inheritable  # DEPRECATED: TF v1.x only
1509  def group(self, value, name=None):
1510    """Shortcut for `tf.group(self.experimental_local_results(value))`."""
1511    return self._extended._group(value, name)  # pylint: disable=protected-access
1512
1513  @property
1514  def num_replicas_in_sync(self):
1515    """Returns number of replicas over which gradients are aggregated."""
1516    return self._extended._num_replicas_in_sync  # pylint: disable=protected-access
1517
1518  @doc_controls.do_not_doc_inheritable  # DEPRECATED: see doc string
1519  def configure(self,
1520                session_config=None,
1521                cluster_spec=None,
1522                task_type=None,
1523                task_id=None):
1524    # pylint: disable=g-doc-return-or-yield,g-doc-args
1525    """DEPRECATED: use `update_config_proto` instead.
1526
1527    Configures the strategy class.
1528
1529    DEPRECATED: This method's functionality has been split into the strategy
1530    constructor and `update_config_proto`. In the future, we will allow passing
1531    cluster and config_proto to the constructor to configure the strategy. And
1532    `update_config_proto` can be used to update the config_proto based on the
1533    specific strategy.
1534    """
1535    return self._extended._configure(  # pylint: disable=protected-access
1536        session_config, cluster_spec, task_type, task_id)
1537
1538  @doc_controls.do_not_generate_docs  # DEPRECATED
1539  def update_config_proto(self, config_proto):
1540    """DEPRECATED TF 1.x ONLY."""
1541    return self._extended._update_config_proto(config_proto)  # pylint: disable=protected-access
1542
1543  def __deepcopy__(self, memo):
1544    # First do a regular deepcopy of `self`.
1545    cls = self.__class__
1546    result = cls.__new__(cls)
1547    memo[id(self)] = result
1548    for k, v in self.__dict__.items():
1549      setattr(result, k, copy.deepcopy(v, memo))
1550    # One little fix-up: we want `result._extended` to reference `result`
1551    # instead of `self`.
1552    result._extended._container_strategy_weakref = weakref.ref(result)  # pylint: disable=protected-access
1553    return result
1554
1555  def __copy__(self):
1556    raise RuntimeError("Must only deepcopy DistributionStrategy.")
1557
1558  @property
1559  def cluster_resolver(self):
1560    """Returns the cluster resolver associated with this strategy.
1561
1562    In general, when using a multi-worker `tf.distribute` strategy such as
1563    `tf.distribute.experimental.MultiWorkerMirroredStrategy` or
1564    `tf.distribute.TPUStrategy()`, there is a
1565    `tf.distribute.cluster_resolver.ClusterResolver` associated with the
1566    strategy used, and such an instance is returned by this property.
1567
1568    Strategies that intend to have an associated
1569    `tf.distribute.cluster_resolver.ClusterResolver` must set the
1570    relevant attribute, or override this property; otherwise, `None` is returned
1571    by default. Those strategies should also provide information regarding what
1572    is returned by this property.
1573
1574    Single-worker strategies usually do not have a
1575    `tf.distribute.cluster_resolver.ClusterResolver`, and in those cases this
1576    property will return `None`.
1577
1578    The `tf.distribute.cluster_resolver.ClusterResolver` may be useful when the
1579    user needs to access information such as the cluster spec, task type or task
1580    id. For example,
1581
1582    ```python
1583
1584    os.environ['TF_CONFIG'] = json.dumps({
1585      'cluster': {
1586          'worker': ["localhost:12345", "localhost:23456"],
1587          'ps': ["localhost:34567"]
1588      },
1589      'task': {'type': 'worker', 'index': 0}
1590    })
1591
1592    # This implicitly uses TF_CONFIG for the cluster and current task info.
1593    strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
1594
1595    ...
1596
1597    if strategy.cluster_resolver.task_type == 'worker':
1598      # Perform something that's only applicable on workers. Since we set this
1599      # as a worker above, this block will run on this particular instance.
1600    elif strategy.cluster_resolver.task_type == 'ps':
1601      # Perform something that's only applicable on parameter servers. Since we
1602      # set this as a worker above, this block will not run on this particular
1603      # instance.
1604    ```
1605
1606    For more information, please see
1607    `tf.distribute.cluster_resolver.ClusterResolver`'s API docstring.
1608
1609    Returns:
1610      The cluster resolver associated with this strategy. Returns `None` if a
1611      cluster resolver is not applicable or available in this strategy.
1612    """
1613    if hasattr(self.extended, "_cluster_resolver"):
1614      return self.extended._cluster_resolver  # pylint: disable=protected-access
1615    return None
1616
1617
1618@tf_export("distribute.Strategy", v1=[])  # pylint: disable=g-missing-docstring
1619class Strategy(StrategyBase):
1620
1621  __doc__ = StrategyBase.__doc__
1622
1623  def experimental_distribute_values_from_function(self, value_fn):
1624    """Generates `tf.distribute.DistributedValues` from `value_fn`.
1625
1626    This function is to generate `tf.distribute.DistributedValues` to pass
1627    into `run`, `reduce`, or other methods that take
1628    distributed values when not using datasets.
1629
1630    Args:
1631      value_fn: The function to run to generate values. It is called for
1632        each replica with `tf.distribute.ValueContext` as the sole argument. It
1633        must return a Tensor or a type that can be converted to a Tensor.
1634    Returns:
1635      A `tf.distribute.DistributedValues` containing a value for each replica.
1636
1637    Example usage:
1638
1639    1. Return constant value per replica:
1640
1641    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
1642    >>> def value_fn(ctx):
1643    ...   return tf.constant(1.)
1644    >>> distributed_values = (
1645    ...      strategy.experimental_distribute_values_from_function(
1646    ...        value_fn))
1647    >>> local_result = strategy.experimental_local_results(distributed_values)
1648    >>> local_result
1649    (<tf.Tensor: shape=(), dtype=float32, numpy=1.0>,
1650     <tf.Tensor: shape=(), dtype=float32, numpy=1.0>)
1651
1652    2. Distribute values in array based on replica_id:
1653
1654    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
1655    >>> array_value = np.array([3., 2., 1.])
1656    >>> def value_fn(ctx):
1657    ...   return array_value[ctx.replica_id_in_sync_group]
1658    >>> distributed_values = (
1659    ...      strategy.experimental_distribute_values_from_function(
1660    ...        value_fn))
1661    >>> local_result = strategy.experimental_local_results(distributed_values)
1662    >>> local_result
1663    (3.0, 2.0)
1664
1665    3. Specify values using num_replicas_in_sync:
1666
1667    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
1668    >>> def value_fn(ctx):
1669    ...   return ctx.num_replicas_in_sync
1670    >>> distributed_values = (
1671    ...      strategy.experimental_distribute_values_from_function(
1672    ...        value_fn))
1673    >>> local_result = strategy.experimental_local_results(distributed_values)
1674    >>> local_result
1675    (2, 2)
1676
1677    4. Place values on devices and distribute:
1678
1679    ```
1680    strategy = tf.distribute.TPUStrategy()
1681    worker_devices = strategy.extended.worker_devices
1682    multiple_values = []
1683    for i in range(strategy.num_replicas_in_sync):
1684      with tf.device(worker_devices[i]):
1685        multiple_values.append(tf.constant(1.0))
1686
1687    def value_fn(ctx):
1688      return multiple_values[ctx.replica_id_in_sync_group]
1689
1690    distributed_values = strategy.
1691      experimental_distribute_values_from_function(
1692      value_fn)
1693    ```
1694
1695    """
1696    return self._extended._experimental_distribute_values_from_function(  # pylint: disable=protected-access
1697        value_fn)
1698
1699  def gather(self, value, axis):
1700    # pylint: disable=line-too-long, protected-access
1701    """Gather `value` across replicas along `axis` to the current device.
1702
1703    Given a `tf.distribute.DistributedValues` or `tf.Tensor`-like
1704    object `value`, this API gathers and concatenates `value` across replicas
1705    along the `axis`-th dimension. The result is copied to the "current" device,
1706    which would typically be the CPU of the worker on which the program is
1707    running. For `tf.distribute.TPUStrategy`, it is the first TPU host. For
1708    multi-client `tf.distribute.MultiWorkerMirroredStrategy`, this is the CPU of
1709    each worker.
1710
1711    This API can only be called in the cross-replica context. For a counterpart
1712    in the replica context, see `tf.distribute.ReplicaContext.all_gather`.
1713
1714    Note: For all strategies except `tf.distribute.TPUStrategy`, the input
1715    `value` on different replicas must have the same rank, and their shapes must
1716    be the same in all dimensions except the `axis`-th dimension. In other
1717    words, their shapes cannot be different in a dimension `d` where `d` does
1718    not equal to the `axis` argument. For example, given a
1719    `tf.distribute.DistributedValues` with component tensors of shape
1720    `(1, 2, 3)` and `(1, 3, 3)` on two replicas, you can call
1721    `gather(..., axis=1, ...)` on it, but not `gather(..., axis=0, ...)` or
1722    `gather(..., axis=2, ...)`. However, for `tf.distribute.TPUStrategy.gather`,
1723    all tensors must have exactly the same rank and same shape.
1724
1725    Note: Given a `tf.distribute.DistributedValues` `value`, its component
1726    tensors must have a non-zero rank. Otherwise, consider using
1727    `tf.expand_dims` before gathering them.
1728
1729    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
1730    >>> # A DistributedValues with component tensor of shape (2, 1) on each replica
1731    ... distributed_values = strategy.experimental_distribute_values_from_function(lambda _: tf.identity(tf.constant([[1], [2]])))
1732    >>> @tf.function
1733    ... def run():
1734    ...   return strategy.gather(distributed_values, axis=0)
1735    >>> run()
1736    <tf.Tensor: shape=(4, 1), dtype=int32, numpy=
1737    array([[1],
1738           [2],
1739           [1],
1740           [2]], dtype=int32)>
1741
1742
1743    Consider the following example for more combinations:
1744
1745    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1", "GPU:2", "GPU:3"])
1746    >>> single_tensor = tf.reshape(tf.range(6), shape=(1,2,3))
1747    >>> distributed_values = strategy.experimental_distribute_values_from_function(lambda _: tf.identity(single_tensor))
1748    >>> @tf.function
1749    ... def run(axis):
1750    ...   return strategy.gather(distributed_values, axis=axis)
1751    >>> axis=0
1752    >>> run(axis)
1753    <tf.Tensor: shape=(4, 2, 3), dtype=int32, numpy=
1754    array([[[0, 1, 2],
1755            [3, 4, 5]],
1756           [[0, 1, 2],
1757            [3, 4, 5]],
1758           [[0, 1, 2],
1759            [3, 4, 5]],
1760           [[0, 1, 2],
1761            [3, 4, 5]]], dtype=int32)>
1762    >>> axis=1
1763    >>> run(axis)
1764    <tf.Tensor: shape=(1, 8, 3), dtype=int32, numpy=
1765    array([[[0, 1, 2],
1766            [3, 4, 5],
1767            [0, 1, 2],
1768            [3, 4, 5],
1769            [0, 1, 2],
1770            [3, 4, 5],
1771            [0, 1, 2],
1772            [3, 4, 5]]], dtype=int32)>
1773    >>> axis=2
1774    >>> run(axis)
1775    <tf.Tensor: shape=(1, 2, 12), dtype=int32, numpy=
1776    array([[[0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2],
1777            [3, 4, 5, 3, 4, 5, 3, 4, 5, 3, 4, 5]]], dtype=int32)>
1778
1779
1780    Args:
1781      value: a `tf.distribute.DistributedValues` instance, e.g. returned by
1782        `Strategy.run`, to be combined into a single tensor. It can also be a
1783        regular tensor when used with `tf.distribute.OneDeviceStrategy` or the
1784        default strategy. The tensors that constitute the DistributedValues
1785        can only be dense tensors with non-zero rank, NOT a `tf.IndexedSlices`.
1786      axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the
1787        range [0, rank(value)).
1788
1789    Returns:
1790       A `Tensor` that's the concatenation of `value` across replicas along
1791       `axis` dimension.
1792    """
1793    # pylint: enable=line-too-long
1794    error_message = ("tf.distribute.Strategy.gather method requires "
1795                     "cross-replica context, use "
1796                     "get_replica_context().all_gather() instead")
1797    _require_cross_replica_or_default_context_extended(self._extended,
1798                                                       error_message)
1799    dst = device_util.current(
1800    ) or self._extended._default_device or "/device:CPU:0"
1801    if isinstance(value, ops.IndexedSlices):
1802      raise NotImplementedError("gather does not support IndexedSlices")
1803    return self._extended._local_results(
1804        self._extended._gather_to(value, dst, axis))[0]
1805
1806
1807# TF v1.x version has additional deprecated APIs
1808@tf_export(v1=["distribute.Strategy"])
1809class StrategyV1(StrategyBase):
1810  """A list of devices with a state & compute distribution policy.
1811
1812  See [the guide](https://www.tensorflow.org/guide/distribute_strategy)
1813  for overview and examples.
1814
1815  Note: Not all `tf.distribute.Strategy` implementations currently support
1816  TensorFlow's partitioned variables (where a single variable is split across
1817  multiple devices) at this time.
1818  """
1819
1820  def make_dataset_iterator(self, dataset):
1821    """Makes an iterator for input provided via `dataset`.
1822
1823    DEPRECATED: This method is not available in TF 2.x.
1824
1825    Data from the given dataset will be distributed evenly across all the
1826    compute replicas. We will assume that the input dataset is batched by the
1827    global batch size. With this assumption, we will make a best effort to
1828    divide each batch across all the replicas (one or more workers).
1829    If this effort fails, an error will be thrown, and the user should instead
1830    use `make_input_fn_iterator` which provides more control to the user, and
1831    does not try to divide a batch across replicas.
1832
1833    The user could also use `make_input_fn_iterator` if they want to
1834    customize which input is fed to which replica/worker etc.
1835
1836    Args:
1837      dataset: `tf.data.Dataset` that will be distributed evenly across all
1838        replicas.
1839
1840    Returns:
1841      An `tf.distribute.InputIterator` which returns inputs for each step of the
1842      computation.  User should call `initialize` on the returned iterator.
1843    """
1844    return self._extended._make_dataset_iterator(dataset)  # pylint: disable=protected-access
1845
1846  def make_input_fn_iterator(self,  # pylint: disable=useless-super-delegation
1847                             input_fn,
1848                             replication_mode=InputReplicationMode.PER_WORKER):
1849    """Returns an iterator split across replicas created from an input function.
1850
1851    DEPRECATED: This method is not available in TF 2.x.
1852
1853    The `input_fn` should take an `tf.distribute.InputContext` object where
1854    information about batching and input sharding can be accessed:
1855
1856    ```
1857    def input_fn(input_context):
1858      batch_size = input_context.get_per_replica_batch_size(global_batch_size)
1859      d = tf.data.Dataset.from_tensors([[1.]]).repeat().batch(batch_size)
1860      return d.shard(input_context.num_input_pipelines,
1861                     input_context.input_pipeline_id)
1862    with strategy.scope():
1863      iterator = strategy.make_input_fn_iterator(input_fn)
1864      replica_results = strategy.experimental_run(replica_fn, iterator)
1865    ```
1866
1867    The `tf.data.Dataset` returned by `input_fn` should have a per-replica
1868    batch size, which may be computed using
1869    `input_context.get_per_replica_batch_size`.
1870
1871    Args:
1872      input_fn: A function taking a `tf.distribute.InputContext` object and
1873        returning a `tf.data.Dataset`.
1874      replication_mode: an enum value of `tf.distribute.InputReplicationMode`.
1875        Only `PER_WORKER` is supported currently, which means there will be
1876        a single call to `input_fn` per worker. Replicas will dequeue from the
1877        local `tf.data.Dataset` on their worker.
1878
1879    Returns:
1880      An iterator object that should first be `.initialize()`-ed. It may then
1881      either be passed to `strategy.experimental_run()` or you can
1882      `iterator.get_next()` to get the next value to pass to
1883      `strategy.extended.call_for_each_replica()`.
1884    """
1885    return super(StrategyV1, self).make_input_fn_iterator(
1886        input_fn, replication_mode)
1887
1888  def experimental_make_numpy_dataset(self, numpy_input, session=None):
1889    """Makes a tf.data.Dataset for input provided via a numpy array.
1890
1891    This avoids adding `numpy_input` as a large constant in the graph,
1892    and copies the data to the machine or machines that will be processing
1893    the input.
1894
1895    Note that you will likely need to use
1896    tf.distribute.Strategy.experimental_distribute_dataset
1897    with the returned dataset to further distribute it with the strategy.
1898
1899    Example:
1900    ```
1901    numpy_input = np.ones([10], dtype=np.float32)
1902    dataset = strategy.experimental_make_numpy_dataset(numpy_input)
1903    dist_dataset = strategy.experimental_distribute_dataset(dataset)
1904    ```
1905
1906    Args:
1907      numpy_input: A nest of NumPy input arrays that will be converted into a
1908      dataset. Note that lists of Numpy arrays are stacked, as that is normal
1909      `tf.data.Dataset` behavior.
1910      session: (TensorFlow v1.x graph execution only) A session used for
1911        initialization.
1912
1913    Returns:
1914      A `tf.data.Dataset` representing `numpy_input`.
1915    """
1916    return self.extended.experimental_make_numpy_dataset(
1917        numpy_input, session=session)
1918
1919  def experimental_run(self, fn, input_iterator=None):  # pylint: disable=useless-super-delegation
1920    """Runs ops in `fn` on each replica, with inputs from `input_iterator`.
1921
1922    DEPRECATED: This method is not available in TF 2.x. Please switch
1923    to using `run` instead.
1924
1925    When eager execution is enabled, executes ops specified by `fn` on each
1926    replica. Otherwise, builds a graph to execute the ops on each replica.
1927
1928    Each replica will take a single, different input from the inputs provided by
1929    one `get_next` call on the input iterator.
1930
1931    `fn` may call `tf.distribute.get_replica_context()` to access members such
1932    as `replica_id_in_sync_group`.
1933
1934    IMPORTANT: Depending on the `tf.distribute.Strategy` implementation being
1935    used, and whether eager execution is enabled, `fn` may be called one or more
1936    times (once for each replica).
1937
1938    Args:
1939      fn: The function to run. The inputs to the function must match the outputs
1940        of `input_iterator.get_next()`. The output must be a `tf.nest` of
1941        `Tensor`s.
1942      input_iterator: (Optional) input iterator from which the inputs are taken.
1943
1944    Returns:
1945      Merged return value of `fn` across replicas. The structure of the return
1946      value is the same as the return value from `fn`. Each element in the
1947      structure can either be `PerReplica` (if the values are unsynchronized),
1948      `Mirrored` (if the values are kept in sync), or `Tensor` (if running on a
1949      single replica).
1950    """
1951    return super(StrategyV1, self).experimental_run(
1952        fn, input_iterator)
1953
1954  def reduce(self, reduce_op, value, axis=None):
1955    return super(StrategyV1, self).reduce(reduce_op, value, axis)
1956
1957  reduce.__doc__ = StrategyBase.reduce.__doc__
1958
1959  def update_config_proto(self, config_proto):
1960    """Returns a copy of `config_proto` modified for use with this strategy.
1961
1962    DEPRECATED: This method is not available in TF 2.x.
1963
1964    The updated config has something needed to run a strategy, e.g.
1965    configuration to run collective ops, or device filters to improve
1966    distributed training performance.
1967
1968    Args:
1969      config_proto: a `tf.ConfigProto` object.
1970
1971    Returns:
1972      The updated copy of the `config_proto`.
1973    """
1974    return self._extended._update_config_proto(config_proto)  # pylint: disable=protected-access
1975
1976
1977# NOTE(josh11b): For any strategy that needs to support tf.compat.v1,
1978# instead descend from StrategyExtendedV1.
1979@tf_export("distribute.StrategyExtended", v1=[])
1980class StrategyExtendedV2(object):
1981  """Additional APIs for algorithms that need to be distribution-aware.
1982
1983  Note: For most usage of `tf.distribute.Strategy`, there should be no need to
1984  call these methods, since TensorFlow libraries (such as optimizers) already
1985  call these methods when needed on your behalf.
1986
1987
1988  Some common use cases of functions on this page:
1989
1990  * _Locality_
1991
1992  `tf.distribute.DistributedValues` can have the same _locality_ as a
1993  _distributed variable_, which leads to a mirrored value residing on the same
1994  devices as the variable (as opposed to the compute devices). Such values may
1995  be passed to a call to `tf.distribute.StrategyExtended.update` to update the
1996  value of a variable. You may use
1997  `tf.distribute.StrategyExtended.colocate_vars_with` to give a variable the
1998  same locality as another variable. You may convert a "PerReplica" value to a
1999  variable's locality by using `tf.distribute.StrategyExtended.reduce_to` or
2000  `tf.distribute.StrategyExtended.batch_reduce_to`.
2001
2002  * _How to update a distributed variable_
2003
2004  A distributed variable is variables created on multiple devices. As discussed
2005  in the [glossary](https://www.tensorflow.org/api_docs/python/tf/distribute),
2006  mirrored variable and SyncOnRead variable are two examples. The standard
2007  pattern for updating distributed variables is to:
2008
2009  1. In your function passed to `tf.distribute.Strategy.run`,
2010     compute a list of (update, variable) pairs. For example, the update might
2011     be a gradient of the loss with respect to the variable.
2012  2. Switch to cross-replica mode by calling
2013     `tf.distribute.get_replica_context().merge_call()` with the updates and
2014     variables as arguments.
2015  3. Call
2016     `tf.distribute.StrategyExtended.reduce_to(VariableAggregation.SUM, t, v)`
2017     (for one variable) or `tf.distribute.StrategyExtended.batch_reduce_to`
2018     (for a list of variables) to sum the updates.
2019  4. Call `tf.distribute.StrategyExtended.update(v)` for each variable to update
2020     its value.
2021
2022  Steps 2 through 4 are done automatically by class
2023  `tf.keras.optimizers.Optimizer` if you call its
2024  `tf.keras.optimizers.Optimizer.apply_gradients` method in a replica context.
2025
2026  In fact, a higher-level solution to update a distributed variable is by
2027  calling `assign` on the variable as you would do to a regular `tf.Variable`.
2028  You can call the method in both _replica context_ and _cross-replica context_.
2029  For a _mirrored variable_, calling `assign` in _replica context_ requires you
2030  to specify the `aggregation` type in the variable constructor. In that case,
2031  the context switching and sync described in steps 2 through 4 are handled for
2032  you. If you call `assign` on _mirrored variable_ in _cross-replica context_,
2033  you can only assign a single value or assign values from another mirrored
2034  variable or a mirrored `tf.distribute.DistributedValues`. For a _SyncOnRead
2035  variable_, in _replica context_, you can simply call `assign` on it and no
2036  aggregation happens under the hood. In _cross-replica context_, you can only
2037  assign a single value to a SyncOnRead variable. One example case is restoring
2038  from a checkpoint: if the `aggregation` type of the variable is
2039  `tf.VariableAggregation.SUM`, it is assumed that replica values were added
2040  before checkpointing, so at the time of restoring, the value is divided by
2041  the number of replicas and then assigned to each replica; if the `aggregation`
2042  type is `tf.VariableAggregation.MEAN`, the value is assigned to each replica
2043  directly.
2044
2045  """
2046
2047  def __init__(self, container_strategy):
2048    self._container_strategy_weakref = weakref.ref(container_strategy)
2049    self._default_device = None
2050    # This property is used to determine if we should set drop_remainder=True
2051    # when creating Datasets from numpy array inputs.
2052    self._require_static_shapes = False
2053
2054  def _container_strategy(self):
2055    """Get the containing `tf.distribute.Strategy`.
2056
2057    This should not generally be needed except when creating a new
2058    `ReplicaContext` and to validate that the caller is in the correct
2059    `scope()`.
2060
2061    Returns:
2062      The `tf.distribute.Strategy` such that `strategy.extended` is `self`.
2063    """
2064    container_strategy = self._container_strategy_weakref()
2065    assert container_strategy is not None
2066    return container_strategy
2067
2068  def _scope(self, strategy):
2069    """Implementation of tf.distribute.Strategy.scope()."""
2070
2071    def creator_with_resource_vars(next_creator, **kwargs):
2072      """Variable creator to use in `_CurrentDistributionContext`."""
2073      _require_strategy_scope_extended(self)
2074      kwargs["use_resource"] = True
2075      kwargs["distribute_strategy"] = strategy
2076
2077      # Unwrap `initial_value` if it is a `CheckpointInitialValue` to avoid
2078      # dereferencing a `Tensor` that is without a `name`. We still need to
2079      # propagate the metadata it's holding.
2080      if isinstance(kwargs["initial_value"], trackable.CheckpointInitialValue):
2081        checkpoint_restore_uid = kwargs[
2082            "initial_value"].checkpoint_position.restore_uid
2083        kwargs["initial_value"] = kwargs["initial_value"].wrapped_value
2084      elif isinstance(kwargs["initial_value"],
2085                      trackable.CheckpointInitialValueCallable):
2086        checkpoint_restore_uid = kwargs[
2087            "initial_value"].checkpoint_position.restore_uid
2088      else:
2089        checkpoint_restore_uid = None
2090
2091      created = self._create_variable(next_creator, **kwargs)
2092
2093      if checkpoint_restore_uid is not None:
2094        # pylint: disable=protected-access
2095        # Let the checkpointing infrastructure know that the variable was
2096        # already restored so it doesn't waste memory loading the value again.
2097        # In this case of CheckpointInitialValueCallable this may already be
2098        # done by the final variable creator, but it doesn't hurt to do it
2099        # again.
2100        created._maybe_initialize_trackable()
2101        created._update_uid = checkpoint_restore_uid
2102        # pylint: enable=protected-access
2103      return created
2104
2105    def distributed_getter(getter, *args, **kwargs):
2106      if not self._allow_variable_partition():
2107        if kwargs.pop("partitioner", None) is not None:
2108          tf_logging.log_first_n(
2109              tf_logging.WARN, "Partitioned variables are disabled when using "
2110              "current tf.distribute.Strategy.", 1)
2111      return getter(*args, **kwargs)
2112
2113    return _CurrentDistributionContext(
2114        strategy,
2115        variable_scope.variable_creator_scope(creator_with_resource_vars),
2116        variable_scope.variable_scope(
2117            variable_scope.get_variable_scope(),
2118            custom_getter=distributed_getter), self._default_device)
2119
2120  def _allow_variable_partition(self):
2121    return False
2122
2123  def _create_variable(self, next_creator, **kwargs):
2124    # Note: should support "colocate_with" argument.
2125    raise NotImplementedError("must be implemented in descendants")
2126
2127  def variable_created_in_scope(self, v):
2128    """Tests whether `v` was created while this strategy scope was active.
2129
2130    Variables created inside the strategy scope are "owned" by it:
2131
2132    >>> strategy = tf.distribute.MirroredStrategy()
2133    >>> with strategy.scope():
2134    ...   v = tf.Variable(1.)
2135    >>> strategy.extended.variable_created_in_scope(v)
2136    True
2137
2138    Variables created outside the strategy are not owned by it:
2139
2140    >>> strategy = tf.distribute.MirroredStrategy()
2141    >>> v = tf.Variable(1.)
2142    >>> strategy.extended.variable_created_in_scope(v)
2143    False
2144
2145    Args:
2146      v: A `tf.Variable` instance.
2147
2148    Returns:
2149      True if `v` was created inside the scope, False if not.
2150    """
2151    return v._distribute_strategy == self._container_strategy_weakref()  # pylint: disable=protected-access
2152
2153  def colocate_vars_with(self, colocate_with_variable):
2154    """Scope that controls which devices variables will be created on.
2155
2156    No operations should be added to the graph inside this scope, it
2157    should only be used when creating variables (some implementations
2158    work by changing variable creation, others work by using a
2159    tf.compat.v1.colocate_with() scope).
2160
2161    This may only be used inside `self.scope()`.
2162
2163    Example usage:
2164
2165    ```
2166    with strategy.scope():
2167      var1 = tf.Variable(...)
2168      with strategy.extended.colocate_vars_with(var1):
2169        # var2 and var3 will be created on the same device(s) as var1
2170        var2 = tf.Variable(...)
2171        var3 = tf.Variable(...)
2172
2173      def fn(v1, v2, v3):
2174        # operates on v1 from var1, v2 from var2, and v3 from var3
2175
2176      # `fn` runs on every device `var1` is on, `var2` and `var3` will be there
2177      # too.
2178      strategy.extended.update(var1, fn, args=(var2, var3))
2179    ```
2180
2181    Args:
2182      colocate_with_variable: A variable created in this strategy's `scope()`.
2183        Variables created while in the returned context manager will be on the
2184        same set of devices as `colocate_with_variable`.
2185
2186    Returns:
2187      A context manager.
2188    """
2189
2190    def create_colocated_variable(next_creator, **kwargs):
2191      _require_strategy_scope_extended(self)
2192      kwargs["use_resource"] = True
2193      kwargs["colocate_with"] = colocate_with_variable
2194      return next_creator(**kwargs)
2195
2196    _require_strategy_scope_extended(self)
2197    self._validate_colocate_with_variable(colocate_with_variable)
2198    return variable_scope.variable_creator_scope(create_colocated_variable)
2199
2200  def _validate_colocate_with_variable(self, colocate_with_variable):
2201    """Validate `colocate_with_variable` argument to `colocate_vars_with`."""
2202    pass
2203
2204  def _make_dataset_iterator(self, dataset):
2205    raise NotImplementedError("must be implemented in descendants")
2206
2207  def _make_input_fn_iterator(self, input_fn, replication_mode):
2208    raise NotImplementedError("must be implemented in descendants")
2209
2210  def _experimental_distribute_dataset(self, dataset, options):
2211    raise NotImplementedError("must be implemented in descendants")
2212
2213  def _distribute_datasets_from_function(self, dataset_fn, options):
2214    raise NotImplementedError("must be implemented in descendants")
2215
2216  def _experimental_distribute_values_from_function(self, value_fn):
2217    raise NotImplementedError("must be implemented in descendants")
2218
2219  def _reduce(self, reduce_op, value):
2220    # Default implementation until we have an implementation for each strategy.
2221    dst = device_util.current() or self._default_device or "/device:CPU:0"
2222    return self._local_results(self.reduce_to(reduce_op, value, dst))[0]
2223
2224  def reduce_to(self, reduce_op, value, destinations, options=None):
2225    """Combine (via e.g. sum or mean) values across replicas.
2226
2227    `reduce_to` aggregates `tf.distribute.DistributedValues` and distributed
2228    variables. It supports both dense values and `tf.IndexedSlices`.
2229
2230    This API currently can only be called in cross-replica context. Other
2231    variants to reduce values across replicas are:
2232    * `tf.distribute.StrategyExtended.batch_reduce_to`: the batch version of
2233      this API.
2234    * `tf.distribute.ReplicaContext.all_reduce`: the counterpart of this API
2235      in replica context. It supports both batched and non-batched all-reduce.
2236    * `tf.distribute.Strategy.reduce`: a more convenient method to reduce
2237      to the host in cross-replica context.
2238
2239    `destinations` specifies where to reduce the value to, e.g. "GPU:0". You can
2240    also pass in a `Tensor`, and the destinations will be the device of that
2241    tensor. For all-reduce, pass the same to `value` and `destinations`.
2242
2243    It can be used in `tf.distribute.ReplicaContext.merge_call` to write code
2244    that works for all `tf.distribute.Strategy`.
2245
2246    >>> @tf.function
2247    ... def step_fn(var):
2248    ...
2249    ...   def merge_fn(strategy, value, var):
2250    ...     # All-reduce the value. Note that `value` here is a
2251    ...     # `tf.distribute.DistributedValues`.
2252    ...     reduced = strategy.extended.reduce_to(tf.distribute.ReduceOp.SUM,
2253    ...         value, destinations=var)
2254    ...     strategy.extended.update(var, lambda var, value: var.assign(value),
2255    ...         args=(reduced,))
2256    ...
2257    ...   value = tf.identity(1.)
2258    ...   tf.distribute.get_replica_context().merge_call(merge_fn,
2259    ...     args=(value, var))
2260    >>>
2261    >>> def run(strategy):
2262    ...   with strategy.scope():
2263    ...     v = tf.Variable(0.)
2264    ...     strategy.run(step_fn, args=(v,))
2265    ...     return v
2266    >>>
2267    >>> run(tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]))
2268    MirroredVariable:{
2269      0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>,
2270      1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=2.0>
2271    }
2272    >>> run(tf.distribute.experimental.CentralStorageStrategy(
2273    ...     compute_devices=["GPU:0", "GPU:1"], parameter_device="CPU:0"))
2274    <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>
2275    >>> run(tf.distribute.OneDeviceStrategy("GPU:0"))
2276    <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>
2277
2278    Args:
2279      reduce_op: a `tf.distribute.ReduceOp` value specifying how values should
2280        be combined. Allows using string representation of the enum such as
2281        "SUM", "MEAN".
2282      value: a `tf.distribute.DistributedValues`, or a `tf.Tensor` like object.
2283      destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a
2284        `tf.Tensor` alike object, or a device string. It specifies the devices
2285        to reduce to. To perform an all-reduce, pass the same to `value` and
2286        `destinations`. Note that if it's a `tf.Variable`, the value is reduced
2287        to the devices of that variable, and this method doesn't update the
2288        variable.
2289      options: a `tf.distribute.experimental.CommunicationOptions`. Options to
2290        perform collective operations. This overrides the default options if the
2291        `tf.distribute.Strategy` takes one in the constructor. See
2292        `tf.distribute.experimental.CommunicationOptions` for details of the
2293        options.
2294
2295    Returns:
2296      A tensor or value reduced to `destinations`.
2297    """
2298    if options is None:
2299      options = collective_util.Options()
2300    _require_cross_replica_or_default_context_extended(self)
2301    assert not isinstance(destinations, (list, tuple))
2302    assert not isinstance(reduce_op, variable_scope.VariableAggregation)
2303    if isinstance(reduce_op, six.string_types):
2304      reduce_op = reduce_util.ReduceOp(reduce_op.upper())
2305    assert (reduce_op == reduce_util.ReduceOp.SUM or
2306            reduce_op == reduce_util.ReduceOp.MEAN)
2307    return self._reduce_to(reduce_op, value, destinations, options)
2308
2309  def _reduce_to(self, reduce_op, value, destinations, options):
2310    raise NotImplementedError("must be implemented in descendants")
2311
2312  def batch_reduce_to(self, reduce_op, value_destination_pairs, options=None):
2313    """Combine multiple `reduce_to` calls into one for faster execution.
2314
2315    Similar to `reduce_to`, but accepts a list of (value, destinations) pairs.
2316    It's more efficient than reduce each value separately.
2317
2318    This API currently can only be called in cross-replica context. Other
2319    variants to reduce values across replicas are:
2320    * `tf.distribute.StrategyExtended.reduce_to`: the non-batch version of
2321      this API.
2322    * `tf.distribute.ReplicaContext.all_reduce`: the counterpart of this API
2323      in replica context. It supports both batched and non-batched all-reduce.
2324    * `tf.distribute.Strategy.reduce`: a more convenient method to reduce
2325      to the host in cross-replica context.
2326
2327    See `reduce_to` for more information.
2328
2329    >>> @tf.function
2330    ... def step_fn(var):
2331    ...
2332    ...   def merge_fn(strategy, value, var):
2333    ...     # All-reduce the value. Note that `value` here is a
2334    ...     # `tf.distribute.DistributedValues`.
2335    ...     reduced = strategy.extended.batch_reduce_to(
2336    ...         tf.distribute.ReduceOp.SUM, [(value, var)])[0]
2337    ...     strategy.extended.update(var, lambda var, value: var.assign(value),
2338    ...         args=(reduced,))
2339    ...
2340    ...   value = tf.identity(1.)
2341    ...   tf.distribute.get_replica_context().merge_call(merge_fn,
2342    ...     args=(value, var))
2343    >>>
2344    >>> def run(strategy):
2345    ...   with strategy.scope():
2346    ...     v = tf.Variable(0.)
2347    ...     strategy.run(step_fn, args=(v,))
2348    ...     return v
2349    >>>
2350    >>> run(tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]))
2351    MirroredVariable:{
2352      0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>,
2353      1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=2.0>
2354    }
2355    >>> run(tf.distribute.experimental.CentralStorageStrategy(
2356    ...     compute_devices=["GPU:0", "GPU:1"], parameter_device="CPU:0"))
2357    <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>
2358    >>> run(tf.distribute.OneDeviceStrategy("GPU:0"))
2359    <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>
2360
2361    Args:
2362      reduce_op: a `tf.distribute.ReduceOp` value specifying how values should
2363        be combined. Allows using string representation of the enum such as
2364        "SUM", "MEAN".
2365      value_destination_pairs: a sequence of (value, destinations) pairs. See
2366        `tf.distribute.Strategy.reduce_to` for descriptions.
2367      options: a `tf.distribute.experimental.CommunicationOptions`. Options to
2368        perform collective operations. This overrides the default options if the
2369        `tf.distribute.Strategy` takes one in the constructor. See
2370        `tf.distribute.experimental.CommunicationOptions` for details of the
2371        options.
2372
2373    Returns:
2374      A list of reduced values, one per pair in `value_destination_pairs`.
2375    """
2376    if options is None:
2377      options = collective_util.Options()
2378    _require_cross_replica_or_default_context_extended(self)
2379    assert not isinstance(reduce_op, variable_scope.VariableAggregation)
2380    if isinstance(reduce_op, six.string_types):
2381      reduce_op = reduce_util.ReduceOp(reduce_op.upper())
2382    return self._batch_reduce_to(reduce_op, value_destination_pairs, options)
2383
2384  def _batch_reduce_to(self, reduce_op, value_destination_pairs, options):
2385    return [
2386        self.reduce_to(reduce_op, t, destinations=v, options=options)
2387        for t, v in value_destination_pairs
2388    ]
2389
2390  def _replica_ctx_all_reduce(self, reduce_op, value, options=None):
2391    """All-reduce `value` across all replicas so that all get the final result.
2392
2393    If `value` is a nested structure of tensors, all-reduces of these tensors
2394    will be batched when possible. `options` can be set to hint the batching
2395    behavior.
2396
2397    This API must be called in a replica context.
2398
2399    Args:
2400      reduce_op: A `tf.distribute.ReduceOp` value specifying how values should
2401        be combined. Allows using string representation of the enum such as
2402        "SUM", "MEAN".
2403      value: Value to be reduced. A tensor or a nested structure of tensors.
2404      options: A `tf.distribute.experimental.CommunicationOptions`. Options to
2405        perform collective operations. This overrides the default options if the
2406        `tf.distribute.Strategy` takes one in the constructor.
2407
2408    Returns:
2409      A tensor or a nested strucutre of tensors with the reduced values. The
2410      structure is the same as `value`.
2411    """
2412    if options is None:
2413      options = collective_util.Options()
2414    replica_context = distribution_strategy_context.get_replica_context()
2415    assert replica_context, (
2416        "`StrategyExtended._replica_ctx_all_reduce` must be called in"
2417        " a replica context")
2418
2419    def merge_fn(_, flat_value):
2420      return self.batch_reduce_to(reduce_op, [(v, v) for v in flat_value],
2421                                  options)
2422
2423    reduced = replica_context.merge_call(merge_fn, args=(nest.flatten(value),))
2424    return nest.pack_sequence_as(value, reduced)
2425
2426  def _gather_to(self, value, destinations, axis, options=None):
2427    """Gather `value` across replicas along axis-th dimension to `destinations`.
2428
2429    `gather_to` gathers `tf.distribute.DistributedValues` or `tf.Tensor`-like
2430    object, along `axis`-th dimension. It supports only dense tensors but NOT
2431    sparse tensor. This API can only be called in cross-replica context.
2432
2433    Args:
2434      value: a `tf.distribute.DistributedValues`, or a `tf.Tensor` like object.
2435      destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a
2436        `tf.Tensor` alike object, or a device string. It specifies the devices
2437        to reduce to. To perform an all-gather, pass the same to `value` and
2438        `destinations`. Note that if it's a `tf.Variable`, the value is reduced
2439        to the devices of that variable, and this method doesn't update the
2440        variable.
2441      axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the
2442        range [0, rank(value)).
2443      options: a `tf.distribute.experimental.CommunicationOptions`. Options to
2444        perform collective operations. This overrides the default options if the
2445        `tf.distribute.Strategy` takes one in the constructor. See
2446        `tf.distribute.experimental.CommunicationOptions` for details of the
2447        options.
2448
2449    Returns:
2450      A tensor or value gathered to `destinations`.
2451    """
2452    _require_cross_replica_or_default_context_extended(self)
2453    assert not isinstance(destinations, (list, tuple))
2454    if options is None:
2455      options = collective_util.Options()
2456    return self._gather_to_implementation(value, destinations, axis, options)
2457
2458  def _gather_to_implementation(self, value, destinations, axis, options):
2459    raise NotImplementedError("_gather_to must be implemented in descendants")
2460
2461  def _batch_gather_to(self, value_destination_pairs, axis, options=None):
2462    _require_cross_replica_or_default_context_extended(self)
2463    if options is None:
2464      options = collective_util.Options()
2465    return [
2466        self._gather_to(t, destinations=v, axis=axis, options=options)
2467        for t, v in value_destination_pairs
2468    ]
2469
2470  def update(self, var, fn, args=(), kwargs=None, group=True):
2471    """Run `fn` to update `var` using inputs mirrored to the same devices.
2472
2473    `tf.distribute.StrategyExtended.update` takes a distributed variable `var`
2474    to be updated, an update function `fn`, and `args` and `kwargs` for `fn`. It
2475    applies `fn` to each component variable of `var` and passes corresponding
2476    values from `args` and `kwargs`. Neither `args` nor `kwargs` may contain
2477    per-replica values. If they contain mirrored values, they will be unwrapped
2478    before calling `fn`. For example, `fn` can be `assign_add` and `args` can be
2479    a mirrored DistributedValues where each component contains the value to be
2480    added to this mirrored variable `var`. Calling `update` will call
2481    `assign_add` on each component variable of `var` with the corresponding
2482    tensor value on that device.
2483
2484    Example usage:
2485
2486    ```python
2487    strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1']) # With 2
2488    devices
2489    with strategy.scope():
2490      v = tf.Variable(5.0, aggregation=tf.VariableAggregation.SUM)
2491    def update_fn(v):
2492      return v.assign(1.0)
2493    result = strategy.extended.update(v, update_fn)
2494    # result is
2495    # Mirrored:{
2496    #  0: tf.Tensor(1.0, shape=(), dtype=float32),
2497    #  1: tf.Tensor(1.0, shape=(), dtype=float32)
2498    # }
2499    ```
2500
2501    If `var` is mirrored across multiple devices, then this method implements
2502    logic as following:
2503
2504    ```python
2505    results = {}
2506    for device, v in var:
2507      with tf.device(device):
2508        # args and kwargs will be unwrapped if they are mirrored.
2509        results[device] = fn(v, *args, **kwargs)
2510    return merged(results)
2511    ```
2512
2513    Otherwise, this method returns `fn(var, *args, **kwargs)` colocated with
2514    `var`.
2515
2516    Args:
2517      var: Variable, possibly mirrored to multiple devices, to operate on.
2518      fn: Function to call. Should take the variable as the first argument.
2519      args: Tuple or list. Additional positional arguments to pass to `fn()`.
2520      kwargs: Dict with keyword arguments to pass to `fn()`.
2521      group: Boolean. Defaults to True. If False, the return value will be
2522        unwrapped.
2523
2524    Returns:
2525      By default, the merged return value of `fn` across all replicas.  The
2526      merged result has dependencies to make sure that if it is evaluated at
2527      all, the side effects (updates) will happen on every replica. If instead
2528      "group=False" is specified, this function will return a nest of lists
2529      where each list has an element per replica, and the caller is responsible
2530      for ensuring all elements are executed.
2531    """
2532    _require_cross_replica_or_default_context_extended(self)
2533    if kwargs is None:
2534      kwargs = {}
2535    fn = autograph.tf_convert(
2536        fn, autograph_ctx.control_status_ctx(), convert_by_default=False)
2537    with self._container_strategy().scope():
2538      return self._update(var, fn, args, kwargs, group)
2539
2540  def _update(self, var, fn, args, kwargs, group):
2541    raise NotImplementedError("must be implemented in descendants")
2542
2543  def _local_results(self, distributed_value):
2544    raise NotImplementedError("must be implemented in descendants")
2545
2546  def value_container(self, value):
2547    """Returns the container that this per-replica `value` belongs to.
2548
2549    Args:
2550      value: A value returned by `run()` or a variable created in `scope()`.
2551
2552    Returns:
2553      A container that `value` belongs to.
2554      If value does not belong to any container (including the case of
2555      container having been destroyed), returns the value itself.
2556      `value in experimental_local_results(value_container(value))` will
2557      always be true.
2558    """
2559    raise NotImplementedError("must be implemented in descendants")
2560
2561  def _group(self, value, name=None):
2562    """Implementation of `group`."""
2563    value = nest.flatten(self._local_results(value))
2564
2565    if len(value) != 1 or name is not None:
2566      return control_flow_ops.group(value, name=name)
2567    # Special handling for the common case of one op.
2568    v, = value
2569    if hasattr(v, "op"):
2570      v = v.op
2571    return v
2572
2573  @property
2574  def experimental_require_static_shapes(self):
2575    """Returns `True` if static shape is required; `False` otherwise."""
2576    return self._require_static_shapes
2577
2578  @property
2579  def _num_replicas_in_sync(self):
2580    """Returns number of replicas over which gradients are aggregated."""
2581    raise NotImplementedError("must be implemented in descendants")
2582
2583  @property
2584  def worker_devices(self):
2585    """Returns the tuple of all devices used to for compute replica execution.
2586    """
2587    # TODO(josh11b): More docstring
2588    raise NotImplementedError("must be implemented in descendants")
2589
2590  @property
2591  def parameter_devices(self):
2592    """Returns the tuple of all devices used to place variables."""
2593    # TODO(josh11b): More docstring
2594    raise NotImplementedError("must be implemented in descendants")
2595
2596  def _configure(self,
2597                 session_config=None,
2598                 cluster_spec=None,
2599                 task_type=None,
2600                 task_id=None):
2601    """Configures the strategy class."""
2602    del session_config, cluster_spec, task_type, task_id
2603
2604  def _update_config_proto(self, config_proto):
2605    return copy.deepcopy(config_proto)
2606
2607  def _in_multi_worker_mode(self):
2608    """Whether this strategy indicates working in multi-worker settings.
2609
2610    Multi-worker training refers to the setup where the training is
2611    distributed across multiple workers, as opposed to the case where
2612    only a local process performs the training. This function is
2613    used by higher-level APIs such as Keras' `model.fit()` to infer
2614    for example whether or not a distribute coordinator should be run,
2615    and thus TensorFlow servers should be started for communication
2616    with other servers in the cluster, or whether or not saving/restoring
2617    checkpoints is relevant for preemption fault tolerance.
2618
2619    Subclasses should override this to provide whether the strategy is
2620    currently in multi-worker setup.
2621
2622    Experimental. Signature and implementation are subject to change.
2623    """
2624    raise NotImplementedError("must be implemented in descendants")
2625
2626
2627@tf_export(v1=["distribute.StrategyExtended"])  # pylint: disable=missing-docstring
2628class StrategyExtendedV1(StrategyExtendedV2):
2629
2630  __doc__ = StrategyExtendedV2.__doc__
2631
2632  def experimental_make_numpy_dataset(self, numpy_input, session=None):
2633    """Makes a dataset for input provided via a numpy array.
2634
2635    This avoids adding `numpy_input` as a large constant in the graph,
2636    and copies the data to the machine or machines that will be processing
2637    the input.
2638
2639    Args:
2640      numpy_input: A nest of NumPy input arrays that will be distributed evenly
2641        across all replicas. Note that lists of Numpy arrays are stacked, as
2642        that is normal `tf.data.Dataset` behavior.
2643      session: (TensorFlow v1.x graph execution only) A session used for
2644        initialization.
2645
2646    Returns:
2647      A `tf.data.Dataset` representing `numpy_input`.
2648    """
2649    _require_cross_replica_or_default_context_extended(self)
2650    return self._experimental_make_numpy_dataset(numpy_input, session=session)
2651
2652  def _experimental_make_numpy_dataset(self, numpy_input, session):
2653    raise NotImplementedError("must be implemented in descendants")
2654
2655  def broadcast_to(self, tensor, destinations):
2656    """Mirror a tensor on one device to all worker devices.
2657
2658    Args:
2659      tensor: A Tensor value to broadcast.
2660      destinations: A mirrored variable or device string specifying the
2661        destination devices to copy `tensor` to.
2662
2663    Returns:
2664      A value mirrored to `destinations` devices.
2665    """
2666    assert destinations is not None  # from old strategy.broadcast()
2667    # TODO(josh11b): More docstring
2668    _require_cross_replica_or_default_context_extended(self)
2669    assert not isinstance(destinations, (list, tuple))
2670    return self._broadcast_to(tensor, destinations)
2671
2672  def _broadcast_to(self, tensor, destinations):
2673    raise NotImplementedError("must be implemented in descendants")
2674
2675  def experimental_run_steps_on_iterator(self,
2676                                         fn,
2677                                         iterator,
2678                                         iterations=1,
2679                                         initial_loop_values=None):
2680    """DEPRECATED: please use `run` instead.
2681
2682    Run `fn` with input from `iterator` for `iterations` times.
2683
2684    This method can be used to run a step function for training a number of
2685    times using input from a dataset.
2686
2687    Args:
2688      fn: function to run using this distribution strategy. The function must
2689        have the following signature: `def fn(context, inputs)`. `context` is an
2690          instance of `MultiStepContext` that will be passed when `fn` is run.
2691          `context` can be used to specify the outputs to be returned from `fn`
2692          by calling `context.set_last_step_output`. It can also be used to
2693          capture non tensor outputs by `context.set_non_tensor_output`. See
2694          `MultiStepContext` documentation for more information. `inputs` will
2695          have same type/structure as `iterator.get_next()`. Typically, `fn`
2696          will use `call_for_each_replica` method of the strategy to distribute
2697          the computation over multiple replicas.
2698      iterator: Iterator of a dataset that represents the input for `fn`. The
2699        caller is responsible for initializing the iterator as needed.
2700      iterations: (Optional) Number of iterations that `fn` should be run.
2701        Defaults to 1.
2702      initial_loop_values: (Optional) Initial values to be passed into the
2703        loop that runs `fn`. Defaults to `None`. # TODO(priyag): Remove
2704          initial_loop_values argument when we have a mechanism to infer the
2705          outputs of `fn`.
2706
2707    Returns:
2708      Returns the `MultiStepContext` object which has the following properties,
2709      among other things:
2710        - run_op: An op that runs `fn` `iterations` times.
2711        - last_step_outputs: A dictionary containing tensors set using
2712        `context.set_last_step_output`. Evaluating this returns the value of
2713        the tensors after the last iteration.
2714        - non_tensor_outputs: A dictionary containing anything that was set by
2715          `fn` by calling `context.set_non_tensor_output`.
2716    """
2717    _require_cross_replica_or_default_context_extended(self)
2718    with self._container_strategy().scope():
2719      return self._experimental_run_steps_on_iterator(fn, iterator, iterations,
2720                                                      initial_loop_values)
2721
2722  def _experimental_run_steps_on_iterator(self, fn, iterator, iterations,
2723                                          initial_loop_values):
2724    raise NotImplementedError("must be implemented in descendants")
2725
2726  def call_for_each_replica(self, fn, args=(), kwargs=None):
2727    """Run `fn` once per replica.
2728
2729    `fn` may call `tf.get_replica_context()` to access methods such as
2730    `replica_id_in_sync_group` and `merge_call()`.
2731
2732    `merge_call()` is used to communicate between the replicas and
2733    re-enter the cross-replica context. All replicas pause their execution
2734    having encountered a `merge_call()` call. After that the
2735    `merge_fn`-function is executed. Its results are then unwrapped and
2736    given back to each replica call. After that execution resumes until
2737    `fn` is complete or encounters another `merge_call()`.  Example:
2738
2739    ```python
2740    # Called once in "cross-replica" context.
2741    def merge_fn(distribution, three_plus_replica_id):
2742      # sum the values across replicas
2743      return sum(distribution.experimental_local_results(three_plus_replica_id))
2744
2745    # Called once per replica in `distribution`, in a "replica" context.
2746    def fn(three):
2747      replica_ctx = tf.get_replica_context()
2748      v = three + replica_ctx.replica_id_in_sync_group
2749      # Computes the sum of the `v` values across all replicas.
2750      s = replica_ctx.merge_call(merge_fn, args=(v,))
2751      return s + v
2752
2753    with distribution.scope():
2754      # in "cross-replica" context
2755      ...
2756      merged_results = distribution.run(fn, args=[3])
2757      # merged_results has the values from every replica execution of `fn`.
2758      # This statement prints a list:
2759      print(distribution.experimental_local_results(merged_results))
2760    ```
2761
2762    Args:
2763      fn: function to run (will be run once per replica).
2764      args: Tuple or list with positional arguments for `fn`.
2765      kwargs: Dict with keyword arguments for `fn`.
2766
2767    Returns:
2768      Merged return value of `fn` across all replicas.
2769    """
2770    _require_cross_replica_or_default_context_extended(self)
2771    if kwargs is None:
2772      kwargs = {}
2773    with self._container_strategy().scope():
2774      return self._call_for_each_replica(fn, args, kwargs)
2775
2776  def _call_for_each_replica(self, fn, args, kwargs):
2777    raise NotImplementedError("must be implemented in descendants")
2778
2779  def read_var(self, v):
2780    """Reads the value of a variable.
2781
2782    Returns the aggregate value of a replica-local variable, or the
2783    (read-only) value of any other variable.
2784
2785    Args:
2786      v: A variable allocated within the scope of this `tf.distribute.Strategy`.
2787
2788    Returns:
2789      A tensor representing the value of `v`, aggregated across replicas if
2790      necessary.
2791    """
2792    raise NotImplementedError("must be implemented in descendants")
2793
2794  def update_non_slot(
2795      self, colocate_with, fn, args=(), kwargs=None, group=True):
2796    """Runs `fn(*args, **kwargs)` on `colocate_with` devices.
2797
2798    Used to update non-slot variables.
2799
2800    DEPRECATED: TF 1.x ONLY.
2801
2802    Args:
2803      colocate_with: Devices returned by `non_slot_devices()`.
2804      fn: Function to execute.
2805      args: Tuple or list. Positional arguments to pass to `fn()`.
2806      kwargs: Dict with keyword arguments to pass to `fn()`.
2807      group: Boolean. Defaults to True. If False, the return value will be
2808        unwrapped.
2809
2810    Returns:
2811      Return value of `fn`, possibly merged across devices.
2812    """
2813    _require_cross_replica_or_default_context_extended(self)
2814    if kwargs is None:
2815      kwargs = {}
2816    fn = autograph.tf_convert(
2817        fn, autograph_ctx.control_status_ctx(), convert_by_default=False)
2818    with self._container_strategy().scope():
2819      return self._update_non_slot(colocate_with, fn, args, kwargs, group)
2820
2821  def _update_non_slot(self, colocate_with, fn, args, kwargs, group):
2822    raise NotImplementedError("must be implemented in descendants")
2823
2824  def non_slot_devices(self, var_list):
2825    """Device(s) for non-slot variables.
2826
2827    DEPRECATED: TF 1.x ONLY.
2828
2829    This method returns non-slot devices where non-slot variables are placed.
2830    Users can create non-slot variables on these devices by using a block:
2831
2832    ```python
2833    with tf.distribute.StrategyExtended.colocate_vars_with(tf.distribute.StrategyExtended.non_slot_devices(...)):
2834      ...
2835    ```
2836
2837    Args:
2838      var_list: The list of variables being optimized, needed with the
2839        default `tf.distribute.Strategy`.
2840    Returns:
2841      A sequence of devices for non-slot variables.
2842    """
2843    raise NotImplementedError("must be implemented in descendants")
2844
2845  @property
2846  def experimental_between_graph(self):
2847    """Whether the strategy uses between-graph replication or not.
2848
2849      This is expected to return a constant value that will not be changed
2850      throughout its life cycle.
2851    """
2852    raise NotImplementedError("must be implemented in descendants")
2853
2854  @property
2855  def experimental_should_init(self):
2856    """Whether initialization is needed."""
2857    raise NotImplementedError("must be implemented in descendants")
2858
2859  @property
2860  def should_checkpoint(self):
2861    """Whether checkpointing is needed."""
2862    raise NotImplementedError("must be implemented in descendants")
2863
2864  @property
2865  def should_save_summary(self):
2866    """Whether saving summaries is needed."""
2867    raise NotImplementedError("must be implemented in descendants")
2868
2869
2870# A note about the difference between the context managers
2871# `ReplicaContext` (defined here) and `_CurrentDistributionContext`
2872# (defined above) used by `tf.distribute.Strategy.scope()`:
2873#
2874# * a ReplicaContext is only present during a `run()`
2875#   call (except during a `merge_run` call) and in such a scope it
2876#   will be returned by calls to `get_replica_context()`.  Implementers of new
2877#   Strategy descendants will frequently also need to
2878#   define a descendant of ReplicaContext, and are responsible for
2879#   entering and exiting this context.
2880#
2881# * Strategy.scope() sets up a variable_creator scope that
2882#   changes variable creation calls (e.g. to make mirrored
2883#   variables). This is intended as an outer scope that users enter once
2884#   around their model creation and graph definition. There is no
2885#   anticipated need to define descendants of _CurrentDistributionContext.
2886#   It sets the current Strategy for purposes of
2887#   `get_strategy()` and `has_strategy()`
2888#   and switches the thread mode to a "cross-replica context".
2889class ReplicaContextBase(object):
2890  """A class with a collection of APIs that can be called in a replica context.
2891
2892  You can use `tf.distribute.get_replica_context` to get an instance of
2893  `ReplicaContext`, which can only be called inside the function passed to
2894  `tf.distribute.Strategy.run`.
2895
2896  >>> strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1'])
2897  >>> def func():
2898  ...   replica_context = tf.distribute.get_replica_context()
2899  ...   return replica_context.replica_id_in_sync_group
2900  >>> strategy.run(func)
2901  PerReplica:{
2902    0: <tf.Tensor: shape=(), dtype=int32, numpy=0>,
2903    1: <tf.Tensor: shape=(), dtype=int32, numpy=1>
2904  }
2905  """
2906
2907  def __init__(self, strategy, replica_id_in_sync_group):
2908    """Creates a ReplicaContext.
2909
2910    Args:
2911      strategy: A `tf.distribute.Strategy`.
2912      replica_id_in_sync_group: An integer, a `Tensor` or None. Prefer an
2913        integer whenever possible to avoid issues with nested `tf.function`. It
2914        accepts a `Tensor` only to be compatible with `tpu.replicate`.
2915    """
2916    self._strategy = strategy
2917    self._thread_context = distribution_strategy_context._InReplicaThreadMode(  # pylint: disable=protected-access
2918        self)
2919    if not (replica_id_in_sync_group is None or
2920            tensor_util.is_tf_type(replica_id_in_sync_group) or
2921            isinstance(replica_id_in_sync_group, int)):
2922      raise ValueError(
2923          "replica_id_in_sync_group can only be an integer, a Tensor or None.")
2924    self._replica_id_in_sync_group = replica_id_in_sync_group
2925    # We need this check because TPUContext extends from ReplicaContext and
2926    # does not pass a strategy object since it is used by TPUEstimator.
2927    if strategy:
2928      self._local_replica_id = strategy.extended._get_local_replica_id(
2929          replica_id_in_sync_group)
2930    self._summary_recording_distribution_strategy = None
2931
2932  @doc_controls.do_not_generate_docs
2933  def __enter__(self):
2934    _push_per_thread_mode(self._thread_context)
2935
2936    def replica_id_is_zero():
2937      return math_ops.equal(self.replica_id_in_sync_group,
2938                            constant_op.constant(0))
2939
2940    summary_state = summary_ops_v2._summary_state  # pylint: disable=protected-access
2941    self._summary_recording_distribution_strategy = (
2942        summary_state.is_recording_distribution_strategy)
2943    summary_state.is_recording_distribution_strategy = replica_id_is_zero
2944
2945  @doc_controls.do_not_generate_docs
2946  def __exit__(self, exception_type, exception_value, traceback):
2947    summary_state = summary_ops_v2._summary_state  # pylint: disable=protected-access
2948    summary_state.is_recording_distribution_strategy = (
2949        self._summary_recording_distribution_strategy)
2950    _pop_per_thread_mode()
2951
2952  def merge_call(self, merge_fn, args=(), kwargs=None):
2953    """Merge args across replicas and run `merge_fn` in a cross-replica context.
2954
2955    This allows communication and coordination when there are multiple calls
2956    to the step_fn triggered by a call to `strategy.run(step_fn, ...)`.
2957
2958    See `tf.distribute.Strategy.run` for an explanation.
2959
2960    If not inside a distributed scope, this is equivalent to:
2961
2962    ```
2963    strategy = tf.distribute.get_strategy()
2964    with cross-replica-context(strategy):
2965      return merge_fn(strategy, *args, **kwargs)
2966    ```
2967
2968    Args:
2969      merge_fn: Function that joins arguments from threads that are given as
2970        PerReplica. It accepts `tf.distribute.Strategy` object as
2971        the first argument.
2972      args: List or tuple with positional per-thread arguments for `merge_fn`.
2973      kwargs: Dict with keyword per-thread arguments for `merge_fn`.
2974
2975    Returns:
2976      The return value of `merge_fn`, except for `PerReplica` values which are
2977      unpacked.
2978    """
2979    require_replica_context(self)
2980    if kwargs is None:
2981      kwargs = {}
2982
2983    merge_fn = autograph.tf_convert(
2984        merge_fn, autograph_ctx.control_status_ctx(), convert_by_default=False)
2985    return self._merge_call(merge_fn, args, kwargs)
2986
2987  def _merge_call(self, merge_fn, args, kwargs):
2988    """Default implementation for single replica."""
2989    _push_per_thread_mode(  # thread-local, so not needed with multiple threads
2990        distribution_strategy_context._CrossReplicaThreadMode(self._strategy))  # pylint: disable=protected-access
2991    try:
2992      return merge_fn(self._strategy, *args, **kwargs)
2993    finally:
2994      _pop_per_thread_mode()
2995
2996  @property
2997  def num_replicas_in_sync(self):
2998    """Returns number of replicas that are kept in sync."""
2999    return self._strategy.num_replicas_in_sync
3000
3001  @property
3002  def replica_id_in_sync_group(self):
3003    """Returns the id of the replica.
3004
3005    This identifies the replica among all replicas that are kept in sync. The
3006    value of the replica id can range from 0 to
3007    `tf.distribute.ReplicaContext.num_replicas_in_sync` - 1.
3008
3009    NOTE: This is not guaranteed to be the same ID as the XLA replica ID use
3010    for low-level operations such as collective_permute.
3011
3012    Returns:
3013      a `Tensor`.
3014    """
3015    # It's important to prefer making the Tensor at call time whenever possible.
3016    # Keeping Tensors in global states doesn't work well with nested
3017    # tf.function, since it's possible that the tensor is generated in one func
3018    # graph, and gets captured by another, which will result in a subtle "An op
3019    # outside of the function building code is being passed a Graph tensor"
3020    # error. Making the tensor at call time to ensure it is the same graph where
3021    # it's used. However to be compatible with tpu.replicate(),
3022    # self._replica_id_in_sync_group can also be a Tensor.
3023    if tensor_util.is_tf_type(self._replica_id_in_sync_group):
3024      return self._replica_id_in_sync_group
3025    return constant_op.constant(
3026        self._replica_id_in_sync_group,
3027        dtypes.int32,
3028        name="replica_id_in_sync_group")
3029
3030  @property
3031  def _replica_id(self):
3032    """This is the local replica id in a given sync group."""
3033    return self._local_replica_id
3034
3035  @property
3036  def strategy(self):
3037    """The current `tf.distribute.Strategy` object."""
3038    return self._strategy
3039
3040  @property
3041  @deprecation.deprecated(None, "Please avoid relying on devices property.")
3042  def devices(self):
3043    """Returns the devices this replica is to be executed on, as a tuple of strings.
3044
3045    NOTE: For `tf.distribute.MirroredStrategy` and
3046    `tf.distribute.experimental.MultiWorkerMirroredStrategy`, this returns a
3047    nested
3048    list of device strings, e.g, [["GPU:0"]].
3049    """
3050    require_replica_context(self)
3051    return (device_util.current(),)
3052
3053  def all_reduce(self, reduce_op, value, options=None):
3054    """All-reduces `value` across all replicas.
3055
3056    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
3057    >>> def step_fn():
3058    ...   ctx = tf.distribute.get_replica_context()
3059    ...   value = tf.identity(1.)
3060    ...   return ctx.all_reduce(tf.distribute.ReduceOp.SUM, value)
3061    >>> strategy.experimental_local_results(strategy.run(step_fn))
3062    (<tf.Tensor: shape=(), dtype=float32, numpy=2.0>,
3063     <tf.Tensor: shape=(), dtype=float32, numpy=2.0>)
3064
3065    It supports batched operations. You can pass a list of values and it
3066    attempts to batch them when possible. You can also specify `options`
3067    to indicate the desired batching behavior, e.g. batch the values into
3068    multiple packs so that they can better overlap with computations.
3069
3070    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
3071    >>> def step_fn():
3072    ...   ctx = tf.distribute.get_replica_context()
3073    ...   value1 = tf.identity(1.)
3074    ...   value2 = tf.identity(2.)
3075    ...   return ctx.all_reduce(tf.distribute.ReduceOp.SUM, [value1, value2])
3076    >>> strategy.experimental_local_results(strategy.run(step_fn))
3077    ([PerReplica:{
3078      0: <tf.Tensor: shape=(), dtype=float32, numpy=2.0>,
3079      1: <tf.Tensor: shape=(), dtype=float32, numpy=2.0>
3080    }, PerReplica:{
3081      0: <tf.Tensor: shape=(), dtype=float32, numpy=4.0>,
3082      1: <tf.Tensor: shape=(), dtype=float32, numpy=4.0>
3083    }],)
3084
3085    Note that all replicas need to participate in the all-reduce, otherwise this
3086    operation hangs. Note that if there're multiple all-reduces, they need to
3087    execute in the same order on all replicas. Dispatching all-reduce based on
3088    conditions is usually error-prone.
3089
3090    This API currently can only be called in the replica context. Other
3091    variants to reduce values across replicas are:
3092    * `tf.distribute.StrategyExtended.reduce_to`: the reduce and all-reduce API
3093      in the cross-replica context.
3094    * `tf.distribute.StrategyExtended.batch_reduce_to`: the batched reduce and
3095      all-reduce API in the cross-replica context.
3096    * `tf.distribute.Strategy.reduce`: a more convenient method to reduce
3097      to the host in cross-replica context.
3098
3099    Args:
3100      reduce_op: a `tf.distribute.ReduceOp` value specifying how values should
3101        be combined. Allows using string representation of the enum such as
3102        "SUM", "MEAN".
3103      value: a nested structure of `tf.Tensor` which `tf.nest.flatten` accepts.
3104        The structure and the shapes of the `tf.Tensor` need to be same on all
3105        replicas.
3106      options: a `tf.distribute.experimental.CommunicationOptions`. Options to
3107        perform collective operations. This overrides the default options if the
3108        `tf.distribute.Strategy` takes one in the constructor. See
3109        `tf.distribute.experimental.CommunicationOptions` for details of the
3110        options.
3111
3112    Returns:
3113       A nested structure of `tf.Tensor` with the reduced values. The structure
3114       is the same as `value`.
3115    """
3116    if isinstance(reduce_op, six.string_types):
3117      reduce_op = reduce_util.ReduceOp(reduce_op.upper())
3118    if options is None:
3119      options = collective_util.Options()
3120
3121    def batch_all_reduce(strategy, *value_flat):
3122      return strategy.extended.batch_reduce_to(
3123          reduce_op, [(v, _batch_reduce_destination(v)) for v in value_flat],
3124          options)
3125
3126    if reduce_op in [reduce_util.ReduceOp.SUM, reduce_util.ReduceOp.MEAN]:
3127      # TODO(cjfj): Work out why `batch_reduce` doesn't return the correct grad.
3128      @custom_gradient.custom_gradient
3129      def grad_wrapper(*xs):
3130        ys = self.merge_call(batch_all_reduce, args=xs)
3131        # The gradient of an all-sum is itself an all-sum (all-mean, likewise).
3132        return ys, lambda *dy_s: self.all_reduce(reduce_op, dy_s)
3133      return nest.pack_sequence_as(value, grad_wrapper(*nest.flatten(value)))
3134    else:
3135      # TODO(cjfj): Implement gradients for other reductions.
3136      reduced = nest.pack_sequence_as(
3137          value, self.merge_call(batch_all_reduce, args=nest.flatten(value)))
3138      return nest.map_structure(array_ops.prevent_gradient, reduced)
3139
3140  # TODO(josh11b): Implement `start_all_reduce(method, t)` for efficient
3141  # all-reduce. It would return a function returning the result of reducing `t`
3142  # across all replicas. The caller would wait to call this function until they
3143  # needed the reduce result, allowing an efficient implementation:
3144  # * With eager execution, the reduction could be performed asynchronously
3145  #   in the background, not blocking until the result was needed.
3146  # * When constructing a graph, it could batch up all reduction requests up
3147  #   to that point that the first result is needed. Most likely this can be
3148  #   implemented in terms of `merge_call()` and `batch_reduce_to()`.
3149
3150
3151@tf_export("distribute.ReplicaContext", v1=[])
3152class ReplicaContext(ReplicaContextBase):
3153
3154  __doc__ = ReplicaContextBase.__doc__
3155
3156  def all_gather(self, value, axis, options=None):
3157    """All-gathers `value` across all replicas along `axis`.
3158
3159    Note: An `all_gather` method can only be called in replica context. For
3160    a cross-replica context counterpart, see `tf.distribute.Strategy.gather`.
3161    All replicas need to participate in the all-gather, otherwise this
3162    operation hangs. So if `all_gather` is called in any replica, it must be
3163    called in all replicas.
3164
3165    Note: If there are multiple `all_gather` calls, they need to be executed in
3166    the same order on all replicas. Dispatching `all_gather` based on conditions
3167    is usually error-prone.
3168
3169    For all strategies except `tf.distribute.TPUStrategy`, the input
3170    `value` on different replicas must have the same rank, and their shapes must
3171    be the same in all dimensions except the `axis`-th dimension. In other
3172    words, their shapes cannot be different in a dimension `d` where `d` does
3173    not equal to the `axis` argument. For example, given a
3174    `tf.distribute.DistributedValues` with component tensors of shape
3175    `(1, 2, 3)` and `(1, 3, 3)` on two replicas, you can call
3176    `all_gather(..., axis=1, ...)` on it, but not `all_gather(..., axis=0, ...)`
3177    or `all_gather(..., axis=2, ...)`. However, with
3178    `tf.distribute.TPUStrategy`, all tensors must have exactly the same rank and
3179    same shape.
3180
3181    Note: The input `value` must have a non-zero rank. Otherwise, consider using
3182    `tf.expand_dims` before gathering them.
3183
3184    You can pass in a single tensor to all-gather:
3185
3186    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
3187    >>> @tf.function
3188    ... def gather_value():
3189    ...   ctx = tf.distribute.get_replica_context()
3190    ...   local_value = tf.constant([1, 2, 3])
3191    ...   return ctx.all_gather(local_value, axis=0)
3192    >>> result = strategy.run(gather_value)
3193    >>> result
3194    PerReplica:{
3195      0: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>,
3196      1: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>
3197    }
3198    >>> strategy.experimental_local_results(result)
3199    (<tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3],
3200    dtype=int32)>,
3201    <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3],
3202    dtype=int32)>)
3203
3204
3205    You can also pass in a nested structure of tensors to all-gather, say, a
3206    list:
3207
3208    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
3209    >>> @tf.function
3210    ... def gather_nest():
3211    ...   ctx = tf.distribute.get_replica_context()
3212    ...   value_1 = tf.constant([1, 2, 3])
3213    ...   value_2 = tf.constant([[1, 2], [3, 4]])
3214    ...   # all_gather a nest of `tf.distribute.DistributedValues`
3215    ...   return ctx.all_gather([value_1, value_2], axis=0)
3216    >>> result = strategy.run(gather_nest)
3217    >>> result
3218    [PerReplica:{
3219      0: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>,
3220      1: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>
3221    }, PerReplica:{
3222      0: <tf.Tensor: shape=(4, 2), dtype=int32, numpy=
3223    array([[1, 2],
3224           [3, 4],
3225           [1, 2],
3226           [3, 4]], dtype=int32)>,
3227      1: <tf.Tensor: shape=(4, 2), dtype=int32, numpy=
3228    array([[1, 2],
3229           [3, 4],
3230           [1, 2],
3231           [3, 4]], dtype=int32)>
3232    }]
3233    >>> strategy.experimental_local_results(result)
3234    ([PerReplica:{
3235      0: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>,
3236      1: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>
3237    }, PerReplica:{
3238      0: <tf.Tensor: shape=(4, 2), dtype=int32, numpy=
3239    array([[1, 2],
3240           [3, 4],
3241           [1, 2],
3242           [3, 4]], dtype=int32)>,
3243      1: <tf.Tensor: shape=(4, 2), dtype=int32, numpy=
3244    array([[1, 2],
3245           [3, 4],
3246           [1, 2],
3247           [3, 4]], dtype=int32)>
3248    }],)
3249
3250
3251    What if you are all-gathering tensors with different shapes on different
3252    replicas? Consider the following example with two replicas, where you have
3253    `value` as a nested structure consisting of two items to all-gather, `a` and
3254    `b`.
3255
3256      On Replica 0, `value` is `{'a': [0], 'b': [[0, 1]]}`.
3257
3258      On Replica 1, `value` is `{'a': [1], 'b': [[2, 3], [4, 5]]}`.
3259
3260      Result for `all_gather` with `axis`=0 (on each of the replicas) is:
3261
3262      ```{'a': [1, 2], 'b': [[0, 1], [2, 3], [4, 5]]}```
3263
3264    Args:
3265      value: a nested structure of `tf.Tensor` which `tf.nest.flatten` accepts,
3266        or a `tf.distribute.DistributedValues` instance. The structure of the
3267        `tf.Tensor` need to be same on all replicas. The underlying tensor
3268        constructs can only be dense tensors with non-zero rank, NOT
3269        `tf.IndexedSlices`.
3270      axis: 0-D int32 Tensor. Dimension along which to gather.
3271      options: a `tf.distribute.experimental.CommunicationOptions`. Options to
3272        perform collective operations. This overrides the default options if the
3273        `tf.distribute.Strategy` takes one in the constructor. See
3274        `tf.distribute.experimental.CommunicationOptions` for details of the
3275        options.
3276
3277    Returns:
3278       A nested structure of `tf.Tensor` with the gathered values. The structure
3279       is the same as `value`.
3280    """
3281    for v in nest.flatten(value):
3282      if isinstance(v, ops.IndexedSlices):
3283        raise NotImplementedError("all_gather does not support IndexedSlices")
3284
3285    if options is None:
3286      options = collective_util.Options()
3287
3288    def batch_all_gather(strategy, *value_flat):
3289      return strategy.extended._batch_gather_to(  # pylint: disable=protected-access
3290          [(v, _batch_reduce_destination(v)) for v in value_flat], axis,
3291          options)
3292
3293    @custom_gradient.custom_gradient
3294    def grad_wrapper(*xs):
3295      ys = self.merge_call(batch_all_gather, args=xs)
3296
3297      def grad(*dy_s):
3298        grads = self.all_reduce(reduce_util.ReduceOp.SUM, dy_s)
3299        new_grads = []
3300        for i, grad in enumerate(grads):
3301          input_shape = array_ops.shape(xs[i])
3302          axis_dim = array_ops.reshape(input_shape[axis], [1])
3303          with ops.control_dependencies([array_ops.identity(grads)]):
3304            d = self.all_gather(axis_dim, axis=0)
3305            begin_dim = math_ops.reduce_sum(d[:self.replica_id_in_sync_group])
3306            end_dim = begin_dim + array_ops.shape(xs[i])[axis]
3307            new_grad = array_ops.gather(
3308                grad, axis=axis, indices=math_ops.range(begin_dim, end_dim))
3309            new_grads.append(new_grad)
3310        return new_grads
3311
3312      return ys, grad
3313
3314    return nest.pack_sequence_as(value, grad_wrapper(*nest.flatten(value)))
3315
3316
3317@tf_export(v1=["distribute.ReplicaContext"])
3318class ReplicaContextV1(ReplicaContextBase):
3319  __doc__ = ReplicaContextBase.__doc__
3320
3321
3322def _batch_reduce_destination(x):
3323  """Returns the destinations for batch all-reduce."""
3324  if isinstance(x, ops.Tensor):
3325    # If this is a one device strategy.
3326    return x.device
3327  else:
3328    return x
3329
3330
3331# ------------------------------------------------------------------------------
3332
3333
3334_creating_default_strategy_singleton = False
3335
3336
3337class _DefaultDistributionStrategyV1(StrategyV1):
3338  """Default `tf.distribute.Strategy` if none is explicitly selected."""
3339
3340  def __init__(self):
3341    if not _creating_default_strategy_singleton:
3342      raise RuntimeError("Should only create a single instance of "
3343                         "_DefaultDistributionStrategy")
3344    super(_DefaultDistributionStrategyV1,
3345          self).__init__(_DefaultDistributionExtended(self))
3346
3347  def __deepcopy__(self, memo):
3348    del memo
3349    raise RuntimeError("Should only create a single instance of "
3350                       "_DefaultDistributionStrategy")
3351
3352
3353class _DefaultDistributionStrategy(Strategy):
3354  """Default `tf.distribute.Strategy` if none is explicitly selected."""
3355
3356  def __init__(self):
3357    if not _creating_default_strategy_singleton:
3358      raise RuntimeError("Should only create a single instance of "
3359                         "_DefaultDistributionStrategy")
3360    super(_DefaultDistributionStrategy, self).__init__(
3361        _DefaultDistributionExtended(self))
3362
3363  def __deepcopy__(self, memo):
3364    del memo
3365    raise RuntimeError("Should only create a single instance of "
3366                       "_DefaultDistributionStrategy")
3367
3368
3369class _DefaultDistributionContext(object):
3370  """Context manager setting the default `tf.distribute.Strategy`."""
3371
3372  __slots__ = ["_var_creator_scope", "_strategy", "_nested_count"]
3373
3374  def __init__(self, strategy):
3375
3376    def creator(next_creator, **kwargs):
3377      _require_strategy_scope_strategy(strategy)
3378      return next_creator(**kwargs)
3379
3380    self._var_creator_scope = variable_scope.variable_creator_scope(creator)
3381    self._strategy = strategy
3382    self._nested_count = 0
3383
3384  def __enter__(self):
3385    # Allow this scope to be entered if this strategy is already in scope.
3386    if distribution_strategy_context.has_strategy():
3387      raise RuntimeError("Must not nest tf.distribute.Strategy scopes.")
3388    if self._nested_count == 0:
3389      self._var_creator_scope.__enter__()
3390    self._nested_count += 1
3391    return self._strategy
3392
3393  def __exit__(self, exception_type, exception_value, traceback):
3394    self._nested_count -= 1
3395    if self._nested_count == 0:
3396      try:
3397        self._var_creator_scope.__exit__(
3398            exception_type, exception_value, traceback)
3399      except RuntimeError as e:
3400        six.raise_from(
3401            RuntimeError("Variable creator scope nesting error: move call to "
3402                         "tf.distribute.set_strategy() out of `with` scope."),
3403            e)
3404
3405
3406class _DefaultDistributionExtended(StrategyExtendedV1):
3407  """Implementation of _DefaultDistributionStrategy."""
3408
3409  def __init__(self, container_strategy):
3410    super(_DefaultDistributionExtended, self).__init__(container_strategy)
3411    self._retrace_functions_for_each_device = False
3412
3413  def _scope(self, strategy):
3414    """Context manager setting a variable creator and `self` as current."""
3415    return _DefaultDistributionContext(strategy)
3416
3417  def colocate_vars_with(self, colocate_with_variable):
3418    """Does not require `self.scope`."""
3419    _require_strategy_scope_extended(self)
3420    return ops.colocate_with(colocate_with_variable)
3421
3422  def variable_created_in_scope(self, v):
3423    return v._distribute_strategy is None  # pylint: disable=protected-access
3424
3425  def _experimental_distribute_dataset(self, dataset, options):
3426    return dataset
3427
3428  def _distribute_datasets_from_function(self, dataset_fn, options):
3429    return dataset_fn(InputContext())
3430
3431  def _experimental_distribute_values_from_function(self, value_fn):
3432    return value_fn(ValueContext())
3433
3434  def _make_dataset_iterator(self, dataset):
3435    return _DefaultDistributionExtended.DefaultInputIterator(dataset)
3436
3437  def _make_input_fn_iterator(self,
3438                              input_fn,
3439                              replication_mode=InputReplicationMode.PER_WORKER):
3440    dataset = input_fn(InputContext())
3441    return _DefaultDistributionExtended.DefaultInputIterator(dataset)
3442
3443  def _experimental_make_numpy_dataset(self, numpy_input, session):
3444    numpy_flat = nest.flatten(numpy_input)
3445    vars_flat = tuple(
3446        variable_scope.variable(array_ops.zeros(i.shape, i.dtype),
3447                                trainable=False, use_resource=True)
3448        for i in numpy_flat
3449    )
3450    for v, i in zip(vars_flat, numpy_flat):
3451      numpy_dataset.init_var_from_numpy(v, i, session)
3452    vars_nested = nest.pack_sequence_as(numpy_input, vars_flat)
3453    return dataset_ops.Dataset.from_tensor_slices(vars_nested)
3454
3455  def _broadcast_to(self, tensor, destinations):
3456    if destinations is None:
3457      return tensor
3458    else:
3459      raise NotImplementedError("TODO")
3460
3461  def _call_for_each_replica(self, fn, args, kwargs):
3462    with ReplicaContext(self._container_strategy(), replica_id_in_sync_group=0):
3463      return fn(*args, **kwargs)
3464
3465  def _reduce_to(self, reduce_op, value, destinations, options):
3466    # TODO(josh11b): Use destinations?
3467    del reduce_op, destinations, options
3468    return value
3469
3470  def _gather_to_implementation(self, value, destinations, axis, options):
3471    del destinations, axis, options
3472    return value
3473
3474  def _update(self, var, fn, args, kwargs, group):
3475    # The implementations of _update() and _update_non_slot() are identical
3476    # except _update() passes `var` as the first argument to `fn()`.
3477    return self._update_non_slot(var, fn, (var,) + tuple(args), kwargs, group)
3478
3479  def _update_non_slot(self, colocate_with, fn, args, kwargs, should_group):
3480    # TODO(josh11b): Figure out what we should be passing to UpdateContext()
3481    # once that value is used for something.
3482    with UpdateContext(colocate_with):
3483      result = fn(*args, **kwargs)
3484      if should_group:
3485        return result
3486      else:
3487        return nest.map_structure(self._local_results, result)
3488
3489  def read_var(self, replica_local_var):
3490    return array_ops.identity(replica_local_var)
3491
3492  def _local_results(self, distributed_value):
3493    return (distributed_value,)
3494
3495  def value_container(self, value):
3496    return value
3497
3498  @property
3499  def _num_replicas_in_sync(self):
3500    return 1
3501
3502  @property
3503  def worker_devices(self):
3504    raise RuntimeError("worker_devices() method unsupported by default "
3505                       "tf.distribute.Strategy.")
3506
3507  @property
3508  def parameter_devices(self):
3509    raise RuntimeError("parameter_devices() method unsupported by default "
3510                       "tf.distribute.Strategy.")
3511
3512  def non_slot_devices(self, var_list):
3513    return min(var_list, key=lambda x: x.name)
3514
3515  def _in_multi_worker_mode(self):
3516    """Whether this strategy indicates working in multi-worker settings."""
3517    # Default strategy doesn't indicate multi-worker training.
3518    return False
3519
3520  @property
3521  def should_checkpoint(self):
3522    return True
3523
3524  @property
3525  def should_save_summary(self):
3526    return True
3527
3528  def _get_local_replica_id(self, replica_id_in_sync_group):
3529    return replica_id_in_sync_group
3530
3531  def _get_replica_id_in_sync_group(self, replica_id):
3532    return replica_id
3533
3534  # TODO(priyag): This should inherit from `InputIterator`, once dependency
3535  # issues have been resolved.
3536  class DefaultInputIterator(object):
3537    """Default implementation of `InputIterator` for default strategy."""
3538
3539    def __init__(self, dataset):
3540      self._dataset = dataset
3541      if eager_context.executing_eagerly():
3542        self._iterator = dataset_ops.make_one_shot_iterator(dataset)
3543      else:
3544        self._iterator = dataset_ops.make_initializable_iterator(dataset)
3545
3546    def get_next(self):
3547      return self._iterator.get_next()
3548
3549    def get_next_as_optional(self):
3550      return self._iterator.get_next_as_optional()
3551
3552    @deprecated(None, "Use the iterator's `initializer` property instead.")
3553    def initialize(self):
3554      """Initialize underlying iterators.
3555
3556      Returns:
3557        A list of any initializer ops that should be run.
3558      """
3559      if eager_context.executing_eagerly():
3560        self._iterator = self._dataset.make_one_shot_iterator()
3561        return []
3562      else:
3563        return [self._iterator.initializer]
3564
3565    @property
3566    def initializer(self):
3567      """Returns a list of ops that initialize the iterator."""
3568      return self.initialize()
3569
3570  # TODO(priyag): Delete this once all strategies use global batch size.
3571  @property
3572  def _global_batch_size(self):
3573    """Global and per-replica batching are equivalent for this strategy."""
3574    return True
3575
3576
3577class _DefaultReplicaContext(ReplicaContext):
3578  """ReplicaContext for _DefaultDistributionStrategy."""
3579
3580  @property
3581  def replica_id_in_sync_group(self):
3582    # Return 0 instead of a constant tensor to avoid creating a new node for
3583    # users who don't use distribution strategy.
3584    return 0
3585
3586
3587# ------------------------------------------------------------------------------
3588# We haven't yet implemented deserialization for DistributedVariables.
3589# So here we catch any attempts to deserialize variables
3590# when using distribution strategies.
3591# pylint: disable=protected-access
3592_original_from_proto = resource_variable_ops._from_proto_fn
3593
3594
3595def _from_proto_fn(v, import_scope=None):
3596  if distribution_strategy_context.has_strategy():
3597    raise NotImplementedError(
3598        "Deserialization of variables is not yet supported when using a "
3599        "tf.distribute.Strategy.")
3600  else:
3601    return _original_from_proto(v, import_scope=import_scope)
3602
3603resource_variable_ops._from_proto_fn = _from_proto_fn
3604# pylint: enable=protected-access
3605
3606
3607#-------------------------------------------------------------------------------
3608# Shorthand for some methods from distribution_strategy_context.
3609_push_per_thread_mode = distribution_strategy_context._push_per_thread_mode  # pylint: disable=protected-access
3610_get_per_thread_mode = distribution_strategy_context._get_per_thread_mode  # pylint: disable=protected-access
3611_pop_per_thread_mode = distribution_strategy_context._pop_per_thread_mode  # pylint: disable=protected-access
3612_get_default_replica_mode = (
3613    distribution_strategy_context._get_default_replica_mode)  # pylint: disable=protected-access
3614
3615
3616# ------------------------------------------------------------------------------
3617# Metrics to track which distribution strategy is being called
3618distribution_strategy_gauge = monitoring.StringGauge(
3619    "/tensorflow/api/distribution_strategy",
3620    "Gauge to track the type of distribution strategy used.", "TFVersion")
3621distribution_strategy_replica_gauge = monitoring.IntGauge(
3622    "/tensorflow/api/distribution_strategy/replica",
3623    "Gauge to track the number of replica each distribution strategy used.",
3624    "CountType")
3625