1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Various classes representing distributed values."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import copy
22import weakref
23
24from tensorflow.python.distribute import device_util
25from tensorflow.python.distribute import distribute_lib
26from tensorflow.python.distribute import distribution_strategy_context as ds_context
27from tensorflow.python.distribute import packed_distributed_variable as packed
28from tensorflow.python.distribute import reduce_util
29from tensorflow.python.distribute import values_util
30from tensorflow.python.eager import context
31from tensorflow.python.framework import composite_tensor
32from tensorflow.python.framework import ops
33from tensorflow.python.framework import type_spec
34from tensorflow.python.ops import array_ops
35from tensorflow.python.ops import control_flow_ops
36from tensorflow.python.ops import math_ops
37from tensorflow.python.ops import variable_scope as vs
38from tensorflow.python.ops import variables as variables_lib
39from tensorflow.python.saved_model import save_context
40from tensorflow.python.training.saving import saveable_object
41from tensorflow.python.training.tracking import base as trackable
42from tensorflow.python.types import core
43from tensorflow.python.util.tf_export import tf_export
44
45
46def _on_write_update_replica(var, update_fn, value, **kwargs):
47  """Updates variables with ON_WRITE synchronization in replica context."""
48  if var.aggregation == vs.VariableAggregation.NONE:
49    return update_fn(var._get_on_device_or_primary(), value, **kwargs)  # pylint: disable=protected-access
50
51  def merge_fn(strategy, value, **kwargs):
52    """Aggregate values and update all variables in cross replica context."""
53    # Don't allow MEAN with non float dtype, since it may cause unexpected
54    # precision loss. Python3 and NumPy automatically upcast integers to
55    # float in division, but we should always preserve the type.
56    #
57    # Note that to be backward compatible we allow the case when the value
58    # is *always* the same on each replica. I.E. value is not a
59    # PerReplica. Refer to regroup() to see how values are grouped.
60    if var.aggregation == vs.VariableAggregation.MEAN and (
61        not var.dtype.is_floating) and isinstance(value, PerReplica):
62      raise ValueError(
63          "Cannot update non-float variables with "
64          "tf.VariableAggregation.MEAN aggregation in replica context. "
65          "Either change the variable dtype to float or update it in "
66          "cross-replica context.")
67
68    assert strategy == var.distribute_strategy
69    v = values_util.apply_aggregation(strategy, value, var.aggregation, var)
70    return var._update_cross_replica(update_fn, v, **kwargs)  # pylint: disable=protected-access
71
72  return ds_context.get_replica_context().merge_call(
73      merge_fn, args=(value,), kwargs=kwargs)
74
75
76@tf_export("distribute.DistributedValues", v1=[])
77class DistributedValues(object):
78  """Base class for representing distributed values.
79
80  A subclass instance of `tf.distribute.DistributedValues` is created when
81  creating variables within a distribution strategy, iterating a
82  `tf.distribute.DistributedDataset` or through `tf.distribute.Strategy.run`.
83  This base class should never be instantiated directly.
84  `tf.distribute.DistributedValues` contains a value per replica. Depending on
85  the subclass, the values could either be synced on update, synced on demand,
86  or never synced.
87
88  `tf.distribute.DistributedValues` can be reduced to obtain single value across
89  replicas, as input into `tf.distribute.Strategy.run` or the per-replica values
90  inspected using `tf.distribute.Strategy.experimental_local_results`.
91
92  Example usage:
93
94  1. Created from a `tf.distribute.DistributedDataset`:
95
96  >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
97  >>> dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)
98  >>> dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))
99  >>> distributed_values = next(dataset_iterator)
100
101  2. Returned by `run`:
102
103  >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
104  >>> @tf.function
105  ... def run():
106  ...   ctx = tf.distribute.get_replica_context()
107  ...   return ctx.replica_id_in_sync_group
108  >>> distributed_values = strategy.run(run)
109
110  3. As input into `run`:
111
112  >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
113  >>> dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)
114  >>> dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))
115  >>> distributed_values = next(dataset_iterator)
116  >>> @tf.function
117  ... def run(input):
118  ...   return input + 1.0
119  >>> updated_value = strategy.run(run, args=(distributed_values,))
120
121  4. Reduce value:
122
123  >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
124  >>> dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)
125  >>> dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))
126  >>> distributed_values = next(dataset_iterator)
127  >>> reduced_value = strategy.reduce(tf.distribute.ReduceOp.SUM,
128  ...                                 distributed_values,
129  ...                                 axis = 0)
130
131  5. Inspect local replica values:
132
133  >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
134  >>> dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)
135  >>> dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))
136  >>> per_replica_values = strategy.experimental_local_results(
137  ...    distributed_values)
138  >>> per_replica_values
139  (<tf.Tensor: shape=(1,), dtype=float32, numpy=array([5.], dtype=float32)>,
140   <tf.Tensor: shape=(1,), dtype=float32, numpy=array([6.], dtype=float32)>)
141
142  """
143
144  def __init__(self, values):
145    """Should only be called by subclass __init__."""
146    self._values = tuple(values)
147
148  def _get(self):
149    """Returns the value for the current device or raises a ValueError."""
150    replica_id = values_util.get_current_replica_id_as_int()
151    if replica_id is None:
152      return self._get_cross_replica()
153    else:
154      return self._values[replica_id]
155
156  def _get_cross_replica(self):
157    raise NotImplementedError(
158        "This method should be overridden by sub-classes which support cross-"
159        "replica accesses.")
160
161  def _get_on_device_or_primary(self):
162    """Returns value in same replica or device if possible, else the _primary."""
163    replica_id = values_util.get_current_replica_id_as_int()
164    if replica_id is None:
165      # Try to find a value on the current device.
166      current_device = device_util.canonicalize(device_util.current())
167      for value in self._values:
168        if device_util.canonicalize(value.device) == current_device:
169          return value
170      return self._primary
171    else:
172      return self._values[replica_id]
173
174  @property
175  def _primary(self):
176    """Returns a representative component."""
177    return self._values[0]
178
179  @property
180  def _devices(self):
181    return tuple(v.device for v in self._values)
182
183  def __str__(self):
184    debug_str = ",\n".join(
185        "  %d: %s" % (i, v) for i, v in enumerate(self._values))
186    return "%s:{\n%s\n}" % (self.__class__.__name__, debug_str)
187
188  def __repr__(self):
189    debug_repr = ",\n".join(
190        "  %d: %r" % (i, v) for i, v in enumerate(self._values))
191    return "%s:{\n%s\n}" % (self.__class__.__name__, debug_repr)
192
193
194# NOTE(josh11b,apassos): It would be great if we could inspect the values this was
195# initialized with and use that to generate the overloaded operators here.
196# Unfortunately, Python's rules for special methods don't allow this, see
197# https://docs.python.org/3/reference/datamodel.html#special-method-names
198# "if a class defines a method named __getitem__(), and x is an instance of
199# this class, then x[i] is roughly equivalent to type(x).__getitem__(x, i)."
200# In particular, these special methods don't go through __getattr__, and
201# it will only use those methods if they are defined in the class, not the
202# object.
203class DistributedDelegate(DistributedValues):
204  """A map from device to values; acts as the same type as the values."""
205
206  def __getattr__(self, name):
207    # The '_use_resource_variables' and the attrs starts with '_self' are used
208    # for restoring the saved_model proto, and '_attribute_sentinel' is used for
209    # Layer tracking. At the point these attrs are queried, the variable has not
210    # been initialized. Thus it should not query those of the underlying
211    # components.
212    if name.startswith("_self_") or name in ("_use_resource_variables",
213                                             "_attribute_sentinel",
214                                             "_distributed_container"):
215      return super(DistributedDelegate, self).__getattr__(name)
216
217    # This allows copy.copy(DistributedDelegate). When copying an object,
218    # copy.copy doesn't invoke its __init__ method, instead it makes a new
219    # empty object, then copies the attributes over. copy.copy looks for
220    # attributes like "__getstate__" in case the object implements its custom
221    # copying. Since DistributedDelegate doesn't have those attributes defined,
222    # __getattr__ will be invoked, which tries to access "_values" attributes,
223    # but that doesn't exist either because this is an empty object, and again
224    # __getattr__ is invoked, leading to an infinite recursion.
225    if name == "_values":
226      raise AttributeError()
227
228    # TODO(priyag): This needs to be made robust against pitfalls from mix use
229    # __getattr__ and @property. See b/120402273.
230    return getattr(self._get(), name)
231
232  @property
233  def values(self):
234    """Returns the per replica values."""
235    return self._values
236
237  def _get_as_operand(self):
238    """Returns the value for operations for the current device.
239
240    Some implementations, e.g. `TPUMirroredVariable`, are not able to return the
241    value type within a replica context. They can, however, return a value that
242    can be used by the operations below.
243    """
244    return self._get()
245
246  # pylint: disable=multiple-statements
247  def __add__(self, o):
248    return self._get_as_operand() + o
249
250  def __radd__(self, o):
251    return o + self._get_as_operand()
252
253  def __sub__(self, o):
254    return self._get_as_operand() - o
255
256  def __rsub__(self, o):
257    return o - self._get_as_operand()
258
259  def __mul__(self, o):
260    return self._get_as_operand() * o
261
262  def __rmul__(self, o):
263    return o * self._get_as_operand()
264
265  def __truediv__(self, o):
266    return self._get_as_operand() / o
267
268  def __rtruediv__(self, o):
269    return o / self._get_as_operand()
270
271  def __floordiv__(self, o):
272    return self._get_as_operand() // o
273
274  def __rfloordiv__(self, o):
275    return o // self._get_as_operand()
276
277  def __mod__(self, o):
278    return self._get_as_operand() % o
279
280  def __rmod__(self, o):
281    return o % self._get_as_operand()
282
283  def __lt__(self, o):
284    return self._get_as_operand() < o
285
286  def __le__(self, o):
287    return self._get_as_operand() <= o
288
289  def __gt__(self, o):
290    return self._get_as_operand() > o
291
292  def __ge__(self, o):
293    return self._get_as_operand() >= o
294
295  def __and__(self, o):
296    return self._get_as_operand() & o
297
298  def __rand__(self, o):
299    return o & self._get_as_operand()
300
301  def __or__(self, o):
302    return self._get_as_operand() | o
303
304  def __ror__(self, o):
305    return o | self._get_as_operand()
306
307  def __xor__(self, o):
308    return self._get_as_operand() ^ o
309
310  def __rxor__(self, o):
311    return o ^ self._get_as_operand()
312
313  def __getitem__(self, o):
314    return self._get_as_operand()[o]
315
316  def __pow__(self, o, modulo=None):
317    return pow(self._get_as_operand(), o, modulo)
318
319  def __rpow__(self, o):
320    return pow(o, self._get_as_operand())
321
322  def __invert__(self):
323    return ~self._get_as_operand()
324
325  def __neg__(self):
326    return -self._get_as_operand()
327
328  def __abs__(self):
329    return abs(self._get_as_operand())
330
331  def __div__(self, o):
332    try:
333      return self._get_as_operand().__div__(o)
334    except AttributeError:
335      # See https://docs.python.org/3/library/constants.html#NotImplemented
336      return NotImplemented
337
338  def __rdiv__(self, o):
339    try:
340      return self._get_as_operand().__rdiv__(o)
341    except AttributeError:
342      # See https://docs.python.org/3/library/constants.html#NotImplemented
343      return NotImplemented
344
345  def __matmul__(self, o):
346    try:
347      return self._get_as_operand().__matmul__(o)
348    except AttributeError:
349      # See https://docs.python.org/3/library/constants.html#NotImplemented
350      return NotImplemented
351
352  def __rmatmul__(self, o):
353    try:
354      return self._get_as_operand().__rmatmul__(o)
355    except AttributeError:
356      # See https://docs.python.org/3/library/constants.html#NotImplemented
357      return NotImplemented
358
359  # TODO(josh11b): Even more operator overloads.
360
361
362class PerReplica(DistributedValues, composite_tensor.CompositeTensor):
363  """Holds a map from replica to unsynchronized values."""
364
365  @property
366  def _type_spec(self):
367    return PerReplicaSpec(
368        *(type_spec.type_spec_from_value(v) for v in self._values))
369
370  @property
371  def values(self):
372    """Returns the per replica values."""
373    return self._values
374
375
376class PerReplicaSpec(type_spec.TypeSpec):
377  """Type specification for a `PerReplica`."""
378
379  __slots__ = ["_value_specs"]
380
381  value_type = property(lambda self: PerReplica)
382
383  def __init__(self, *value_specs):
384    self._value_specs = tuple(value_specs)
385
386  def _serialize(self):
387    return self._value_specs
388
389  @property
390  def _component_specs(self):
391    return self._value_specs
392
393  def _to_components(self, value):
394    replica_context = ds_context.get_replica_context()
395    if replica_context is not None and replica_context.num_replicas_in_sync > 1:
396      raise ValueError(
397          "Flattening a PerReplica to components is not supported in replica "
398          "context.")
399    return value._values  # pylint: disable=protected-access
400
401  def _from_components(self, tensor_list):
402    return PerReplica(tensor_list)
403
404
405# Note that unlike PerReplica, Mirrored values inherit from
406# DistributedDelegate and so can be used directly in cross-replica mode.
407# TODO(tomhennigan) Should this extend CompositeTensor?
408class Mirrored(DistributedDelegate):
409  """Holds a map from replica to values which are kept in sync."""
410
411  def _get_cross_replica(self):
412    return self._get_on_device_or_primary()
413
414  def _as_graph_element(self):
415    obj = self._get()
416    conv_fn = getattr(obj, "_as_graph_element", None)
417    if conv_fn and callable(conv_fn):
418      return conv_fn()
419    return obj
420
421
422class DistributedVarOp(object):
423  """A class that looks like `tf.Operation`."""
424
425  def __init__(self, name, graph, traceback, typ):
426    self.name = name
427    self.graph = graph
428    self.traceback = traceback
429    self.type = typ
430
431  def __eq__(self, o):
432    if not isinstance(o, self.__class__):
433      raise NotImplementedError
434    return (self.name == o.name and self.graph == o.graph and
435            self.traceback == o.traceback and self.type == o.type)
436
437  def __hash__(self):
438    return hash((self.name, self.graph, tuple(self.traceback), self.type))
439
440
441class DistributedVariable(DistributedDelegate, variables_lib.Variable,
442                          core.Tensor):
443  """Holds a map from replica to variables."""
444
445  def __init__(self, strategy, values, aggregation, var_policy=None):
446    if (aggregation == variables_lib.VariableAggregation.MEAN and
447        not values[0].dtype.is_floating):
448      raise ValueError(
449          "creating distributed tf.Variable with aggregation=MEAN and a "
450          "non-floating dtype is not supported, please use a different "
451          "aggregation or dtype")
452    self._distribute_strategy = strategy
453    self._aggregation = aggregation
454    super(DistributedVariable, self).__init__(values)
455    self._common_name = self._primary.name.split(":")[0]
456    # Use a weakref to make it easy to map from the contained values
457    # to the container without introducing a reference cycle.
458    for v in values:
459      v._distributed_container = weakref.ref(self)  # pylint: disable=protected-access
460
461    # Packed variable is used to reduce the overhead of function execution.
462    # For a DistributedVariable, only one variable handle is captured into a
463    # function graph. It's only supported in eager mode.
464    if ops.executing_eagerly_outside_functions() and getattr(
465        strategy, "_enable_packed_variable_in_eager_mode", False):
466      name = "%s/packed/" % self._common_name
467      self._packed_var = packed.PackedDistributedVariable(values, name=name)
468    else:
469      self._packed_var = None
470
471    # tf.keras keeps track of variables initialized using this attribute. When
472    # tf.keras gets the default session, it initializes all uninitialized vars.
473    # We need to make _keras_initialized a member of DistributedVariable because
474    # without this it will use `__getattr__` which will delegate to a component
475    # variable.
476    self._keras_initialized = False
477    # Typically, a `DistributedVariable`'s initializer is composed of the
478    # initializers of the components variables. However, in some cases, such as
479    # when restoring from a checkpoint, we may set the _initializer_op
480    # property on the entire `DistributedVariable`.
481    self._initializer_op = None
482    # Set a VariablePolicy which decides how we replicate/aggregate the given
483    # variable.
484    self._policy = var_policy
485
486  def __deepcopy__(self, memo):
487    """Perform a deepcopy of the `DistributedVariable`.
488
489    Unlike the deepcopy of a regular tf.Variable, this keeps the original
490    strategy and devices of the `DistributedVariable`.  To avoid confusion
491    with the behavior of deepcopy on a regular `Variable` (which does
492    copy into new devices), we only allow a deepcopy of a `DistributedVariable`
493    within its originating strategy scope.
494
495    Args:
496      memo: The memoization object for `deepcopy`.
497
498    Returns:
499      A deep copy of the current `DistributedVariable`.
500
501    Raises:
502      RuntimeError: If trying to deepcopy into a different strategy.
503    """
504    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
505      new_values = []
506
507      for value in self._values:
508        with ops.device(value.device):
509          new_values.append(copy.deepcopy(value, memo))
510
511    copied_variable = type(self)(
512        strategy=self._distribute_strategy,
513        values=new_values,
514        aggregation=self._aggregation,
515        var_policy=copy.deepcopy(self._policy, memo))
516
517    memo[id(self)] = copied_variable
518
519    return copied_variable
520
521  def _use_packed_variable(self):
522    # Don't use packed variable when under a SaveContext to avoid explicit
523    # device placement on variable consuming ops.
524    return self._packed_var is not None and not save_context.in_save_context()
525
526  def is_initialized(self, name=None):
527    """Identifies if all the component variables are initialized.
528
529    Args:
530      name: Name of the final `logical_and` op.
531
532    Returns:
533      The op that evaluates to True or False depending on if all the
534      component variables are initialized.
535    """
536    if values_util.is_saving_non_distributed():
537      return self._primary.is_initialized()
538    if self._use_packed_variable():
539      return self._packed_var.is_initialized()
540    result = self._primary.is_initialized()
541    # We iterate through the list of values except the last one to allow us to
542    # name the final `logical_and` op the same name that is passed by the user
543    # to the `is_initialized` op. For distributed variables, the
544    # `is_initialized` op is a `logical_and` op.
545    for v in self._values[1:-1]:
546      result = math_ops.logical_and(result, v.is_initialized())
547    result = math_ops.logical_and(
548        result, self._values[-1].is_initialized(), name=name)
549    return result
550
551  @property
552  def initializer(self):
553    if values_util.is_saving_non_distributed():
554      return self._primary.initializer
555    if self._initializer_op:
556      init_op = self._initializer_op
557    else:
558      # return grouped ops of all the var initializations of component values of
559      # the mirrored variable
560      init_op = control_flow_ops.group(
561          tuple(v.initializer for v in self._values))
562    return init_op
563
564  def initialized_value(self):
565    return self._get_on_device_or_primary().initialized_value()
566
567  @property
568  def initial_value(self):
569    return self._get_on_device_or_primary().initial_value
570
571  @property
572  def constraint(self):
573    return self._primary.constraint
574
575  @property
576  def graph(self):
577    return self._primary.graph
578
579  @property
580  def _shared_name(self):
581    return self._common_name
582
583  @property
584  def _unique_id(self):
585    return self._primary._unique_id  # pylint: disable=protected-access
586
587  @property
588  def _graph_key(self):
589    """Lets Optimizers know which graph this variable is from."""
590    return self._primary._graph_key  # pylint: disable=protected-access
591
592  @property
593  def name(self):
594    return self._primary.name
595
596  @property
597  def dtype(self):
598    return self._primary.dtype
599
600  @property
601  def shape(self):
602    return self._primary.shape
603
604  @property
605  def synchronization(self):
606    return self._primary.synchronization
607
608  @property
609  def aggregation(self):
610    return self._aggregation
611
612  @property
613  def _packed_variable(self):
614    if self._use_packed_variable():
615      return self._packed_var
616    return None
617
618  @property
619  def handle(self):
620    if values_util.is_saving_non_distributed():
621      return self._primary.handle
622    replica_id = values_util.get_current_replica_id_as_int()
623    if replica_id is None:
624      raise ValueError("`handle` is not available outside the replica context"
625                       " or a `tf.distribute.Strategy.update()` call.")
626    else:
627      if self._use_packed_variable():
628        return self._packed_var.handle
629      return self._values[replica_id].handle
630
631  def eval(self, session=None):
632    return self._get_on_device_or_primary().eval(session)
633
634  @property
635  def _save_slice_info(self):
636    return self._primary._save_slice_info  # pylint: disable=protected-access
637
638  def _get_save_slice_info(self):
639    return self._primary._get_save_slice_info()  # pylint: disable=protected-access
640
641  def _set_save_slice_info(self, save_slice_info):
642    for v in self._values:
643      v._set_save_slice_info(save_slice_info)  # pylint: disable=protected-access
644
645  @property
646  def device(self):
647    return self._get_on_device_or_primary().device
648
649  @property
650  def trainable(self):
651    return self._primary.trainable
652
653  @property
654  def distribute_strategy(self):
655    return self._distribute_strategy
656
657  def get_shape(self):
658    return self._primary.get_shape()
659
660  def to_proto(self, export_scope=None):
661    return self._primary.to_proto(export_scope=export_scope)
662
663  @property
664  def op(self):
665    if values_util.is_saving_non_distributed():
666      return self._primary.op
667    # We want cross-replica code that does some var.op.X calls
668    # to work (even if the current device isn't in self._devices), but
669    # other uses of var.op in a cross-replica context to fail.
670    if ds_context.in_cross_replica_context():
671      return DistributedVarOp(self._primary.op.name, self._primary.op.graph,
672                              self._primary.op.traceback, self._primary.op.type)
673    return self._get().op
674
675  @property
676  def _in_graph_mode(self):
677    return self._primary._in_graph_mode  # pylint: disable=protected-access
678
679  def _get_replica(self, replica_id):
680    """Returns the value on a device with the given replica_id."""
681    if self._use_packed_variable():
682      return self._packed_var.on_device(self._devices[replica_id])
683    return self._values[replica_id]
684
685  def _get(self):
686    """Returns the value for the current device or raises a ValueError."""
687    if values_util.is_saving_non_distributed():
688      return self._primary
689    replica_id = values_util.get_current_replica_id_as_int()
690    if replica_id is None:
691      return self._get_cross_replica()
692    else:
693      return self._get_replica(replica_id)
694
695  def _get_on_device_or_primary(self):
696    """Returns value in same replica or device if possible, else the _primary."""
697    if values_util.is_saving_non_distributed():
698      return self._primary
699    replica_id = values_util.get_current_replica_id_as_int()
700    if replica_id is None:
701      # Try to find a value on the current device.
702      current_device = device_util.canonicalize(device_util.current())
703      for i, value in enumerate(self._values):
704        if device_util.canonicalize(value.device) == current_device:
705          return self._get_replica(i)
706      return self._get_replica(0)
707    else:
708      return self._get_replica(replica_id)
709
710  def read_value(self):
711    if values_util.is_saving_non_distributed():
712      return self._primary.read_value()
713    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
714      return array_ops.identity(self._get())
715
716  def value(self):
717    if values_util.is_saving_non_distributed():
718      return self._primary.value()
719    if self._policy:
720      return self._policy.value(self)
721    return self._get_on_device_or_primary().value()
722
723  def numpy(self):
724    if context.executing_eagerly():
725      return self.read_value().numpy()
726    else:
727      raise NotImplementedError(
728          "numpy() is only available when eager execution is enabled.")
729
730  def assign_sub(self, value, use_locking=False, name=None, read_value=True):
731    if values_util.is_saving_non_distributed():
732      return self._primary.assign_sub(value, use_locking, name, read_value)
733    if self._policy:
734      return self._policy.assign_sub(
735          self,
736          value,
737          use_locking=use_locking,
738          name=name,
739          read_value=read_value)
740    return values_util.on_write_assign_sub(
741        self, value, use_locking=use_locking, name=name, read_value=read_value)
742
743  def assign_add(self, value, use_locking=False, name=None, read_value=True):
744    if values_util.is_saving_non_distributed():
745      return self._primary.assign_add(value, use_locking, name, read_value)
746    if self._policy:
747      return self._policy.assign_add(
748          self,
749          value,
750          use_locking=use_locking,
751          name=name,
752          read_value=read_value)
753    return values_util.on_write_assign_add(
754        self, value, use_locking=use_locking, name=name, read_value=read_value)
755
756  def assign(self, value, use_locking=False, name=None, read_value=True):
757    if values_util.is_saving_non_distributed():
758      return self._primary.assign(value, use_locking, name, read_value)
759    if self._policy:
760      return self._policy.assign(
761          self,
762          value,
763          use_locking=use_locking,
764          name=name,
765          read_value=read_value)
766    return values_util.on_write_assign(
767        self, value, use_locking=use_locking, name=name, read_value=read_value)
768
769  def scatter_sub(self, sparse_delta, use_locking=False, name=None):
770    if values_util.is_saving_non_distributed():
771      return self._primary.scatter_sub(sparse_delta, use_locking, name)
772    if self._policy:
773      return self._policy.scatter_sub(
774          self, sparse_delta, use_locking=use_locking, name=name)
775    return values_util.scatter_sub(
776        self, sparse_delta, use_locking=use_locking, name=name)
777
778  def scatter_add(self, sparse_delta, use_locking=False, name=None):
779    if values_util.is_saving_non_distributed():
780      return self._primary.scatter_add(sparse_delta, use_locking, name)
781    if self._policy:
782      return self._policy.scatter_add(
783          self, sparse_delta, use_locking=use_locking, name=name)
784    return values_util.scatter_add(
785        self, sparse_delta, use_locking=use_locking, name=name)
786
787  def scatter_mul(self, sparse_delta, use_locking=False, name=None):
788    if values_util.is_saving_non_distributed():
789      return self._primary.scatter_mul(sparse_delta, use_locking, name)
790    if self._policy:
791      return self._policy.scatter_mul(
792          self, sparse_delta, use_locking=use_locking, name=name)
793    return values_util.scatter_mul(
794        self, sparse_delta, use_locking=use_locking, name=name)
795
796  def scatter_div(self, sparse_delta, use_locking=False, name=None):
797    if values_util.is_saving_non_distributed():
798      return self._primary.scatter_div(sparse_delta, use_locking, name)
799    if self._policy:
800      return self._policy.scatter_div(
801          self, sparse_delta, use_locking=use_locking, name=name)
802    return values_util.scatter_div(
803        self, sparse_delta, use_locking=use_locking, name=name)
804
805  def scatter_min(self, sparse_delta, use_locking=False, name=None):
806    if values_util.is_saving_non_distributed():
807      return self._primary.scatter_min(sparse_delta, use_locking, name)
808    if self._policy:
809      return self._policy.scatter_min(
810          self, sparse_delta, use_locking=use_locking, name=name)
811    return values_util.scatter_min(
812        self, sparse_delta, use_locking=use_locking, name=name)
813
814  def scatter_max(self, sparse_delta, use_locking=False, name=None):
815    if values_util.is_saving_non_distributed():
816      return self._primary.scatter_max(sparse_delta, use_locking, name)
817    if self._policy:
818      return self._policy.scatter_max(
819          self, sparse_delta, use_locking=use_locking, name=name)
820    return values_util.scatter_max(
821        self, sparse_delta, use_locking=use_locking, name=name)
822
823  def scatter_update(self, sparse_delta, use_locking=False, name=None):
824    if values_util.is_saving_non_distributed():
825      return self._primary.scatter_update(sparse_delta, use_locking, name)
826    if self._policy:
827      return self._policy.scatter_update(
828          self, sparse_delta, use_locking=use_locking, name=name)
829    return values_util.scatter_update(
830        self, sparse_delta, use_locking=use_locking, name=name)
831
832  def _gather_saveables_for_checkpoint(self):
833    """Overrides Trackable method.
834
835    This allows both name-based and object-based save and restore of
836    DistributedVariables.
837
838    Returns:
839      A dictionary mapping attribute names to `SaveableObject` factories.
840    """
841
842    def _saveable_factory(name=self._common_name):
843      return _DistributedVariableSaveable(self, self._primary, name)
844
845    return {trackable.VARIABLE_VALUE_KEY: _saveable_factory}
846
847  def _as_graph_element(self):
848    if values_util.is_saving_non_distributed():
849      return self._primary._as_graph_element()  # pylint: disable=protected-access
850    if self._policy:
851      return self._policy._as_graph_element(self)  # pylint: disable=protected-access
852
853    raise NotImplementedError("No policy set for calling _as_graph_element.")
854
855  def _get_cross_replica(self):
856    if values_util.is_saving_non_distributed():
857      return self._primary
858    if self._policy:
859      return self._policy._get_cross_replica(self)  # pylint: disable=protected-access
860
861    raise NotImplementedError(
862        "This method should be overridden by sub-classes which support cross-"
863        "replica accesses.")
864
865  def _update_cross_replica(self, update_fn, value, **kwargs):
866    """Applies updates across replicas.
867
868    Args:
869      update_fn: A callable to pass to `strategy.extended.update` to update the
870        variable. It should has the same signature as `Variable.assign()`.
871      value: value to be passed to `update_fn`.
872      **kwargs: remaining arguments to `update_fn`.
873
874    Returns:
875      Updated variable or `tf.Operation`.
876    """
877    values_util.mark_as_unsaveable()
878    return self.distribute_strategy.extended.update(
879        self, update_fn, args=(value,), kwargs=kwargs, group=True)
880
881  def _update_replica(self, update_fn, value, **kwargs):
882    """Applies updates in one replica.
883
884    Args:
885      update_fn: A callable to update the variable. It should has the same
886        signature as `Variable.assign()`.
887      value: value to be passed to `update_fn`.
888      **kwargs: remaining arguments to `update_fn`.
889
890    Returns:
891      Updated variable or `tf.Operation`.
892    """
893    if self._policy:
894      return self._policy._update_replica(self, update_fn, value, **kwargs)  # pylint: disable=protected-access
895    raise NotImplementedError("should be implemented by subclass.")
896
897  def _update(self, update_fn, value, **kwargs):
898    """Applies updates depending on the context.
899
900    The method calls `_update_replica` in replica context,
901    `_update_cross_replica` in cross replica context, and `update_fn` in update
902    context.
903
904    If `read_value` is True, the method returns the updated Variable. If
905    `read_value` is False, the method returns the update `tf.Operation`.
906
907    Args:
908      update_fn: A callable to pass to `strategy.extended.update` to update the
909        variable. It should have the same signature as `Variable.assign()`.
910      value: value to be passed to `update_fn`.
911      **kwargs: keyword arguments to `update_fn`.
912
913    Returns:
914      Updated variable or `tf.Operation`.
915
916    """
917    if values_util.is_saving_non_distributed():
918      return update_fn(self._primary, value, **kwargs)
919    with ds_context.enter_or_assert_strategy(self.distribute_strategy):
920      if ds_context.in_cross_replica_context():
921        update_replica_id = distribute_lib.get_update_replica_id()
922        if update_replica_id is not None:
923          replica_value = self._get_replica(update_replica_id)
924          return update_fn(replica_value, value, **kwargs)
925        return self._update_cross_replica(update_fn, value, **kwargs)
926      else:
927        values_util.assert_replica_context(self.distribute_strategy)
928        return self._update_replica(update_fn, value, **kwargs)
929
930  def _should_act_as_resource_variable(self):
931    """Pass resource_variable_ops.is_resource_variable check."""
932    pass
933
934  def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
935    """Converts a variable to a tensor."""
936    if values_util.is_saving_non_distributed():
937      return ops.convert_to_tensor(
938          self._primary, dtype=dtype, name=name, as_ref=as_ref)
939    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
940      return ops.convert_to_tensor(
941          self._get(), dtype=dtype, name=name, as_ref=as_ref)
942
943  def _map_resources(self, save_options):
944    """For implementing `Trackable`."""
945    # Initialize for self._primary first, so that obj_map[self._primary] and
946    # resource_map[self._primary.handle] contain mapped values.
947    obj_map, resource_map = self._primary._map_resources(save_options)  # pylint:disable=protected-access
948    for v in [v for v in self._values if v != self._primary]:
949
950      if (save_options.experimental_variable_policy  # pylint:disable=protected-access
951          ._expand_distributed_variables()):
952        v_obj_map, v_resource_map = v._map_resources(save_options)  # pylint:disable=protected-access
953        obj_map.update(v_obj_map)
954        resource_map.update(v_resource_map)
955      else:
956        obj_map[v] = obj_map[self._primary]
957        resource_map[v.handle] = resource_map[self._primary.handle]
958    obj_map[self] = obj_map[self._primary]
959    resource_map[self] = resource_map[self._primary.handle]
960    if self._packed_var is not None:
961      resource_map[self._packed_var.packed_handle] = resource_map[
962          self._primary.handle]
963    return obj_map, resource_map
964
965  def _write_object_proto(self, proto, options):
966    """Update a SavedObject proto for the caller.
967
968    If a DistributedVariable object supports this method, it will be called when
969    saving with a pre-built `SavedObject` proto representing the object, plus an
970    instance of `SaveOptions`. This method is then free to modify that proto
971    instance.
972
973    `DistributedVariable` with `AUTO` or `ON_WRITE` synchronization optionally
974    write out information about their components to the
975    `experimental_distributed_variable_components` field of a
976    `SavedVariable` (depending on the `SaveOptions` variable policy).
977
978    Args:
979      proto: A pre-built `SavedObject` proto for this object. It is assumed this
980        will be a `SavedVariable` instance.
981      options: A `SaveOptions` instance.
982    """
983    if self._policy:
984      if self._policy._is_mirrored():  # pylint: disable=protected-access
985        self._policy._write_object_proto(self, proto, options)  # pylint: disable=protected-access
986    else:
987      self._write_object_proto(proto, options)
988
989
990# We extend from `saveable_object.SaveableObject` instead of
991# `saveable_object_util.ResourceVariableSaveable` since we need to read the
992# value of ONREAD variables when saving. `SaveableObject` provides a way to
993# specify the function to run to get the value of the variable or tensor at
994# saving time. We can use this for both ON_READ and ON_WRITE variables.
995# TODO(b/164586507): Consolidate ON_WRITE and ON_READ saving/restoring logic
996# if possible.
997class _DistributedVariableSaveable(saveable_object.SaveableObject):
998  """Class for defining how to restore a DistributedVariable."""
999
1000  def __init__(self, distributed_variable, primary_variable, name):
1001    self._distributed_variable = distributed_variable
1002    if not self._distributed_variable._policy:
1003      raise ValueError("VariablePolicy has not been set for the distributed "
1004                       "variable.")
1005    tensor, spec = distributed_variable._policy.get_saveable(
1006        distributed_variable, primary_variable, name)
1007    super(_DistributedVariableSaveable, self).__init__(tensor, spec, name)
1008
1009  def restore(self, restored_tensors, restored_shapes):
1010    """Restore the same value into all variables."""
1011    tensor, = restored_tensors
1012    return self._distributed_variable._policy.get_restore_ops(  # pylint: disable=protected-access
1013        self._distributed_variable, tensor)
1014
1015
1016class _MirroredSaveable(saveable_object.SaveableObject):
1017  """Class for defining how to restore a MirroredVariable."""
1018
1019  def __init__(self, mirrored_variable, primary_variable, name):
1020    self._mirrored_variable = mirrored_variable
1021    tensor, spec = values_util.get_on_write_saveable(self._mirrored_variable,
1022                                                     primary_variable,
1023                                                     name)
1024    super(_MirroredSaveable, self).__init__(tensor, spec, name)
1025
1026  def restore(self, restored_tensors, restored_shapes):
1027    """Restore the same value into all variables."""
1028    tensor, = restored_tensors
1029    return values_util.get_on_write_restore_ops(self._mirrored_variable,
1030                                                tensor)
1031
1032
1033class MirroredVariable(DistributedVariable, Mirrored):
1034  """Holds a map from replica to variables whose values are kept in sync."""
1035
1036  def _update_replica(self, update_fn, value, **kwargs):
1037    return _on_write_update_replica(self, update_fn, value, **kwargs)
1038
1039  def scatter_min(self, *args, **kwargs):
1040    if values_util.is_saving_non_distributed():
1041      return self._primary.scatter_min(*args, **kwargs)
1042    if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
1043        self._aggregation != vs.VariableAggregation.NONE):
1044      raise NotImplementedError(values_util.scatter_error_msg.format(
1045          op_name="scatter_min", aggregation=self._aggregation))
1046    return super(MirroredVariable, self).scatter_min(*args, **kwargs)
1047
1048  def scatter_max(self, *args, **kwargs):
1049    if values_util.is_saving_non_distributed():
1050      return self._primary.scatter_max(*args, **kwargs)
1051    if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
1052        self._aggregation != vs.VariableAggregation.NONE):
1053      raise NotImplementedError(values_util.scatter_error_msg.format(
1054          op_name="scatter_max", aggregation=self._aggregation))
1055    return super(MirroredVariable, self).scatter_max(*args, **kwargs)
1056
1057  def scatter_update(self, *args, **kwargs):
1058    if values_util.is_saving_non_distributed():
1059      return self._primary.scatter_update(*args, **kwargs)
1060    if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
1061        self._aggregation != vs.VariableAggregation.NONE):
1062      raise NotImplementedError(values_util.scatter_error_msg.format(
1063          op_name="scatter_update", aggregation=self._aggregation))
1064    return super(MirroredVariable, self).scatter_update(*args, **kwargs)
1065
1066  def _get_cross_replica(self):
1067    # Return identity, to avoid directly exposing the variable to the user and
1068    # allowing it to be modified by mistake.
1069    return array_ops.identity(Mirrored._get_cross_replica(self))
1070
1071  def _as_graph_element(self):
1072    return self._get_on_device_or_primary()._as_graph_element()  # pylint: disable=protected-access
1073
1074  def _gather_saveables_for_checkpoint(self):
1075    """Overrides Trackable method.
1076
1077    This allows both name-based and object-based save and restore of
1078    MirroredVariables.
1079
1080    Returns:
1081      A dictionary mapping attribute names to `SaveableObject` factories.
1082    """
1083
1084    def _saveable_factory(name=self._common_name):
1085      return _MirroredSaveable(self, self._primary, name)
1086
1087    return {trackable.VARIABLE_VALUE_KEY: _saveable_factory}
1088
1089  def _write_object_proto(self, proto, options):
1090    """Update a SavedObject proto for the caller.
1091
1092    If a DistributedVariable object supports this method, it will be called when
1093    saving with a pre-built `SavedObject` proto representing the object, plus an
1094    instance of `SaveOptions`. This method is then free to modify that proto
1095    instance.
1096
1097    `DistributedVariable` with `AUTO` or `ON_WRITE` synchronization optionally
1098    write out information about their components to the
1099    `experimental_distributed_variable_components` field of a
1100    `SavedVariable` (depending on the `SaveOptions` variable policy).
1101
1102    Args:
1103      proto: A pre-built `SavedObject` proto for this object. It is assumed this
1104        will be a `SavedVariable` instance.
1105      options: A `SaveOptions` instance.
1106    """
1107    values_util.write_object_proto(self, proto, options)
1108
1109  def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
1110    """Converts a variable to a tensor."""
1111    # TODO(b/154017756): Make _dense_var_to_tensor consistent between ON_READ
1112    # and ON_WRITE.
1113    # Try to avoid assignments to and other mutations of MirroredVariable
1114    # state except through a DistributionStrategy.extended.update() or any of
1115    # the `assign*` and `scatter*` calls.
1116    if as_ref:
1117      # A TF 1.x case where the variable is a boolean variable and used like:
1118      # tf.cond(v, true_fn, false_fn).
1119      raise ValueError(
1120          "You may be using variable created under distribute strategy in TF "
1121          "1.x control flows. Try explicitly converting the variable to Tensor "
1122          "using variable.read_value(), or switch to TF 2.x.")
1123    return ops.convert_to_tensor(
1124        self._get(), dtype=dtype, name=name, as_ref=as_ref)
1125
1126
1127class _SyncOnReadSaveable(saveable_object.SaveableObject):
1128  """Class for defining how to restore a SyncOnReadVariable."""
1129
1130  def __init__(self, sync_on_read_variable, name):
1131    self._sync_on_read_variable = sync_on_read_variable
1132    tensor, spec = values_util.get_on_read_saveable(
1133        sync_on_read_variable, sync_on_read_variable._primary, name)
1134
1135    super(_SyncOnReadSaveable, self).__init__(tensor, spec, name)
1136
1137  def restore(self, restored_tensors, restored_shapes):
1138    """Restore the same value into all variables."""
1139    tensor, = restored_tensors
1140    return values_util.get_on_read_restore_ops(
1141        self._sync_on_read_variable, tensor,
1142        self._sync_on_read_variable.aggregation)
1143
1144
1145class SyncOnReadVariable(DistributedVariable):
1146  """Holds a map from replica to variables whose values are reduced on save."""
1147
1148  def _update_replica(self, update_fn, value, **kwargs):
1149    return update_fn(self._get_on_device_or_primary(), value, **kwargs)
1150
1151  # TODO(b/154017756): Make assign behaivor in cross replica context consistent
1152  # with MirroredVariable.
1153  def assign_sub(self, value, use_locking=False, name=None, read_value=True):
1154    if values_util.is_saving_non_distributed():
1155      return self._primary.assign_sub(value, use_locking, name, read_value)
1156    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
1157      if (ds_context.in_cross_replica_context() and
1158          not values_util.in_replica_update_context()):
1159        values_util.mark_as_unsaveable()
1160        return values_util.on_read_assign_sub_cross_replica(
1161            self, value, read_value=read_value)
1162      else:
1163        return super(SyncOnReadVariable,
1164                     self).assign_sub(value, use_locking, name, read_value)
1165
1166  def assign_add(self, value, use_locking=False, name=None, read_value=True):
1167    if values_util.is_saving_non_distributed():
1168      return self._primary.assign_add(value, use_locking, name, read_value)
1169    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
1170      if (ds_context.in_cross_replica_context() and
1171          not values_util.in_replica_update_context()):
1172        values_util.mark_as_unsaveable()
1173        return values_util.on_read_assign_add_cross_replica(
1174            self, value, read_value=read_value)
1175      else:
1176        return super(SyncOnReadVariable,
1177                     self).assign_add(value, use_locking, name, read_value)
1178
1179  def assign(self, value, use_locking=False, name=None, read_value=True):
1180    if values_util.is_saving_non_distributed():
1181      return self._primary.assign(value, use_locking, name, read_value)
1182    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
1183      if (ds_context.in_cross_replica_context() and
1184          not values_util.in_replica_update_context()):
1185        values_util.mark_as_unsaveable()
1186        return values_util.on_read_assign_cross_replica(
1187            self, value, read_value=read_value)
1188      else:
1189        return super(SyncOnReadVariable,
1190                     self).assign(value, use_locking, name, read_value)
1191
1192  def _scatter_not_implemented(self, method):
1193    raise NotImplementedError(
1194        "Variables with `synchronization=ON_READ` doesn't support `%s`" %
1195        method)
1196
1197  def scatter_sub(self, *args, **kwargs):
1198    if values_util.is_saving_non_distributed():
1199      return self._primary.scatter_sub(*args, **kwargs)
1200    self._scatter_not_implemented("scatter_sub")
1201
1202  def scatter_add(self, *args, **kwargs):
1203    if values_util.is_saving_non_distributed():
1204      return self._primary.scatter_add(*args, **kwargs)
1205    self._scatter_not_implemented("scatter_add")
1206
1207  def scatter_mul(self, *args, **kwargs):
1208    if values_util.is_saving_non_distributed():
1209      return self._primary.scatter_mul(*args, **kwargs)
1210    self._scatter_not_implemented("scatter_mul")
1211
1212  def scatter_div(self, *args, **kwargs):
1213    if values_util.is_saving_non_distributed():
1214      return self._primary.scatter_div(*args, **kwargs)
1215    self._scatter_not_implemented("scatter_div")
1216
1217  def scatter_min(self, *args, **kwargs):
1218    if values_util.is_saving_non_distributed():
1219      return self._primary.scatter_min(*args, **kwargs)
1220    self._scatter_not_implemented("scatter_min")
1221
1222  def scatter_max(self, *args, **kwargs):
1223    if values_util.is_saving_non_distributed():
1224      return self._primary.scatter_max(*args, **kwargs)
1225    self._scatter_not_implemented("scatter_max")
1226
1227  def scatter_update(self, *args, **kwargs):
1228    if values_util.is_saving_non_distributed():
1229      return self._primary.scatter_update(*args, **kwargs)
1230    self._scatter_not_implemented("scatter_update")
1231
1232  def value(self):
1233    if values_util.is_saving_non_distributed():
1234      return self._primary.value()
1235    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
1236      if (ds_context.in_cross_replica_context() and
1237          not values_util.in_replica_update_context()):
1238        if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
1239          return self._get_replica(0).value()
1240        return self._get_cross_replica()
1241      else:
1242        # _get_on_device_or_primary() returns a Variable.
1243        return self._get_on_device_or_primary().value()
1244
1245  def _get_cross_replica(self):
1246    if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
1247      # Consider returning a tensor value here to make the return value of
1248      # _get_cross_replica consistent.
1249      return self._get_replica(0)
1250    if self._aggregation == vs.VariableAggregation.SUM:
1251      values_util.mark_as_unsaveable()
1252    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
1253      return self._distribute_strategy.reduce(
1254          reduce_util.ReduceOp.from_variable_aggregation(self._aggregation),
1255          self,
1256          axis=None)
1257
1258  def _as_graph_element(self):
1259    if values_util.is_saving_non_distributed():
1260      return self._primary._as_graph_element()  # pylint: disable=protected-access
1261    # pylint: disable=protected-access
1262    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
1263      if ds_context.in_cross_replica_context():
1264        return ops.convert_to_tensor(self._get_cross_replica())
1265    return self._get()._as_graph_element()
1266
1267  def _gather_saveables_for_checkpoint(self):
1268    """Overrides Trackable method.
1269
1270    This allows both name-based and object-based save and restore of
1271    `SyncOnReadVariable`s.
1272
1273    Returns:
1274      A dictionary mapping attribute names to `SaveableObject` factories.
1275    """
1276
1277    def _saveable_factory(name=self._common_name):
1278      return _SyncOnReadSaveable(self, name)
1279
1280    return {trackable.VARIABLE_VALUE_KEY: _saveable_factory}
1281
1282  def _write_object_proto(self, proto, options):
1283    """Update a SavedObject proto for the caller.
1284
1285    If a DistributedVariable object supports this method, it will be called when
1286    saving with a pre-built `SavedObject` proto representing the object, plus an
1287    instance of `SaveOptions`. This method is then free to modify that proto
1288    instance.
1289
1290    `DistributedVariable` with `AUTO` or `ON_WRITE` synchronization optionally
1291    write out information about their components to the
1292    `experimental_distributed_variable_components` field of a
1293    `SavedVariable` (depending on the `SaveOptions` variable policy).
1294
1295    Args:
1296      proto: A pre-built `SavedObject` proto for this object. It is assumed this
1297        will be a `SavedVariable` instance.
1298      options: A `SaveOptions` instance.
1299    """
1300    pass
1301
1302
1303# Register a conversion functions which reads the value of the variable,
1304# allowing instances of the class to be used as tensors.
1305# DistributedVariable
1306def _tensor_conversion_distributed_var(var, dtype=None, name=None,
1307                                       as_ref=False):
1308  return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref)  # pylint: disable=protected-access
1309
1310
1311ops.register_tensor_conversion_function(DistributedVariable,
1312                                        _tensor_conversion_distributed_var)
1313
1314
1315# MirroredVariables
1316def _tensor_conversion_mirrored(var, dtype=None, name=None, as_ref=False):
1317  return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref)  # pylint: disable=protected-access
1318
1319
1320ops.register_tensor_conversion_function(MirroredVariable,
1321                                        _tensor_conversion_mirrored)
1322
1323
1324# Mirrored Values
1325def _tensor_conversion_mirrored_val(value, dtype=None, name=None, as_ref=False):
1326  return ops.convert_to_tensor(
1327      value._get(), dtype=dtype, name=name, as_ref=as_ref)  # pylint: disable=protected-access
1328
1329
1330ops.register_tensor_conversion_function(Mirrored,
1331                                        _tensor_conversion_mirrored_val)
1332
1333
1334# SyncOnReadVariables
1335def _tensor_conversion_sync_on_read(var, dtype=None, name=None, as_ref=False):
1336  return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref)  # pylint: disable=protected-access
1337
1338
1339ops.register_tensor_conversion_function(SyncOnReadVariable,
1340                                        _tensor_conversion_sync_on_read)
1341
1342
1343class VariablePolicy(object):
1344  """Policy defining synchronization and aggregation of a distributed variable.
1345
1346  Given `synchronization` and `aggregation` parameters set on a `tf.Variable`
1347  during variable creation within `tf.distribute` scope, `tf.distribute` creates
1348  an appropriate policy object and assigns it to the distributed variable. All
1349  variable operations are delegated to the respective policy object.
1350  """
1351
1352  def __init__(self, aggregation):
1353    self._aggregation = aggregation
1354
1355  def value(self):
1356    raise NotImplementedError(
1357        "This method should be overridden by sub-classes.")
1358
1359  def _is_mirrored(self):
1360    raise NotImplementedError(
1361        "This method should be overridden by sub-classes.")
1362
1363  def _as_graph_element(self, _):
1364    raise NotImplementedError(
1365        "This method should be overridden by sub-classes.")
1366
1367  def _get_cross_replica(self, var):
1368    raise NotImplementedError(
1369        "This method should be overridden by sub-classes.")
1370
1371  def _update_replica(self, var, update_fn, value, **kwargs):
1372    raise NotImplementedError(
1373        "This method should be overridden by sub-classes.")
1374
1375
1376class OnReadPolicy(VariablePolicy):
1377  """Policy defined for `tf.VariableSynchronization.ON_READ` synchronization.
1378
1379  This policy is created when `synchronization` is set to
1380  `tf.VariableSynchronization.ON_READ` and `aggregation` is set to any of the
1381  values allowed by the `tf.VariableAggregation` enum such as `NONE`, `SUM`,
1382  `MEAN` or `ONLY_FIRST_REPLICA`when creating a `tf.Variable` in `tf.distribute`
1383  scope.
1384  """
1385
1386  def _is_mirrored(self):
1387    return False
1388
1389  def value(self, var):
1390    with ds_context.enter_or_assert_strategy(var.distribute_strategy):
1391      if (ds_context.in_cross_replica_context() and
1392          not values_util.in_replica_update_context()):
1393        if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
1394          return var._get_replica(0).value()  # pylint: disable=protected-access
1395        return var._get_cross_replica()  # pylint: disable=protected-access
1396      else:
1397        return var._get_on_device_or_primary().value()  # pylint: disable=protected-access
1398
1399  def _as_graph_element(self, var):
1400    with ds_context.enter_or_assert_strategy(var.distribute_strategy):
1401      if ds_context.in_cross_replica_context():
1402        return ops.convert_to_tensor(var._get_cross_replica())   # pylint: disable=protected-access
1403    return var._get()._as_graph_element()   # pylint: disable=protected-access
1404
1405  def _get_cross_replica(self, var):
1406    if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
1407      return var._get_replica(0)  # pylint: disable=protected-access
1408    if self._aggregation == vs.VariableAggregation.SUM:
1409      values_util.mark_as_unsaveable()
1410    with ds_context.enter_or_assert_strategy(var.distribute_strategy):
1411      return var.distribute_strategy.reduce(
1412          reduce_util.ReduceOp.from_variable_aggregation(self._aggregation),
1413          var,
1414          axis=None)
1415
1416  def _update_replica(self, var, update_fn, value, **kwargs):
1417    return update_fn(var._get_on_device_or_primary(), value, **kwargs)  # pylint: disable=protected-access
1418
1419  def _scatter_not_implemented(self, method):
1420    raise NotImplementedError(
1421        "ON_READ variables doesn't support `%s` in cross replica context" %
1422        method)
1423
1424  def assign_sub(self, var, value, use_locking=False, name=None,
1425                 read_value=True):
1426    """Subtracts a value from this variable."""
1427    with ds_context.enter_or_assert_strategy(var.distribute_strategy):
1428      if (ds_context.in_cross_replica_context() and
1429          not values_util.in_replica_update_context()):
1430        values_util.mark_as_unsaveable()
1431        return values_util.on_read_assign_sub_cross_replica(
1432            var, value, read_value=read_value)
1433      else:
1434        return values_util.on_write_assign_sub(
1435            var, value, use_locking=use_locking, name=name,
1436            read_value=read_value)
1437
1438  def assign_add(self, var, value, use_locking=False, name=None,
1439                 read_value=True):
1440    """Adds a value to this variable."""
1441    with ds_context.enter_or_assert_strategy(var.distribute_strategy):
1442      if (ds_context.in_cross_replica_context() and
1443          not values_util.in_replica_update_context()):
1444        values_util.mark_as_unsaveable()
1445        return values_util.on_read_assign_add_cross_replica(
1446            var, value, read_value=read_value)
1447      else:
1448        return values_util.on_write_assign_add(
1449            var, value, use_locking=use_locking, name=name,
1450            read_value=read_value)
1451
1452  def assign(self, var, value, use_locking=False, name=None, read_value=True):
1453    with ds_context.enter_or_assert_strategy(var.distribute_strategy):
1454      if (ds_context.in_cross_replica_context() and
1455          not values_util.in_replica_update_context()):
1456        values_util.mark_as_unsaveable()
1457        return values_util.on_read_assign_cross_replica(var, value,
1458                                                        read_value=read_value)
1459      else:
1460        return values_util.on_write_assign(var, value,
1461                                           use_locking=use_locking,
1462                                           name=name,
1463                                           read_value=read_value)
1464
1465  def scatter_sub(self, *args, **kwargs):
1466    del args, kwargs
1467    self._scatter_not_implemented("scatter_sub")
1468
1469  def scatter_add(self, *args, **kwargs):
1470    del args, kwargs
1471    self._scatter_not_implemented("scatter_add")
1472
1473  def scatter_mul(self, *args, **kwargs):
1474    del args, kwargs
1475    self._scatter_not_implemented("scatter_mul")
1476
1477  def scatter_div(self, *args, **kwargs):
1478    del args, kwargs
1479    self._scatter_not_implemented("scatter_div")
1480
1481  def scatter_min(self, *args, **kwargs):
1482    del args, kwargs
1483    self._scatter_not_implemented("scatter_min")
1484
1485  def scatter_max(self, *args, **kwargs):
1486    del args, kwargs
1487    self._scatter_not_implemented("scatter_max")
1488
1489  def scatter_update(self, *args, **kwargs):
1490    del args, kwargs
1491    self._scatter_not_implemented("scatter_update")
1492
1493  def get_saveable(self, var, primary_var, name):
1494    """Create a saveable object for the given variable."""
1495    return values_util.get_on_read_saveable(var, primary_var, name)
1496
1497  def get_restore_ops(self, var, tensor):
1498    """Restore the same value into all variables."""
1499    return values_util.get_on_read_restore_ops(var, tensor, self._aggregation)
1500
1501
1502class OnWritePolicy(VariablePolicy):
1503  """Policy defined for `tf.VariableSynchronization.ON_WRITE` synchronization.
1504
1505  This policy is created when the following `synchronization` and `aggregation`
1506  parameters are specified when creating a `tf.Variable` in `tf.distribute`
1507  scope and `synchronization` is equal to `tf.VariableSynchronization.ON_WRITE`
1508  or `tf.VariableSynchronization.AUTO`.
1509  """
1510
1511  def _is_mirrored(self):
1512    return True
1513
1514  def value(self, var):
1515    return var._get_on_device_or_primary().value()  # pylint: disable=protected-access
1516
1517  def _as_graph_element(self, var):
1518    return var._get_on_device_or_primary()._as_graph_element()  # pylint: disable=protected-access
1519
1520  def _get_cross_replica(self, var):
1521    # Return identity, to avoid directly exposing the variable to the user and
1522    # allowing it to be modified by mistake.
1523    return array_ops.identity(var._get_on_device_or_primary())  # pylint: disable=protected-access
1524
1525  def _update_replica(self, var, update_fn, value, **kwargs):
1526    if var.aggregation == variables_lib.VariableAggregation.NONE:
1527      return update_fn(var._get_on_device_or_primary(), value, **kwargs)  # pylint: disable=protected-access
1528    return _on_write_update_replica(var, update_fn, value, **kwargs)
1529
1530  def assign(self, var, value, use_locking=False, name=None, read_value=True):
1531    return values_util.on_write_assign(var, value, use_locking=use_locking,
1532                                       name=name, read_value=read_value)
1533
1534  def assign_add(self, var, value, use_locking=False, name=None,
1535                 read_value=True):
1536    return values_util.on_write_assign_add(var, value, use_locking=use_locking,
1537                                           name=name, read_value=read_value)
1538
1539  def assign_sub(self, var, value, use_locking=False, name=None,
1540                 read_value=True):
1541    return values_util.on_write_assign_sub(var, value, use_locking=use_locking,
1542                                           name=name, read_value=read_value)
1543
1544  def scatter_sub(self, var, sparse_delta, use_locking=False, name=None):
1545    return values_util.scatter_sub(var, sparse_delta, use_locking=use_locking,
1546                                   name=name)
1547
1548  def scatter_add(self, var, sparse_delta, use_locking=False, name=None):
1549    return values_util.scatter_add(var, sparse_delta, use_locking=use_locking,
1550                                   name=name)
1551
1552  def scatter_mul(self, var, sparse_delta, use_locking=False, name=None):
1553    return values_util.scatter_mul(var, sparse_delta, use_locking=use_locking,
1554                                   name=name)
1555
1556  def scatter_div(self, var, sparse_delta, use_locking=False, name=None):
1557    return values_util.scatter_div(var, sparse_delta, use_locking=use_locking,
1558                                   name=name)
1559
1560  def scatter_min(self, var, sparse_delta, use_locking=False, name=None):
1561    if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
1562        self._aggregation != vs.VariableAggregation.NONE):
1563      raise NotImplementedError(values_util.scatter_error_msg.format(
1564          op_name="scatter_min", aggregation=self._aggregation))
1565    return values_util.scatter_min(var, sparse_delta, use_locking=use_locking,
1566                                   name=name)
1567
1568  def scatter_max(self, var, sparse_delta, use_locking=False, name=None):
1569    if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
1570        self._aggregation != vs.VariableAggregation.NONE):
1571      raise NotImplementedError(values_util.scatter_error_msg.format(
1572          op_name="scatter_max", aggregation=self._aggregation))
1573    return values_util.scatter_max(var, sparse_delta, use_locking=use_locking,
1574                                   name=name)
1575
1576  def scatter_update(self, var, sparse_delta, use_locking=False, name=None):
1577    if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
1578        self._aggregation != vs.VariableAggregation.NONE):
1579      raise NotImplementedError(values_util.scatter_error_msg.format(
1580          op_name="scatter_update", aggregation=self._aggregation))
1581    return values_util.scatter_update(var, sparse_delta,
1582                                      use_locking=use_locking,
1583                                      name=name)
1584
1585  def get_saveable(self, var, primary_var, name):
1586    """Saveable ops for AUTO variables."""
1587    return values_util.get_on_write_saveable(var, primary_var, name)
1588
1589  def get_restore_ops(self, var, tensor):
1590    return values_util.get_on_write_restore_ops(var, tensor)
1591
1592  def _write_object_proto(self, var, proto, options):
1593    """Update a SavedObject proto for the caller.
1594
1595    If a DistributedVariable object supports this method, it will be called when
1596    saving with a pre-built `SavedObject` proto representing the object, plus an
1597    instance of `SaveOptions`. This method is then free to modify that proto
1598    instance.
1599
1600    `DistributedVariable` with `AUTO` or `ON_WRITE` synchronization optionally
1601    write out information about their components to the
1602    `experimental_distributed_variable_components` field of a
1603    `SavedVariable` (depending on the `SaveOptions` variable policy).
1604
1605    Args:
1606      var : A DistributedVariable object
1607      proto: A pre-built `SavedObject` proto for this object. It is assumed this
1608        will be a `SavedVariable` instance.
1609      options: A `SaveOptions` instance.
1610    """
1611    values_util.write_object_proto(var, proto, options)
1612