1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Class implementing utilities used by tf.distribute.Strategy."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from collections import abc
22
23from tensorflow.python.distribute import tpu_values as tpu_values_lib
24from tensorflow.python.distribute import values as values_lib
25from tensorflow.python.eager import context
26from tensorflow.python.eager import tape
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import tensor_util
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import control_flow_ops
31from tensorflow.python.ops import variable_scope as vs
32from tensorflow.python.util import nest
33
34
35def regroup(values, wrap_class=values_lib.PerReplica, always_wrap=False):
36  """Makes a nest per-replica into a nest of PerReplica/Mirrored values.
37
38  Args:
39    values: Values to regroup
40    wrap_class: Class that `values` be wrapped in.
41    always_wrap: Always wrap the `values` in `wrap_class` even if the values
42        are the same except for DistributeVariable.
43  Returns:
44    Wrapped `values`.
45  """
46  v0 = values[0]
47
48  if isinstance(v0, list):
49    for v in values[1:]:
50      assert isinstance(v, list)
51      assert len(v) == len(v0), ("len(v) == %d, len(v0) == %d, v: %s, v0: %s" %
52                                 (len(v), len(v0), v, v0))
53    return [
54        regroup(tuple(v[i] for v in values), wrap_class, always_wrap)
55        for i in range(len(v0))
56    ]
57
58  if isinstance(v0, tuple):
59    for v in values[1:]:
60      assert isinstance(v, tuple)
61      assert len(v) == len(v0)
62    regrouped_tuple = tuple(
63        regroup(tuple(v[i] for v in values), wrap_class, always_wrap)
64        for i in range(len(v0)))
65    if hasattr(v0, "_fields"):
66      # This tuple is in fact a namedtuple! Create a new namedtuple instance
67      # and initialize it with the regrouped values:
68      assert hasattr(v0, "_make")
69      return v0._make(regrouped_tuple)
70    else:
71      return regrouped_tuple
72
73  if isinstance(v0, abc.Mapping):
74    v0keys = v0.keys()
75    for v in values[1:]:
76      assert isinstance(v, abc.Mapping), ("v[0]: %r  v[i]: %r" % (v0, v))
77      assert set(v.keys()) == set(v0keys), ("v[0].keys: %s  v[i].keys: %s" %
78                                            (set(v0keys), set(v.keys())))
79    # Use the actual type in case it is a class inherited from a dict.
80    return type(v0)({
81        key: regroup(tuple(v[key] for v in values),
82                     wrap_class, always_wrap)
83        for key in v0keys
84    })
85
86  # If exactly the same object across all devices, return it unwrapped.
87  same_id = True
88  for v in values[1:]:
89    if v is not v0:
90      same_id = False
91      break
92  # Consider three cases where same_id is true:
93  # * If v0 is a DistributedVariable (a MirroredVariable or
94  #   SyncOnReadVariable, and same_id means it is the same across all
95  #   devices), we want to return it. We check DistributedVariable
96  #   specifically since it can look like it has a
97  #   _distributed_container member since its members do.
98  if same_id and isinstance(v0, values_lib.DistributedVariable):
99    return v0
100  # * If v0 is a member of a distributed variable, in which case
101  #   hasattr(v0, "_distributed_container") is true, we want to
102  #   return the DistributedVariable that contains it using the
103  #   _distributed_container logic below. This case can trigger
104  #   same_id when there is only one device.
105  # * In any other situation, same_id means we return v0 unless `always_wrap` is
106  #   true.
107  if same_id and not always_wrap and not hasattr(v0, "_distributed_container"):
108    return v0
109
110  # Detect the case where each device has a parallel component of the
111  # same MirroredVariable (or SyncOnReadVariable). In this case we
112  # want to return the containing MirroredVariable, after a bunch of
113  # sanity checking. In particular, each component should have the
114  # same container, and the devices of the variables should match the
115  # keys of the per-replica dictionary.
116  if hasattr(v0, "_distributed_container"):
117    # pylint: disable=protected-access
118    assert not isinstance(v0, values_lib.MirroredVariable), (
119        "ids = %s, values = %s" % ([id(v) for v in values], values))
120    distributed_container = v0._distributed_container()
121    assert distributed_container is not None
122    for v in values[1:]:
123      assert distributed_container is v._distributed_container()
124    return distributed_container
125  # pylint: enable=protected-access
126
127  return wrap_class(values)
128
129
130def select_replica(replica_id, structured):
131  """Specialize a nest of regular & per-replica values for one replica."""
132
133  def _get(x):
134    # `DistributedValues` would be sliced according to replica unless it is a
135    # `DistributedVariable` because `DistributedVariable` can be handled
136    # directly in the replica context.
137    if (isinstance(x, values_lib.DistributedVariable) or
138        not isinstance(x, values_lib.DistributedValues)):
139      return x
140    else:
141      return x.values[replica_id]
142
143  return nest.map_structure(_get, structured)
144
145
146def select_replica_mirrored(replica_id, structured):
147  """Specialize a nest of regular & mirrored values for one replica."""
148  assert_mirrored(structured)
149  return select_replica(replica_id, structured)
150
151
152def assert_mirrored(structured):
153  """Raises if the structured is not composed of mirrored or regular values."""
154
155  def _assert_mirrored(x):
156    if isinstance(x, values_lib.DistributedValues) and not is_mirrored(x):
157      raise TypeError(
158          "Expected value to be mirrored across replicas: %s in %s." %
159          (x, structured))
160
161  nest.map_structure(_assert_mirrored, structured)
162
163
164def update_regroup(extended, updates, group):
165  """Regroup for an update, with dependencies to ensure all updates execute."""
166  if not group:
167    regrouped = regroup(updates, values_lib.Mirrored)
168    return nest.map_structure(extended._local_results, regrouped)  # pylint: disable=protected-access
169
170  def _make_grouped_mirrored(values):
171    """Convert per-replica list `values` into Mirrored type with grouping."""
172    if len(values) == 1:
173      return values_lib.Mirrored(values)
174
175    # Make sure we run all updates. Without this, something like
176    # session.run(extended.update(...)) may only update one replica.
177    g = control_flow_ops.group(values)
178
179    # If values is just ops, the grouping is enough. Everything in values
180    # should have the same type, since we expect every replica to be performing
181    # the same computation.
182    if not all(tensor_util.is_tf_type(v) for v in values):
183      return g
184
185    # Otherwise we need tensors with the same values as `values`, but
186    # that have a dependency on `g`.
187    with_dep = []
188    for v in values:
189      with ops.device(v.device), ops.control_dependencies([g]):
190        with_dep.append(array_ops.identity(v))
191
192    return values_lib.Mirrored(with_dep)
193
194  return regroup(updates, _make_grouped_mirrored)
195
196
197def value_container(val):
198  """Returns the container that this per-replica `value` belongs to.
199
200  Args:
201    val: A value returned by `call_for_each_replica()` or a variable created in
202      `scope()`.
203
204  Returns:
205    A container that `value` belongs to.
206    If value does not belong to any container (including the case of
207    container having been destroyed), returns the value itself.
208  """
209  if (hasattr(val, "_distributed_container") and
210      # DistributedVariable has _distributed_container defined
211      # but we don't want to return it.
212      not isinstance(val, values_lib.DistributedVariable)):
213    container = val._distributed_container()  # pylint: disable=protected-access
214    if container is not None:
215      return container
216  return val
217
218
219def is_distributed_variable(v):
220  """Determine if a variable is ds variable or TPU mirrored variable."""
221  return isinstance(v, values_lib.DistributedVariable)
222
223
224def _validate_colocate_extended(v, extended):
225  variable_strategy = v._distribute_strategy  # pylint: disable=protected-access
226  if variable_strategy.extended is not extended:
227    raise ValueError(
228        "`colocate_vars_with` must only be passed a variable created in this "
229        "tf.distribute.Strategy.scope(), not %s created in scope: %s" %
230        (v, variable_strategy))
231
232
233def validate_colocate_distributed_variable(v, extended):
234  if not isinstance(v, values_lib.DistributedVariable):
235    raise ValueError(
236        "`colocate_vars_with` must only be passed a variable created in this "
237        "tf.distribute.Strategy.scope(), not: %r" % (v,))
238  _validate_colocate_extended(v, extended)
239
240
241def validate_colocate(v, extended):
242  if not hasattr(v, "_distribute_strategy"):
243    raise ValueError(
244        "`colocate_vars_with` must only be passed a variable created in this "
245        "tf.distribute.Strategy.scope(), not: %r" % (v,))
246  _validate_colocate_extended(v, extended)
247
248
249# Variable creation function for sync strategies.
250def _validate_synchronization(kwargs):
251  """Validate that given synchronization value is valid."""
252  synchronization = kwargs.get("synchronization",
253                               vs.VariableSynchronization.AUTO)
254  if synchronization == vs.VariableSynchronization.NONE:
255    raise ValueError(
256        "`NONE` variable synchronization mode is not supported with "
257        "tf.distribute strategy. Please change the `synchronization` for "
258        "variable: " + str(kwargs["name"]))
259  if synchronization not in (vs.VariableSynchronization.ON_READ,
260                             vs.VariableSynchronization.ON_WRITE,
261                             vs.VariableSynchronization.AUTO):
262    raise ValueError(
263        "Invalid variable synchronization mode: %s for variable: %s" %
264        (synchronization, kwargs["name"]))
265  if synchronization == vs.VariableSynchronization.AUTO:
266    return vs.VariableSynchronization.ON_WRITE
267  return synchronization
268
269
270def _validate_aggregation(kwargs):
271  aggregation = kwargs.get("aggregation", vs.VariableAggregation.NONE)
272
273  if aggregation not in (vs.VariableAggregation.NONE,
274                         vs.VariableAggregation.SUM,
275                         vs.VariableAggregation.MEAN,
276                         vs.VariableAggregation.ONLY_FIRST_REPLICA):
277    raise ValueError("Invalid variable aggregation mode: %s for variable: %s" %
278                     (aggregation, kwargs["name"]))
279  return aggregation
280
281
282def create_mirrored_variable(strategy, real_mirrored_creator, class_mapping,
283                             policy_mapping, **kwargs):
284  """Create distributed variables with given synchronization and aggregation."""
285  # Figure out what collections this variable should be added to.
286  # We'll add the MirroredVariable to those collections instead.
287  var_collections = kwargs.pop("collections", None)
288  if var_collections is None:
289    var_collections = [ops.GraphKeys.GLOBAL_VARIABLES]
290  kwargs["collections"] = []
291
292  synchronization = _validate_synchronization(kwargs)
293  # Update synchronization in kwargs in case it's AUTO, which is converted to
294  # ON_WRITE.
295  kwargs["synchronization"] = synchronization
296  aggregation = _validate_aggregation(kwargs)
297  use_var_policy = getattr(strategy.extended, "_use_var_policy", False)
298
299  # Ignore user-specified caching device, not needed for mirrored variables.
300  kwargs.pop("caching_device", None)
301
302  # TODO(josh11b,apassos): It would be better if variable initialization
303  # was never recorded on the tape instead of having to do this manually
304  # here.
305  with tape.stop_recording():
306    value_list = real_mirrored_creator(**kwargs)
307    if use_var_policy:
308      var_policy_cls = policy_mapping.get(synchronization)
309      var_policy = var_policy_cls(aggregation=aggregation)
310      var_cls = class_mapping.get("VariableClass")
311      result = var_cls(strategy, value_list, aggregation, var_policy=var_policy)
312    else:
313      var_cls = class_mapping.get(synchronization)
314      result = var_cls(strategy, value_list, aggregation)
315
316  # Add the wrapped variable to the requested collections.
317  # The handling of eager mode and the global step matches
318  # ResourceVariable._init_from_args().
319  if not context.executing_eagerly():
320    g = ops.get_default_graph()
321    # If "trainable" is True, next_creator() will add the member variables
322    # to the TRAINABLE_VARIABLES collection, so we manually remove
323    # them and replace with the MirroredVariable. We can't set
324    # "trainable" to False for next_creator() since that causes functions
325    # like implicit_gradients to skip those variables.
326    if kwargs.get("trainable", True):
327      var_collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
328      l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
329      for value in value_list:
330        for i, trainable_variable in enumerate(l):
331          if value is trainable_variable:
332            del l[i]
333            break
334
335    g.add_to_collections(var_collections, result)
336  elif ops.GraphKeys.GLOBAL_STEP in var_collections:
337    ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result)
338
339  return result
340
341
342# Utility functions
343# Return True if the Value is Mirrored or the Variable is replicated and kept in
344# sync.
345def is_mirrored(val):
346  if isinstance(val, values_lib.DistributedVariable):
347    if val._policy:  # pylint: disable=protected-access
348      return val._policy._is_mirrored()  # pylint: disable=protected-access
349  return isinstance(val, values_lib.Mirrored)
350
351
352def is_sync_on_read(val):
353  if isinstance(val, values_lib.DistributedVariable):
354    if val._policy:  # pylint: disable=protected-access
355      return not val._policy._is_mirrored()  # pylint: disable=protected-access
356  return not isinstance(val, values_lib.Mirrored)
357
358# The following mapping indicates the policy that you must use for a given
359# variable `synchronization` and `aggregation` pair.
360# OnWritePolicy is used for:
361# (synchronization=Auto, aggregation=NONE,SUM,MEAN,ONLY_FIRST_REPLICA)
362# (synchronization=ON_WRITE, aggregation=NONE,SUM,MEAN,ONLY_FIRST_REPLICA)
363# OnReadPolicy is used for:
364# (synchronization=ON_READ, aggregation=NONE,SUM,MEAN,ONLY_FIRST_REPLICA)
365VARIABLE_POLICY_MAPPING = {
366    vs.VariableSynchronization.ON_WRITE: values_lib.OnWritePolicy,
367    vs.VariableSynchronization.ON_READ: values_lib.OnReadPolicy,
368}
369
370VARIABLE_CLASS_MAPPING = {
371    "VariableClass": values_lib.DistributedVariable,
372    vs.VariableSynchronization.ON_WRITE: values_lib.MirroredVariable,
373    vs.VariableSynchronization.ON_READ: values_lib.SyncOnReadVariable,
374}
375
376TPU_VARIABLE_POLICY_MAPPING = {
377    vs.VariableSynchronization.ON_WRITE: tpu_values_lib.TPUOnWritePolicy,
378    vs.VariableSynchronization.ON_READ: tpu_values_lib.TPUOnReadPolicy,
379}
380
381TPU_VARIABLE_CLASS_MAPPING = {
382    "VariableClass": tpu_values_lib.TPUDistributedVariable,
383    vs.VariableSynchronization.ON_WRITE: tpu_values_lib.TPUMirroredVariable,
384    vs.VariableSynchronization.ON_READ: tpu_values_lib.TPUSyncOnReadVariable,
385}
386