1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Various classes representing distributed values."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import contextlib
23import weakref
24import six
25
26from tensorflow.python.distribute import device_util
27from tensorflow.python.distribute import distribute_lib
28from tensorflow.python.distribute import distribution_strategy_context
29from tensorflow.python.distribute import reduce_util
30from tensorflow.python.eager import context
31from tensorflow.python.eager import tape
32from tensorflow.python.framework import ops
33from tensorflow.python.framework import tensor_util
34from tensorflow.python.ops import array_ops
35from tensorflow.python.ops import control_flow_ops
36from tensorflow.python.ops import gen_resource_variable_ops
37from tensorflow.python.ops import math_ops
38from tensorflow.python.ops import variable_scope as vs
39from tensorflow.python.training import saver
40from tensorflow.python.training.tracking import base as trackable
41from tensorflow.python.util import nest
42
43
44def _devices_match(d1, d2):
45  return device_util.canonicalize(d1) == device_util.canonicalize(d2)
46
47
48class DeviceMap(object):
49  """A mapping of replicas & logical device ids to devices."""
50
51  @property
52  def all_devices(self):
53    """Returns a tuple of strings with all devices in this DeviceMap."""
54    raise NotImplementedError("Required for DeviceMap implementations.")
55
56  @property
57  def devices_by_replica(self):
58    """Returns a tuple `t` where `t[replica]` is the devices for `replica`."""
59    raise NotImplementedError("Required for DeviceMap implementations.")
60
61  @property
62  def num_logical_devices(self):
63    """Count of the number of devices each replica may be defined across."""
64    raise NotImplementedError("Required for DeviceMap implementations.")
65
66  @property
67  def num_replicas_in_graph(self):
68    """Number of replicas defined in this graph."""
69    raise NotImplementedError("Required for DeviceMap implementations.")
70
71  def logical_device_from_values(self, values):
72    """Returns the logical device index `values` is on."""
73    raise NotImplementedError("Required for DeviceMap implementations.")
74
75  def logical_to_actual_devices(self, logical_device_id):
76    """Returns sequence of `num_replicas_in_graph` devices."""
77    raise NotImplementedError("Required for DeviceMap implementations.")
78
79  def select_for_current_replica(self, values, replica_context):
80    """Select the element of `values` for the current replica."""
81    raise NotImplementedError("Required for DeviceMap implementations.")
82
83  def replica_for_device(self, device):
84    """Return the replica id containing `device`."""
85    raise NotImplementedError("Required for DeviceMap implementations.")
86
87  def select_for_device(self, values, device):
88    """Select the element of `values` to access from `device`."""
89    raise NotImplementedError("Required for DeviceMap implementations.")
90
91  def is_device_in_replica(self, device, replica_id):
92    """Returns whether `device` is a member of replica `replica_id`."""
93    raise NotImplementedError("Required for DeviceMap implementations.")
94
95
96class SingleDeviceMap(DeviceMap):
97  """A device map for 1 non-computation device.
98
99  Use `SingleDeviceMap` when the device does not correspond to some replica of
100  the computation. For computation devices, use `ReplicaDeviceMap` below (even
101  if there is only a single device in the map).
102  """
103
104  def __init__(self, device):
105    """Initialize a `SingleDeviceMap`.
106
107    Args:
108      device: A string device.
109    """
110    assert isinstance(device, six.string_types)
111    self._device = device_util.canonicalize(device)
112    self._devices = (self._device,)
113
114  @property
115  def all_devices(self):
116    return self._devices
117
118  @property
119  def devices_by_replica(self):
120    raise ValueError("SingleDeviceMap not indexed by replicas")
121
122  @property
123  def num_logical_devices(self):
124    return 1
125
126  @property
127  def num_replicas_in_graph(self):
128    return 1
129
130  def logical_device_from_values(self, values):
131    del values
132    return 0
133
134  def logical_to_actual_devices(self, logical_device_id):
135    assert logical_device_id == 0
136    return self._devices
137
138  def select_for_current_replica(self, values, replica_context):
139    assert len(values) == 1
140    del replica_context
141    return values[0]
142
143  def replica_for_device(self, device):
144    raise ValueError("SingleDeviceMap not indexed by replicas")
145
146  def select_for_device(self, values, device):
147    assert len(values) == 1
148    if self._device != device:
149      raise ValueError("Device %s not found in %s (current device %s)" %
150                       (device, self._devices, device_util.current()))
151    return values[0]
152
153  def is_device_in_replica(self, device, replica_id):
154    raise ValueError("SingleDeviceMap not indexed by replicas")
155
156  def __repr__(self):
157    return "%s(%r)" % (self.__class__.__name__, self._device)
158
159
160class ReplicaDeviceMap(DeviceMap):
161  """A device map for 1 device per replica."""
162
163  def __init__(self, devices):
164    """Initialize a `ReplicaDeviceMap`.
165
166    Args:
167      devices: `devices[i]` is the string device for replica `i`.
168    """
169    self._devices = tuple(device_util.canonicalize(d) for d in devices)
170    if len(set(self._devices)) != len(self._devices):
171      raise ValueError("Duplicate devices in %s, after canonicalization: %s" %
172                       (devices, self._devices))
173    self._device_to_replica = {d: r for r, d in enumerate(self._devices)}
174
175  @property
176  def all_devices(self):
177    return self._devices
178
179  @property
180  def devices_by_replica(self):
181    return ((d,) for d in self._devices)
182
183  @property
184  def num_logical_devices(self):
185    return 1
186
187  @property
188  def num_replicas_in_graph(self):
189    return len(self._devices)
190
191  def logical_device_from_values(self, values):
192    del values
193    return 0
194
195  def logical_to_actual_devices(self, logical_device_id):
196    assert logical_device_id == 0
197    return self._devices
198
199  def select_for_current_replica(self, values, replica_context):
200    assert len(values) == len(self._devices)
201    replica_id = replica_context.replica_id_in_sync_group
202    if not isinstance(replica_id, int):
203      replica_id = tensor_util.constant_value(replica_id)
204    return values[replica_id]
205
206  def replica_for_device(self, device):
207    return self._device_to_replica.get(device)
208
209  def select_for_device(self, values, device):
210    assert len(values) == len(self._devices)
211    replica_id = self._device_to_replica.get(device)
212    if replica_id is None:
213      raise ValueError("Device %s not found in %s (current device %s)" %
214                       (device, self._devices, device_util.current()))
215    return values[replica_id]
216
217  def is_device_in_replica(self, device, replica_id):
218    return _devices_match(device, self._devices[replica_id])
219
220  def __str__(self):
221    return "[%s]" % (", ".join(self._devices))
222
223  def __repr__(self):
224    return "%s([%s])" % (self.__class__.__name__,
225                         ", ".join(repr(d) for d in self._devices))
226
227
228LogicalDeviceSpec = collections.namedtuple(
229    "LogicalDeviceSpec", ("device_map", "logical_device"))
230
231
232class DistributedValues(object):
233  """Holds a map from device to values. Either PerReplica or Mirrored."""
234
235  def __init__(self, device_map, values, logical_device=None):
236    assert isinstance(device_map, DeviceMap)
237    self._device_map = device_map
238    self._values = tuple(values)
239    if logical_device is None:
240      logical_device = device_map.logical_device_from_values(self._values)
241    self._logical_device = logical_device
242
243  # TODO(josh11b): Split this into two functions, one with device, one without.
244  def get(self, device=None):
245    """Returns the value for the current device or raises a ValueError."""
246    if device is None:
247      replica_context = distribution_strategy_context.get_replica_context()
248      if replica_context:
249        return self._device_map.select_for_current_replica(
250            self._values, replica_context)
251      else:
252        device = distribute_lib.get_update_device()
253        if device is None:
254          return self._get_cross_replica()
255    device = device_util.canonicalize(device)
256    return self._device_map.select_for_device(self._values, device)
257
258  @property
259  def primary(self):
260    """Returns a representative component."""
261    return self._values[0]
262
263  @property
264  def devices(self):
265    return self._device_map.logical_to_actual_devices(self._logical_device)
266
267  @property
268  def logical_device(self):
269    return self._logical_device
270
271  @property
272  def device_map(self):
273    return self._device_map
274
275  # TODO(josh11b): Replace experimental_local_results with this?
276  @property
277  def values(self):
278    return self._values
279
280  @property
281  def is_tensor_like(self):
282    for v in self._values:
283      if not tensor_util.is_tensor(v):
284        return False
285    return True
286
287  def __str__(self):
288    devices = self.devices
289    assert len(self._values) == len(devices)
290    debug_str = ",\n".join("  %d %s: %s" % (i, devices[i], self._values[i])
291                           for i in range(len(devices)))
292    return "%s:{\n%s\n}" % (self.__class__.__name__, debug_str)
293
294  def __repr__(self):
295    devices = self.devices
296    assert len(self._values) == len(devices)
297    debug_repr = ",\n".join("  %d %s: %r" % (i, devices[i], self._values[i])
298                            for i in range(len(devices)))
299    return "%s:{\n%s\n}" % (self.__class__.__name__, debug_repr)
300
301
302# NOTE(josh11b,apassos): It would be great if we could inspect the values this was
303# initialized with and use that to generate the overloaded operators here.
304# Unfortunately, Python's rules for special methods don't allow this, see
305# https://docs.python.org/3/reference/datamodel.html#special-method-names
306# "if a class defines a method named __getitem__(), and x is an instance of
307# this class, then x[i] is roughly equivalent to type(x).__getitem__(x, i)."
308# In particular, these special methods don't go through __getattr__, and
309# it will only use those methods if they are defined in the class, not the
310# object.
311class DistributedDelegate(DistributedValues):
312  """A map from device to values; acts as the same type as the values."""
313
314  def __getattr__(self, name):
315    # TODO(priyag): This needs to be made robust against pitfalls from mix use
316    # __getattr__ and @property. See b/120402273.
317    return getattr(self.get(), name)
318
319  # pylint: disable=multiple-statements
320  def __add__(self, o): return self.get() + o
321  def __radd__(self, o): return o + self.get()
322  def __sub__(self, o): return self.get() - o
323  def __rsub__(self, o): return o - self.get()
324  def __mul__(self, o): return self.get() * o
325  def __rmul__(self, o): return o * self.get()
326  def __truediv__(self, o): return self.get() / o
327  def __rtruediv__(self, o): return o / self.get()
328
329  def __floordiv__(self, o):
330    return self.get() // o
331
332  def __rfloordiv__(self, o): return o // self.get()
333  def __mod__(self, o): return self.get() % o
334  def __rmod__(self, o): return o % self.get()
335  def __lt__(self, o): return self.get() < o
336  def __le__(self, o): return self.get() <= o
337  def __gt__(self, o): return self.get() > o
338  def __ge__(self, o): return self.get() >= o
339  def __and__(self, o): return self.get() & o
340  def __rand__(self, o): return o & self.get()
341  def __or__(self, o): return self.get() | o
342  def __ror__(self, o): return o | self.get()
343  def __xor__(self, o): return self.get() ^ o
344  def __rxor__(self, o): return o ^ self.get()
345  def __getitem__(self, o): return self.get()[o]
346  def __pow__(self, o, modulo=None): return pow(self.get(), o, modulo)
347  def __rpow__(self, o): return pow(o, self.get())
348  def __invert__(self): return ~self.get()
349  def __neg__(self): return -self.get()
350  def __abs__(self): return abs(self.get())
351
352  def __div__(self, o):
353    try:
354      return self.get().__div__(o)
355    except AttributeError:
356      # See https://docs.python.org/3/library/constants.html#NotImplemented
357      return NotImplemented
358
359  def __rdiv__(self, o):
360    try:
361      return self.get().__rdiv__(o)
362    except AttributeError:
363      # See https://docs.python.org/3/library/constants.html#NotImplemented
364      return NotImplemented
365
366  def __matmul__(self, o):
367    try:
368      return self.get().__matmul__(o)
369    except AttributeError:
370      # See https://docs.python.org/3/library/constants.html#NotImplemented
371      return NotImplemented
372
373  def __rmatmul__(self, o):
374    try:
375      return self.get().__rmatmul__(o)
376    except AttributeError:
377      # See https://docs.python.org/3/library/constants.html#NotImplemented
378      return NotImplemented
379
380  # TODO(josh11b): Even more operator overloads.
381
382
383class PerReplica(DistributedValues):
384  """Holds a map from device to unsynchronized values."""
385  pass
386
387
388# Note that unlike PerReplica, Mirrored values inherit from
389# DistributedDelegate and so can be used directly in cross-replica mode.
390class Mirrored(DistributedDelegate):
391  """Holds a map from device to values which are kept in sync."""
392
393  def _get_cross_replica(self):
394    device = device_util.canonicalize(device_util.current())
395    replica_id = self._device_map.replica_for_device(device)
396    if replica_id is None:
397      return self.primary
398    return self._values[replica_id]
399
400  def _as_graph_element(self):
401    obj = self.get()
402    conv_fn = getattr(obj, "_as_graph_element", None)
403    if conv_fn and callable(conv_fn):
404      return conv_fn()
405    return obj
406
407
408def _assign_on_device(device, variable, tensor):
409  with ops.device(device):
410    return variable.assign(array_ops.identity(tensor))
411
412
413def _assert_strategy(strategy):
414  if not distribution_strategy_context.has_strategy():
415    raise RuntimeError(
416        'Need to be inside "with strategy.scope()" for %s' %
417        (strategy,))
418  current_strategy = distribution_strategy_context.get_strategy()
419  if current_strategy is not strategy:
420    raise RuntimeError(
421        "Mixing different tf.distribute.Strategy objects: %s is not %s" %
422        (current_strategy, strategy))
423
424
425@contextlib.contextmanager
426def _enter_or_assert_strategy(strategy):
427  if not distribution_strategy_context.has_strategy():
428    with strategy.scope():
429      yield
430  else:
431    _assert_strategy(strategy)
432    yield
433
434
435DistributedVarOp = collections.namedtuple(
436    "DistributedVarOp", ["name", "graph", "type"])
437
438
439class DistributedVariable(DistributedDelegate):
440  """Holds a map from device to variables."""
441  # TODO(josh11b): Support changing the set of variables if e.g. if new
442  # devices are joining or a device is to leave.
443
444  def __init__(self, strategy, device_map, values, logical_device=None):
445    self._distribute_strategy = strategy
446    super(DistributedVariable, self).__init__(
447        device_map, values, logical_device=logical_device)
448    self._common_name = self.primary.name.split(":")[0]
449    # Use a weakref to make it easy to map from the contained values
450    # to the container without introducing a reference cycle.
451    for v in values:
452      v._distributed_container = weakref.ref(self)  # pylint: disable=protected-access
453    # tf.keras keeps track of variables initialized using this attribute. When
454    # tf.keras gets the default session, it initializes all uninitialized vars.
455    # We need to make _keras_initialized a member of DistributedVariable because
456    # without this it will use `__getattr__` which will delegate to a component
457    # variable.
458    self._keras_initialized = False
459    # Typically, a `DistributedVariable`'s initializer is composed of the
460    # initializers of the components variables. However, in some cases, such as
461    # when restoring from a checkpoint, we may set the _initializer_op
462    # property on the entire `DistributedVariable`.
463    self._initializer_op = None
464
465  def is_initialized(self, name=None):
466    """Identifies if all the component variables are initialized.
467
468    Args:
469      name: Name of the final `logical_and` op.
470
471    Returns:
472      The op that evaluates to True or False depending on if all the
473      component variables are initialized.
474    """
475    result = self.primary.is_initialized()
476    # We iterate through the list of values except the last one to allow us to
477    # name the final `logical_and` op the same name that is passed by the user
478    # to the `is_initialized` op. For distributed variables, the
479    # `is_initialized` op is a `logical_and` op.
480    for v in self._values[1:-1]:
481      result = math_ops.logical_and(result, v.is_initialized())
482    result = math_ops.logical_and(result, self._values[-1].is_initialized(),
483                                  name=name)
484    return result
485
486  @property
487  def initializer(self):
488    if self._initializer_op:
489      init_op = self._initializer_op
490    else:
491      # return grouped ops of all the var initializations of component values of
492      # the mirrored variable
493      init_op = control_flow_ops.group(tuple(
494          v.initializer for v in self._values))
495    return init_op
496
497  def _get_closest(self):
498    """Return member in the same replica if possible, else the primary."""
499    replica_context = distribution_strategy_context.get_replica_context()
500    if replica_context:
501      return self._device_map.select_for_current_replica(
502          self._values, replica_context)
503    device = distribute_lib.get_update_device()
504    if device is None:
505      device = device_util.canonicalize(device_util.current())
506    replica_id = self._device_map.replica_for_device(device)
507    if replica_id is None:
508      return self.primary
509    return self._values[replica_id]
510
511  def initialized_value(self):
512    return self._get_closest().initialized_value()
513
514  @property
515  def initial_value(self):
516    return self._get_closest().initial_value
517
518  @property
519  def graph(self):
520    return self.primary.graph
521
522  @property
523  def _shared_name(self):
524    return self._common_name
525
526  @property
527  def _unique_id(self):
528    return self.primary._unique_id   # pylint: disable=protected-access
529
530  @property
531  def _graph_key(self):
532    """Lets Optimizers know which graph this variable is from."""
533    return self.primary._graph_key  # pylint: disable=protected-access
534
535  @property
536  def name(self):
537    return self.primary.name
538
539  @property
540  def dtype(self):
541    return self.primary.dtype
542
543  @property
544  def shape(self):
545    return self.primary.shape
546
547  @property
548  def trainable(self):
549    return self.primary.trainable
550
551  @property
552  def distribute_strategy(self):
553    return self._distribute_strategy
554
555  def get_shape(self):
556    return self.primary.get_shape()
557
558  def to_proto(self, export_scope=None):
559    return self.primary.to_proto(export_scope=export_scope)
560
561  @property
562  def op(self):
563    # We want cross-replica code that does some var.op.X calls
564    # to work (even if the current device isn't in self.devices), but
565    # other uses of var.op in a cross-replica context to fail.
566    if distribution_strategy_context.in_cross_replica_context():
567      return DistributedVarOp(self.primary.op.name,
568                              self.primary.op.graph,
569                              self.primary.op.type)
570    return self.get().op
571
572  @property
573  def _in_graph_mode(self):
574    return self.primary._in_graph_mode   # pylint: disable=protected-access
575
576  def read_value(self):
577    return self._distribute_strategy.extended.read_var(self)
578
579  def value(self):
580    return self._get_closest().value()
581
582  def _should_act_as_resource_variable(self):
583    """Pass resource_variable_ops.is_resource_variable check."""
584    pass
585
586
587ops.register_dense_tensor_like_type(DistributedVariable)
588
589
590def _validate_colocate_extended(v, extended):
591  variable_strategy = v._distribute_strategy  # pylint: disable=protected-access
592  if variable_strategy.extended is not extended:
593    raise ValueError(
594        "`colocate_vars_with` must only be passed a variable created in this "
595        "tf.distribute.Strategy.scope(), not %s created in scope: %s" %
596        (v, variable_strategy))
597
598
599def validate_colocate_distributed_variable(v, extended):
600  if not isinstance(v, DistributedVariable):
601    raise ValueError(
602        "`colocate_vars_with` must only be passed a variable created in this "
603        "tf.distribute.Strategy.scope(), not: %r" % (v,))
604  _validate_colocate_extended(v, extended)
605
606
607def validate_colocate_tpu_variable(v, extended):
608  if not isinstance(v, TPUMirroredVariable):
609    raise ValueError(
610        "`colocate_vars_with` must only be passed a variable created in this "
611        "tf.distribute.Strategy.scope(), not: %r" % (v,))
612  _validate_colocate_extended(v, extended)
613
614
615def validate_colocate(v, extended):
616  if not hasattr(v, "_distribute_strategy"):
617    raise ValueError(
618        "`colocate_vars_with` must only be passed a variable created in this "
619        "tf.distribute.Strategy.scope(), not: %r" % (v,))
620  _validate_colocate_extended(v, extended)
621
622
623def _apply_aggregation(strategy, value, aggregation, destinations):
624  if aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
625    return strategy.extended.broadcast_to(
626        strategy.experimental_local_results(value)[0],
627        destinations=destinations)
628  reduce_op = reduce_util.ReduceOp.from_variable_aggregation(aggregation)
629  return strategy.extended.reduce_to(reduce_op, value, destinations)
630
631
632class _MirroredSaveable(saver.BaseSaverBuilder.ResourceVariableSaveable):
633  """Class for defining how to restore a MirroredVariable."""
634
635  def __init__(self, mirrored_variable, primary_variable, name):
636    self._mirrored_variable = mirrored_variable
637    super(_MirroredSaveable, self).__init__(primary_variable, "", name)
638
639  def restore(self, restored_tensors, restored_shapes):
640    """Restore the same value into all variables."""
641    tensor, = restored_tensors
642    return control_flow_ops.group(tuple(
643        _assign_on_device(v.device, v, tensor)
644        for v in self._mirrored_variable.values))
645
646
647class MirroredVariable(DistributedVariable, Mirrored,
648                       trackable.Trackable):
649  """Holds a map from device to variables whose values are kept in sync."""
650
651  def __init__(
652      self, strategy, device_map, values, aggregation, logical_device=None):
653    super(MirroredVariable, self).__init__(
654        strategy, device_map, values, logical_device=logical_device)
655    self._aggregation = aggregation
656
657  # The arguments to update() are automatically unwrapped so the update()
658  # function would normally see regular variables, not MirroredVariables.
659  # However, the update function can still operate on wrapped MirroredVariables
660  # through object members, captured arguments, etc. This is more likely in an
661  # update_non_slot() function (like OptimizerV2._finish), which can
662  # update several non-slot variables in one call.
663  def _assign_func(self, *args, **kwargs):
664    with _enter_or_assert_strategy(self._distribute_strategy):
665      f = kwargs.pop("f")
666      if distribution_strategy_context.in_cross_replica_context():
667        update_device = distribute_lib.get_update_device()
668        if update_device is not None:
669          # We are calling an assign function on the mirrored variable in an
670          # update context.
671          v = self.get(device=update_device)
672          return f(v, *args, **kwargs)
673
674        # We are calling assign on the mirrored variable in cross replica
675        # context, use `strategy.extended.update()` to update the variable.
676        return self._distribute_strategy.extended.update(
677            self, f, args=args, kwargs=kwargs)
678      else:
679        _assert_replica_context(self._distribute_strategy)
680        # We are calling an assign function on the mirrored variable in replica
681        # context.
682        # We reduce the value we want to assign/add/sub. More details about how
683        # we handle the different use cases can be found in the _reduce method.
684        # We call the function on each of the mirrored variables with the
685        # reduced value.
686        if self._aggregation == vs.VariableAggregation.NONE:
687          raise ValueError("You must specify an aggregation method to update a "
688                           "MirroredVariable in Replica Context.")
689
690        def merge_fn(strategy, value, *other_args, **other_kwargs):
691          v = _apply_aggregation(strategy, value, self._aggregation, self)
692          return strategy.extended.update(
693              self, f, args=(v,) + other_args, kwargs=other_kwargs)
694
695        return distribution_strategy_context.get_replica_context().merge_call(
696            merge_fn, args=args, kwargs=kwargs)
697
698  def assign_sub(self, *args, **kwargs):
699    assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw)
700    return self._assign_func(f=assign_sub_fn, *args, **kwargs)
701
702  def assign_add(self, *args, **kwargs):
703    assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw)
704    return self._assign_func(f=assign_add_fn, *args, **kwargs)
705
706  def assign(self, *args, **kwargs):
707    assign_fn = lambda var, *a, **kw: var.assign(*a, **kw)
708    return self._assign_func(f=assign_fn, *args, **kwargs)
709
710  @property
711  def aggregation(self):
712    return self._aggregation
713
714  def _get_cross_replica(self):
715    device = device_util.canonicalize(device_util.current())
716    replica_id = self._device_map.replica_for_device(device)
717    if replica_id is None:
718      return array_ops.identity(self.primary)
719    return array_ops.identity(self._values[replica_id])
720
721  def _as_graph_element(self):
722    # pylint: disable=protected-access
723    if distribution_strategy_context.in_cross_replica_context():
724      return self.primary._as_graph_element()
725    return self.get()._as_graph_element()
726
727  def _gather_saveables_for_checkpoint(self):
728    """Overrides Trackable method.
729
730    This allows both name-based and object-based save and restore of
731    MirroredVariables.
732
733    Returns:
734      A dictionary mapping attribute names to `SaveableObject` factories.
735    """
736    def _saveable_factory(name=self._common_name):
737      return _MirroredSaveable(self, self.primary, name)
738    return {trackable.VARIABLE_VALUE_KEY: _saveable_factory}
739
740
741# Register a conversion function which reads the value of the variable,
742# allowing instances of the class to be used as tensors.
743def _tensor_conversion_mirrored(var, dtype=None, name=None, as_ref=False):
744  # Try to avoid assignments to and other mutations of MirroredVariable
745  # state except through a DistributionStrategy.extended.update() call.
746  assert not as_ref
747  return ops.internal_convert_to_tensor(
748      var.get(), dtype=dtype, name=name, as_ref=as_ref)
749
750
751ops.register_tensor_conversion_function(MirroredVariable,
752                                        _tensor_conversion_mirrored)
753
754
755def _enclosing_tpu_context():
756  # pylint: disable=protected-access
757  tpu_context = ops.get_default_graph()._get_control_flow_context()
758  # pylint: enable=protected-access
759  while tpu_context is not None and not isinstance(
760      tpu_context, control_flow_ops.XLAControlFlowContext):
761    tpu_context = tpu_context.outer_context
762  return tpu_context
763
764
765# TODO(jhseu): Deduplicate code. We copy code because we don't want to
766# inherit from DistributedDelegate. DistributedDelegate will not work in a
767# tpu.replicate() because it assumes that you're in a device context where you
768# can operate on a single version of the variable, but a tpu.replicate()
769# operates on all variables and is replicated during a rewrite pass.
770class TPUMirroredVariable(trackable.Trackable):
771  """Holds a map from device to TPU variables whose values are kept in sync."""
772
773  def __init__(
774      self, strategy, device_map, values, aggregation, logical_device=None):
775    assert isinstance(device_map, DeviceMap)
776    self._distribute_strategy = strategy
777    self._device_map = device_map
778    self._values = tuple(values)
779    if logical_device is None:
780      logical_device = device_map.logical_device_from_values(self._values)
781    self._logical_device = logical_device
782
783    # Use a weakref to make it easy to map from the contained values
784    # to the container without introducing a reference cycle.
785    for v in self._values:
786      v._mirrored_container = weakref.ref(self)  # pylint: disable=protected-access
787    self._common_name = self.primary.name.split(":")[0]
788    self._aggregation = aggregation
789    # Needed for GradientTape
790    self._trainable = self.primary.trainable
791    # Typically like `DistributedVariable`, a `TPUMirroredVariable`'s
792    # initializer is composed of the initializers of the components variables.
793    # However, in some cases, such as when restoring from a checkpoint, we may
794    # set the _initializer_op property on the entire `TPUMirroredVariable`.
795    self._initializer_op = None
796
797  def _get(self, device=None):
798    """Returns the value for the current device or raises a ValueError."""
799    if device is None:
800      replica_context = distribution_strategy_context.get_replica_context()
801      if replica_context:
802        return self._device_map.select_for_current_replica(
803            self._values, replica_context)
804      else:
805        device = distribute_lib.get_update_device()
806        if device is None:
807          return self._get_cross_replica()
808    device = device_util.canonicalize(device)
809    return self._device_map.select_for_device(self._values, device)
810
811  @property
812  def primary(self):
813    """Returns a representative component."""
814    return self._values[0]
815
816  @property
817  def devices(self):
818    return self._device_map.logical_to_actual_devices(self._logical_device)
819
820  @property
821  def logical_device(self):
822    return self._logical_device
823
824  @property
825  def device_map(self):
826    return self._device_map
827
828  # TODO(josh11b): Replace experimental_local_results with this?
829  @property
830  def values(self):
831    return self._values
832
833  @property
834  def distribute_strategy(self):
835    return self._distribute_strategy
836
837  # pylint: disable=multiple-statements
838  def __add__(self, o): return self.read_value() + o
839  def __radd__(self, o): return o + self.read_value()
840  def __sub__(self, o): return self.read_value() - o
841  def __rsub__(self, o): return o - self.read_value()
842  def __mul__(self, o): return self.read_value() * o
843  def __rmul__(self, o): return o * self.read_value()
844  def __truediv__(self, o): return self.read_value() / o
845  def __rtruediv__(self, o): return o / self.read_value()
846  def __floordiv__(self, o): return self.read_value() // o
847  def __rfloordiv__(self, o): return o // self.read_value()
848  def __mod__(self, o): return self.read_value() % o
849  def __rmod__(self, o): return o % self.read_value()
850  def __lt__(self, o): return self.read_value() < o
851  def __le__(self, o): return self.read_value() <= o
852  def __gt__(self, o): return self.read_value() > o
853  def __ge__(self, o): return self.read_value() >= o
854  def __and__(self, o): return self.read_value() & o
855  def __rand__(self, o): return o & self.read_value()
856  def __or__(self, o): return self.read_value() | o
857  def __ror__(self, o): return o | self.read_value()
858  def __xor__(self, o): return self.read_value() ^ o
859  def __rxor__(self, o): return o ^ self.read_value()
860  def __getitem__(self, o): return self.read_value()[o]
861  def __pow__(self, o, modulo=None): return pow(self.read_value(), o, modulo)
862  def __rpow__(self, o): return pow(o, self.read_value())
863  def __invert__(self): return ~self.read_value()
864  def __neg__(self): return -self.read_value()
865  def __abs__(self): return abs(self.read_value())
866
867  def __div__(self, o):
868    try:
869      return self.read_value().__div__(o)
870    except AttributeError:
871      # See https://docs.python.org/3/library/constants.html#NotImplemented
872      return NotImplemented
873
874  def __rdiv__(self, o):
875    try:
876      return self.read_value().__rdiv__(o)
877    except AttributeError:
878      # See https://docs.python.org/3/library/constants.html#NotImplemented
879      return NotImplemented
880
881  def __matmul__(self, o):
882    try:
883      return self.read_value().__matmul__(o)
884    except AttributeError:
885      # See https://docs.python.org/3/library/constants.html#NotImplemented
886      return NotImplemented
887
888  def __rmatmul__(self, o):
889    try:
890      return self.read_value().__rmatmul__(o)
891    except AttributeError:
892      # See https://docs.python.org/3/library/constants.html#NotImplemented
893      return NotImplemented
894
895  def __str__(self):
896    devices = self.devices
897    debug_str = ",\n".join("  %d %s: %s" % (i, devices[i], self._values[i])
898                           for i in range(len(devices)))
899    return "%s:{\n%s\n}" % (self.__class__.__name__, debug_str)
900
901  def __repr__(self):
902    devices = self.devices
903    debug_repr = ",\n".join("  %d %s: %r" % (i, devices[i], self._values[i])
904                            for i in range(len(devices)))
905    return "%s:{\n%s\n}" % (self.__class__.__name__, debug_repr)
906
907  @property
908  def handle(self):
909    # If we're in a tpu.rewrite(), return the replicated handle.
910    tpu_context = _enclosing_tpu_context()
911    if tpu_context is not None:
912      return tpu_context.get_replicated_var_handle(
913          self._common_name, self._values)
914
915    device = distribute_lib.get_update_device()
916    if device is None:
917      return self.primary.handle
918    return self._get(device=device).handle
919
920  @property
921  def device(self):
922    return self.handle.device
923
924  def eval(self, session=None):
925    return self.primary.eval(session)
926
927  # The arguments to update() are automatically unwrapped so the update()
928  # function would normally see regular variables, not MirroredVariables.
929  # However, the update function can still operate on wrapped MirroredVariables
930  # through object members, captured arguments, etc. This is more likely in an
931  # update_non_slot() function (like OptimizerV2._finish), which can
932  # update several non-slot variables in one call.
933  def _assign_func(self, *args, **kwargs):
934    with _enter_or_assert_strategy(self._distribute_strategy):
935      f = kwargs.pop("f")
936      if distribution_strategy_context.in_cross_replica_context():
937        if _enclosing_tpu_context() is not None:
938          return self._distribute_strategy.extended.update(
939              self, f, args=args, kwargs=kwargs)
940
941        update_device = distribute_lib.get_update_device()
942        # We are calling update on the mirrored variable in cross replica
943        # context.
944        if update_device is not None:
945          # We are calling an assign function on the mirrored variable in cross
946          # replica context.
947          v = self._get(device=update_device)
948          return f(v, *args, **kwargs)
949
950        return self._distribute_strategy.extended.update(
951            self, f, args=args, kwargs=kwargs)
952      else:
953        _assert_replica_context(self._distribute_strategy)
954        # We are calling an assign function on the mirrored variable in replica
955        # context.
956        # We reduce the value we want to assign/add/sub. More details about how
957        # we handle the different use cases can be found in the _reduce method.
958        # We call the function on each of the mirrored variables with the
959        # reduced value.
960        if self._aggregation == vs.VariableAggregation.NONE:
961          raise ValueError("You must specify an aggregation method to update a "
962                           "TPUMirroredVariable in Replica Context.")
963
964        def merge_fn(strategy, value, *other_args, **other_kwargs):
965          v = _apply_aggregation(strategy, value, self._aggregation, self)
966          return strategy.extended.update(
967              self, f, args=(v,) + other_args, kwargs=other_kwargs)
968
969        return distribution_strategy_context.get_replica_context().merge_call(
970            merge_fn, args=args, kwargs=kwargs)
971
972  @contextlib.contextmanager
973  def _handle_graph(self, handle):
974    # Note: might have an eager tensor but not be executing eagerly when
975    # building functions.
976    if (context.executing_eagerly() or isinstance(handle, ops.EagerTensor)
977        or ops.has_default_graph()):
978      yield
979    else:
980      with handle.graph.as_default():
981        yield
982
983  @property
984  def trainable(self):
985    return self._trainable
986
987  def _read_variable_op(self, parent_op=None):
988    if self.trainable:
989      tape.variable_accessed(self)
990    if parent_op is not None:
991      with ops.control_dependencies([parent_op]):
992        return gen_resource_variable_ops.read_variable_op(
993            self.handle, self.dtype)
994
995    return gen_resource_variable_ops.read_variable_op(
996        self.handle, self.dtype)
997
998  def read_value(self):
999    return self._read_variable_op()
1000
1001  def assign_sub(self, *args, **kwargs):
1002    def assign_sub_fn(var, delta, *ar, **kw):
1003      del ar
1004      name = kw.pop("name", None)
1005      read_value = kw.pop("read_value", True)
1006      with self._handle_graph(var.handle):
1007        op = gen_resource_variable_ops.assign_sub_variable_op(
1008            var.handle, ops.convert_to_tensor(delta, dtype=self.dtype),
1009            name=name)
1010      if read_value:
1011        return self._read_variable_op(parent_op=op)
1012      return op
1013
1014    return self._assign_func(f=assign_sub_fn, *args, **kwargs)
1015
1016  def assign_add(self, *args, **kwargs):
1017    def assign_add_fn(var, delta, *ar, **kw):
1018      del ar
1019      name = kw.pop("name", None)
1020      read_value = kw.pop("read_value", True)
1021      with self._handle_graph(var.handle):
1022        op = gen_resource_variable_ops.assign_add_variable_op(
1023            var.handle, ops.convert_to_tensor(delta, dtype=self.dtype),
1024            name=name)
1025      if read_value:
1026        return self._read_variable_op(parent_op=op)
1027      return op
1028
1029    return self._assign_func(f=assign_add_fn, *args, **kwargs)
1030
1031  def assign(self, *args, **kwargs):
1032    def assign_fn(var, value, *ar, **kw):
1033      del ar
1034      name = kw.pop("name", None)
1035      read_value = kw.pop("read_value", True)
1036      with self._handle_graph(var.handle):
1037        op = gen_resource_variable_ops.assign_variable_op(
1038            var.handle, ops.convert_to_tensor(value, dtype=self.dtype),
1039            name=name)
1040      if read_value:
1041        return self._read_variable_op(parent_op=op)
1042      return op
1043
1044    return self._assign_func(f=assign_fn, *args, **kwargs)
1045
1046  @property
1047  def aggregation(self):
1048    return self._aggregation
1049
1050  @property
1051  def constraint(self):
1052    return None
1053
1054  @property
1055  def initializer(self):
1056    if self._initializer_op:
1057      init_op = self._initializer_op
1058    else:
1059      init_op = control_flow_ops.group(tuple(
1060          v.initializer for v in self._values))
1061    return init_op
1062
1063  @property
1064  def graph(self):
1065    return self.primary.graph
1066
1067  @property
1068  def _shared_name(self):
1069    return self._common_name
1070
1071  @property
1072  def _unique_id(self):
1073    return self.primary._unique_id  # pylint: disable=protected-access
1074
1075  @property
1076  def name(self):
1077    return self.primary.name
1078
1079  @property
1080  def dtype(self):
1081    return self.primary.dtype
1082
1083  @property
1084  def shape(self):
1085    return self.primary.shape
1086
1087  def get_shape(self):
1088    return self.primary.get_shape()
1089
1090  def to_proto(self, export_scope=None):
1091    return self.primary.to_proto(export_scope=export_scope)
1092
1093  def _get_cross_replica(self):
1094    device = device_util.canonicalize(device_util.current())
1095    replica = self._device_map.replica_for_device(device)
1096    if replica is None:
1097      return self.primary
1098    return self._values[replica]
1099
1100  def _as_graph_element(self):
1101    # pylint: disable=protected-access
1102    if _enclosing_tpu_context() is None:
1103      if distribution_strategy_context.in_cross_replica_context():
1104        return self.primary._as_graph_element()
1105      return self._get()._as_graph_element()
1106    return None
1107
1108  def _gather_saveables_for_checkpoint(self):
1109    """Overrides Trackable method.
1110
1111    This allows both name-based and object-based save and restore of
1112    MirroredVariables.
1113
1114    Returns:
1115      A dictionary mapping attribute names to `SaveableObject` factories.
1116    """
1117    def _saveable_factory(name=self._common_name):
1118      return _MirroredSaveable(self, self.primary, name)
1119    return {trackable.VARIABLE_VALUE_KEY: _saveable_factory}
1120
1121  def _should_act_as_resource_variable(self):
1122    """Pass resource_variable_ops.is_resource_variable check."""
1123    pass
1124
1125  # Needed to pass ResourceVariable checks.
1126  @property
1127  def op(self):
1128    return self.primary.op
1129
1130  # pylint: disable=protected-access
1131  @property
1132  def _save_slice_info(self):
1133    return self.primary._save_slice_info
1134
1135  def _get_save_slice_info(self):
1136    return self.primary._get_save_slice_info()
1137
1138  def _set_save_slice_info(self, save_slice_info):
1139    return self.primary._set_save_slice_info(save_slice_info)
1140  # pylint: enable=protected-access
1141
1142  @property
1143  def _in_graph_mode(self):
1144    return self.primary._in_graph_mode   # pylint: disable=protected-access
1145
1146  def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
1147    """Converts a variable to a tensor."""
1148    # pylint: disable=protected-access
1149    if _enclosing_tpu_context() is None:
1150      return self._get()._dense_var_to_tensor(dtype, name, as_ref)
1151    # pylint: enable=protected-access
1152    if dtype is not None and dtype != self.dtype:
1153      return math_ops.cast(self.read_value(), dtype)
1154    if as_ref:
1155      return self.handle
1156    else:
1157      return self.read_value()
1158
1159  def is_initialized(self, name=None):
1160    """Identifies if all the component variables are initialized.
1161
1162    Args:
1163      name: Name of the final `logical_and` op.
1164
1165    Returns:
1166      The op that evaluates to True or False depending on if all the
1167      component variables are initialized.
1168    """
1169    # TODO(jhseu): Do we need TPU context implementation?
1170
1171    result = self.primary.is_initialized()
1172    # We iterate through the list of values except the last one to allow us to
1173    # name the final `logical_and` op the same name that is passed by the user
1174    # to the `is_initialized` op. For distributed variables, the
1175    # `is_initialized` op is a `logical_and` op.
1176    for v in self._values[1:-1]:
1177      result = math_ops.logical_and(result, v.is_initialized())
1178    result = math_ops.logical_and(result, self._values[-1].is_initialized(),
1179                                  name=name)
1180    return result
1181
1182
1183# Register a conversion function which reads the value of the variable,
1184# allowing instances of the class to be used as tensors.
1185def _tensor_conversion_tpu_mirrored(var, dtype=None, name=None, as_ref=False):
1186  return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref)  # pylint: disable=protected-access
1187
1188
1189ops.register_tensor_conversion_function(TPUMirroredVariable,
1190                                        _tensor_conversion_tpu_mirrored)
1191ops.register_dense_tensor_like_type(TPUMirroredVariable)
1192
1193
1194class _SyncOnReadSaveable(saver.BaseSaverBuilder.SaveableObject):
1195  """Class for defining how to restore a SyncOnReadVariable."""
1196
1197  def __init__(self, sync_on_read_variable, name):
1198    self._sync_on_read_variable = sync_on_read_variable
1199    # We use a callable so that we don't have to evaluate this expression
1200    # in the case where we are trying to restore instead of save.
1201    def tensor():
1202      strategy = sync_on_read_variable._distribute_strategy  # pylint: disable=protected-access
1203      return strategy.extended.read_var(sync_on_read_variable)
1204
1205    spec = saver.BaseSaverBuilder.SaveSpec(
1206        tensor=tensor,
1207        slice_spec="",
1208        name=name,
1209        dtype=sync_on_read_variable.dtype)
1210    super(_SyncOnReadSaveable, self).__init__(tensor, [spec], name)
1211
1212  def restore(self, restored_tensors, restored_shapes):
1213    """Restore the same value into all variables."""
1214    tensor, = restored_tensors
1215    return self._sync_on_read_variable.assign(tensor)
1216
1217
1218def _assert_replica_context(strategy):
1219  replica_context = distribution_strategy_context.get_replica_context()
1220  if not replica_context:
1221    raise RuntimeError(
1222        "Replica-local variables may only be assigned in a replica context.")
1223  if replica_context.strategy is not strategy:
1224    raise RuntimeError(
1225        "Replica-local variables may only be assigned in a replica context.")
1226
1227
1228class SyncOnReadVariable(DistributedVariable, PerReplica, trackable.Trackable):
1229  """Holds a map from device to variables whose values are reduced on save."""
1230
1231  def __init__(
1232      self, strategy, device_map, values, aggregation, logical_device=None):
1233    self._aggregation = aggregation
1234    super(SyncOnReadVariable, self).__init__(
1235        strategy, device_map, values, logical_device=logical_device)
1236
1237  def assign_sub(self, *args, **kwargs):
1238    _assert_replica_context(self._distribute_strategy)
1239    return self.get().assign_sub(*args, **kwargs)
1240
1241  def assign_add(self, *args, **kwargs):
1242    _assert_replica_context(self._distribute_strategy)
1243    return self.get().assign_add(*args, **kwargs)
1244
1245  def assign(self, *args, **kwargs):
1246    if distribution_strategy_context.in_cross_replica_context():
1247      # To preserve the sum across save and restore, we have to divide the
1248      # total across all devices when restoring a variable that was summed
1249      # when saving.
1250      tensor = args[0]
1251      if self._aggregation == vs.VariableAggregation.SUM:
1252        tensor *= 1. / len(self.devices)
1253      return control_flow_ops.group(tuple(
1254          _assign_on_device(v.device, v, tensor) for v in self._values))
1255    else:
1256      _assert_replica_context(self._distribute_strategy)
1257      return self.get().assign(*args, **kwargs)
1258
1259  @property
1260  def aggregation(self):
1261    return self._aggregation
1262
1263  def _get_cross_replica(self):
1264    if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
1265      return self.primary
1266    return self._distribute_strategy.reduce(
1267        reduce_util.ReduceOp.from_variable_aggregation(self.aggregation), self)
1268
1269  def _as_graph_element(self):
1270    # pylint: disable=protected-access
1271    if distribution_strategy_context.in_cross_replica_context():
1272      return self._get_cross_replica()
1273    return self.get()._as_graph_element()
1274
1275  def _gather_saveables_for_checkpoint(self):
1276    """Overrides Trackable method.
1277
1278    This allows both name-based and object-based save and restore of
1279    `SyncOnReadVariable`s.
1280
1281    Returns:
1282      A dictionary mapping attribute names to `SaveableObject` factories.
1283    """
1284    def _saveable_factory(name=self._common_name):
1285      return _SyncOnReadSaveable(self, name)
1286    return {trackable.VARIABLE_VALUE_KEY: _saveable_factory}
1287
1288
1289# Register a conversion function for SyncOnReadVariable which allows as_ref to
1290# be true.
1291def _tensor_conversion_sync_on_read(var, dtype=None, name=None, as_ref=False):
1292  return ops.internal_convert_to_tensor(
1293      var.get(), dtype=dtype, name=name, as_ref=as_ref)
1294
1295
1296ops.register_tensor_conversion_function(SyncOnReadVariable,
1297                                        _tensor_conversion_sync_on_read)
1298
1299
1300def regroup(device_map, values, wrap_class=PerReplica):
1301  """Makes a nest per-replica into a nest of PerReplica/Mirrored values."""
1302  assert isinstance(device_map, DeviceMap)
1303  assert len(values) == device_map.num_replicas_in_graph
1304  v0 = values[0]
1305
1306  if isinstance(v0, list):
1307    for v in values[1:]:
1308      assert isinstance(v, list)
1309      assert len(v) == len(v0), ("len(v) == %d, len(v0) == %d, v: %s, v0: %s" %
1310                                 (len(v), len(v0), v, v0))
1311    return [regroup(device_map, tuple(v[i] for v in values), wrap_class)
1312            for i in range(len(v0))]
1313
1314  if isinstance(v0, tuple):
1315    for v in values[1:]:
1316      assert isinstance(v, tuple)
1317      assert len(v) == len(v0)
1318    regrouped_tuple = tuple(
1319        regroup(device_map, tuple(v[i] for v in values), wrap_class)
1320        for i in range(len(v0)))
1321    if hasattr(v0, "_fields"):
1322      # This tuple is in fact a namedtuple! Create a new namedtuple instance
1323      # and initialize it with the regrouped values:
1324      assert hasattr(type(v0), "_make")
1325      return type(v0)._make(regrouped_tuple)
1326    else:
1327      return regrouped_tuple
1328
1329  if isinstance(v0, dict):
1330    v0keys = set(v0.keys())
1331    for v in values[1:]:
1332      assert isinstance(v, dict), ("v[0]: %r  v[i]: %r" % (v0, v))
1333      assert set(v.keys()) == v0keys, ("v[0].keys: %s  v[i].keys: %s" %
1334                                       (v0keys, set(v.keys())))
1335    return {key: regroup(device_map, tuple(v[key] for v in values), wrap_class)
1336            for key in v0keys}
1337
1338  # If exactly the same object across all devices, return it unwrapped.
1339  same_id = True
1340  for v in values[1:]:
1341    if v is not v0:
1342      same_id = False
1343      break
1344  # Consider three cases where same_id is true:
1345  # * If v0 is a DistributedVariable (a MirroredVariable or
1346  #   SyncOnReadVariable, and same_id means it is the same across all
1347  #   devices), we want to return it. We check DistributedVariable
1348  #   specifically since it can look like it has a
1349  #   _distributed_container member since its members do.
1350  # * If v0 is a member of a distributed variable, in which case
1351  #   hasattr(v0, "_distributed_container") is true, we want to
1352  #   return the DistributedVariable that contains it using the
1353  #   _distributed_container logic below. This case can trigger
1354  #   same_id when there is only one device.
1355  # * In any other situation, same_id means we return v0.
1356  if same_id and (isinstance(v0, DistributedVariable) or
1357                  not hasattr(v0, "_distributed_container")):
1358    return v0
1359
1360  # Detect the case where each device has a parallel component of the
1361  # same MirroredVariable (or SyncOnReadVariable). In this case we
1362  # want to return the containing MirroredVariable, after a bunch of
1363  # sanity checking. In particular, each component should have the
1364  # same container, and the devices of the variables should match the
1365  # keys of the per-replica dictionary.
1366  if hasattr(v0, "_distributed_container"):
1367    # pylint: disable=protected-access
1368    assert not isinstance(v0, MirroredVariable), (
1369        "ids = %s, values = %s" % ([id(v) for v in values], values))
1370    assert device_map.is_device_in_replica(v0.device, 0), (
1371        "v0.device = %s, device_map = %s" % (v0.device, device_map))
1372    distributed_container = v0._distributed_container()
1373    assert distributed_container is not None
1374    for r, v in enumerate(values[1:]):
1375      assert device_map.is_device_in_replica(v.device, r + 1), (
1376          "v.device = %s, r = %d, device_map = %s" %
1377          (v.device, r + 1, device_map))
1378      assert distributed_container is v._distributed_container()
1379    return distributed_container
1380  # pylint: enable=protected-access
1381
1382  return wrap_class(device_map, values)
1383
1384
1385def select_replica(replica_id, structured):
1386  """Specialize a nest of regular & per-replica values for one replica."""
1387  def _get(x):
1388    return x.values[replica_id] if isinstance(x, DistributedValues) else x
1389
1390  return nest.map_structure(_get, structured)
1391
1392
1393def select_device_mirrored(device, structured):
1394  """Specialize a nest of regular & mirrored values for one device."""
1395  def _get_mirrored(x):
1396    if isinstance(x, DistributedValues):
1397      if not isinstance(x, Mirrored):
1398        raise TypeError(
1399            "Expected value to be mirrored across replicas: %s in %s." %
1400            (x, structured))
1401      return x.get(device)
1402    else:
1403      return x
1404
1405  return nest.map_structure(_get_mirrored, structured)
1406
1407
1408def update_regroup(extended, device_map, updates, group):
1409  """Regroup for an update, with dependencies to ensure all updates execute."""
1410  # TODO(josh11b): Replace "Mirrored" here with a function that does the following
1411  # so we can avoid all these nest operations.
1412  regrouped = regroup(device_map, updates, Mirrored)
1413  if not group:
1414    return nest.map_structure(extended._local_results, regrouped)  # pylint: disable=protected-access
1415  grouped_flat = []
1416  for u in nest.flatten(regrouped):
1417    if isinstance(u, DistributedValues):
1418      g = extended._group(u)  # pylint: disable=protected-access
1419      if u.is_tensor_like:
1420        # Make sure we run all updates. Without this, something like
1421        # session.run(extended.update(...)) may only update one replica.
1422        values = []
1423        for d in u.devices:
1424          with ops.device(d), ops.control_dependencies([g]):
1425            values.append(array_ops.identity(u.get(d)))
1426        g = Mirrored(u.device_map, values)
1427    else:
1428      g = u
1429    grouped_flat.append(g)
1430  return nest.pack_sequence_as(regrouped, grouped_flat)
1431
1432
1433def value_container(val):
1434  """Returns the container that this per-replica `value` belongs to.
1435
1436  Args:
1437    val: A value returned by `call_for_each_replica()` or a variable
1438      created in `scope()`.
1439
1440  Returns:
1441    A container that `value` belongs to.
1442    If value does not belong to any container (including the case of
1443    container having been destroyed), returns the value itself.
1444  """
1445  if (hasattr(val, "_distributed_container") and
1446      # DistributedVariable has _distributed_container defined
1447      # but we don't want to return it.
1448      not isinstance(val, DistributedVariable)):
1449    container = val._distributed_container()  # pylint: disable=protected-access
1450    if container is not None:
1451      return container
1452  return val
1453
1454
1455# TODO(josh11b): Descend from Variable.
1456class AggregatingVariable(trackable.Trackable):
1457  """A wrapper around a variable that aggregates updates across replicas."""
1458
1459  def __init__(self, strategy, v, aggregation):
1460    self._distribute_strategy = strategy
1461    self._v = v
1462    # NOTE: We don't use "_distributed_container" here because we don't want
1463    # to trigger that code path in regroup().
1464    v._aggregating_container = weakref.ref(self)  # pylint: disable=protected-access
1465    self._aggregation = aggregation
1466
1467  def get(self):
1468    return self._v
1469
1470  @property
1471  def distribute_strategy(self):
1472    return self._distribute_strategy
1473
1474  def __getattr__(self, name):
1475    return getattr(self._v, name)
1476
1477  def _assign_func(self, *args, **kwargs):
1478    with _enter_or_assert_strategy(self._distribute_strategy):
1479      f = kwargs.pop("f")
1480      if distribution_strategy_context.in_cross_replica_context():
1481        update_device = distribute_lib.get_update_device()
1482        if update_device is not None:
1483          # We are calling an assign function in an update context.
1484          return f(self._v, *args, **kwargs)
1485
1486        # We are calling an assign function in cross replica context, wrap it in
1487        # an update call.
1488        return self._distribute_strategy.extended.update(
1489            self, f, args=args, kwargs=kwargs)
1490      else:
1491        replica_context = distribution_strategy_context.get_replica_context()
1492        assert replica_context
1493        # We are calling an assign function in replica context.
1494        # We reduce the value we want to assign/add/sub. More details about how
1495        # we handle the different use cases can be found in the _reduce method.
1496        # We call the function with the reduced value.
1497        if self._aggregation == vs.VariableAggregation.NONE:
1498          raise ValueError("You must specify an aggregation method to update a "
1499                           "a variable in replica context.")
1500
1501        def merge_fn(strategy, value, *other_args, **other_kwargs):
1502          v = _apply_aggregation(strategy, value, self._aggregation, self)
1503          return strategy.extended.update(
1504              self, f, args=(v,) + other_args, kwargs=other_kwargs)
1505
1506        return replica_context.merge_call(merge_fn, args=args, kwargs=kwargs)
1507
1508  def assign_sub(self, *args, **kwargs):
1509    assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw)
1510    return self._assign_func(f=assign_sub_fn, *args, **kwargs)
1511
1512  def assign_add(self, *args, **kwargs):
1513    assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw)
1514    return self._assign_func(f=assign_add_fn, *args, **kwargs)
1515
1516  def assign(self, *args, **kwargs):
1517    assign_fn = lambda var, *a, **kw: var.assign(*a, **kw)
1518    return self._assign_func(f=assign_fn, *args, **kwargs)
1519
1520  @property
1521  def aggregation(self):
1522    return self._aggregation
1523
1524  @property
1525  def name(self):
1526    return self._v.name
1527
1528  @property
1529  def dtype(self):
1530    return self._v.dtype
1531
1532  # TODO(josh11b): Test saving & restoring.
1533  def _gather_saveables_for_checkpoint(self):
1534    return {trackable.VARIABLE_VALUE_KEY: self._v}
1535
1536  # pylint: disable=multiple-statements
1537  def __add__(self, o): return self._v + o
1538  def __radd__(self, o): return o + self._v
1539  def __sub__(self, o): return self._v - o
1540  def __rsub__(self, o): return o - self._v
1541  def __mul__(self, o): return self._v * o
1542  def __rmul__(self, o): return o * self._v
1543  def __truediv__(self, o): return self._v / o
1544  def __rtruediv__(self, o): return o / self._v
1545  def __floordiv__(self, o): return self._v // o
1546  def __rfloordiv__(self, o): return o // self._v
1547  def __mod__(self, o): return self._v % o
1548  def __rmod__(self, o): return o % self._v
1549  def __lt__(self, o): return self._v < o
1550  def __le__(self, o): return self._v <= o
1551  def __gt__(self, o): return self._v > o
1552  def __ge__(self, o): return self._v >= o
1553  def __and__(self, o): return self._v & o
1554  def __rand__(self, o): return o & self._v
1555  def __or__(self, o): return self._v | o
1556  def __ror__(self, o): return o | self._v
1557  def __xor__(self, o): return self._v ^ o
1558  def __rxor__(self, o): return o ^ self._v
1559  def __getitem__(self, o): return self._v[o]
1560  def __pow__(self, o, modulo=None): return pow(self._v, o, modulo)
1561  def __rpow__(self, o): return pow(o, self._v)
1562  def __invert__(self): return ~self._v
1563  def __neg__(self): return -self._v
1564  def __abs__(self): return abs(self._v)
1565
1566  def __div__(self, o):
1567    try:
1568      return self._v.__div__(o)
1569    except AttributeError:
1570      # See https://docs.python.org/3/library/constants.html#NotImplemented
1571      return NotImplemented
1572
1573  def __rdiv__(self, o):
1574    try:
1575      return self._v.__rdiv__(o)
1576    except AttributeError:
1577      # See https://docs.python.org/3/library/constants.html#NotImplemented
1578      return NotImplemented
1579
1580  def __matmul__(self, o):
1581    try:
1582      return self._v.__matmul__(o)
1583    except AttributeError:
1584      # See https://docs.python.org/3/library/constants.html#NotImplemented
1585      return NotImplemented
1586
1587  def __rmatmul__(self, o):
1588    try:
1589      return self._v.__rmatmul__(o)
1590    except AttributeError:
1591      # See https://docs.python.org/3/library/constants.html#NotImplemented
1592      return NotImplemented
1593
1594  def __str__(self):
1595    return str(self._v)
1596
1597  def __repr__(self):
1598    return repr(self._v)
1599
1600  def _should_act_as_resource_variable(self):
1601    """Pass resource_variable_ops.is_resource_variable check."""
1602    pass
1603
1604
1605# Register a conversion function which reads the value of the variable,
1606# allowing instances of the class to be used as tensors.
1607def _tensor_conversion_aggregate(var, dtype=None, name=None, as_ref=False):
1608  return ops.internal_convert_to_tensor(
1609      var.get(), dtype=dtype, name=name, as_ref=as_ref)
1610
1611
1612ops.register_tensor_conversion_function(
1613    AggregatingVariable, _tensor_conversion_aggregate)
1614ops.register_dense_tensor_like_type(AggregatingVariable)
1615