1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utility functions used by values.py and ps_values.py."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.distribute import distribute_lib
22from tensorflow.python.distribute import distribution_strategy_context as ds_context
23from tensorflow.python.distribute import reduce_util
24from tensorflow.python.eager import context
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import tensor_util
27from tensorflow.python.ops import control_flow_ops
28from tensorflow.python.ops import math_ops
29from tensorflow.python.ops import variable_scope as vs
30from tensorflow.python.saved_model import save_context
31from tensorflow.python.saved_model import save_options
32from tensorflow.python.training.saving import saveable_object
33
34
35def write_object_proto(var, proto, options):
36  """Update a SavedObject proto for the caller.
37
38  If a DistributedVariable object supports this method, it will be called when
39  saving with a pre-built `SavedObject` proto representing the object, plus an
40  instance of `SaveOptions`. This method is then free to modify that proto
41  instance.
42
43  `DistributedVariable` with `AUTO` or `ON_WRITE` synchronization optionally
44   write out information about their components to the
45   `experimental_distributed_variable_components` field of a
46   `SavedVariable` (depending on the `SaveOptions` variable policy).
47
48  Args:
49    var: The DistributedVariable object.
50    proto: A pre-built `SavedObject` proto for this object. It is assumed this
51      will be a `SavedVariable` instance.
52    options: A `SaveOptions` instance.
53  """
54  if options.experimental_variable_policy._expand_distributed_variables(  # pylint: disable=protected-access
55  ):
56    for var in var.values:
57      var_proto = (
58          proto.variable.experimental_distributed_variable_components.add())
59      var_proto.name = var.name.split(":")[0]
60      var_proto.device = var.device
61
62
63def get_on_write_saveable(var, primary_var, name):
64  """Return saveable spec for AUTO and ON_WRITE variables."""
65  # We use a callable so that we don't have to evaluate this expression
66  # in the case where we are trying to restore instead of save.
67  def tensor():
68    if context.executing_eagerly() and not primary_var.is_initialized():
69      # A SaveSpec tensor value of `None` indicates that the variable is
70      # uninitialized.
71      return None
72    strategy = var.distribute_strategy
73    return strategy.extended.read_var(var)
74
75  spec = saveable_object.SaveSpec(
76      tensor=tensor,
77      slice_spec="",
78      name=name,
79      dtype=var.dtype,
80      device=primary_var.device)
81
82  return tensor, [spec]
83
84
85def get_on_write_restore_ops(var, tensor):
86  """Return restore ops for AUTO and ON_WRITE variables."""
87  packed_var = var._packed_variable  # pylint: disable=protected-access
88  if packed_var is not None:
89    return control_flow_ops.group(
90        tuple(
91            assign_on_device(d, packed_var, tensor)
92            for d in packed_var.devices))
93  return control_flow_ops.group(
94      tuple(
95          assign_on_device(v.device, v, tensor)
96          for v in var.values))
97
98
99def get_on_read_saveable(var, primary_var, name):
100  """Return saveables for ON_READ variable."""
101
102  # We use a callable so that we don't have to evaluate this expression
103  # in the case where we are trying to restore instead of save.
104  def tensor():
105    return var._get_cross_replica()  # pylint: disable=protected-access
106
107  spec = saveable_object.SaveSpec(
108      tensor=tensor,
109      slice_spec="",
110      name=name,
111      dtype=var.dtype,
112      device=primary_var.device)
113
114  return tensor, [spec]
115
116
117def get_on_read_restore_ops(var, tensor, aggregation):
118  """Return restore ops for ON_READ variables."""
119  # To preserve the sum across save and restore, we have to divide the
120  # total across all devices when restoring a variable that was summed
121  # when saving.
122  if aggregation == vs.VariableAggregation.SUM:
123    strategy = var.distribute_strategy
124    tensor = math_ops.cast(tensor / strategy.num_replicas_in_sync,
125                           var.dtype)
126  return control_flow_ops.group(
127      tuple(
128          assign_on_device(v.device, v, tensor)
129          for v in var.values))
130
131
132# Utility function that indicates if you are in an UpdateContext when running
133# in a replica fn.
134def in_replica_update_context():
135  return distribute_lib.get_update_replica_id() is not None
136
137
138def on_write_assign(var, value, use_locking=False, name=None, read_value=True):
139  assign_fn = lambda var, *a, **kw: var.assign(*a, **kw)
140  return var._update(  # pylint: disable=protected-access
141      update_fn=assign_fn,
142      value=value,
143      use_locking=use_locking,
144      name=name,
145      read_value=read_value)
146
147
148def on_write_assign_add(var, value, use_locking=False, name=None,
149                        read_value=True):
150  assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw)
151  return var._update(  # pylint: disable=protected-access
152      update_fn=assign_add_fn,
153      value=value,
154      use_locking=use_locking,
155      name=name,
156      read_value=read_value)
157
158
159def on_write_assign_sub(var, value, use_locking=False, name=None,
160                        read_value=True):
161  assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw)
162  return var._update(  # pylint: disable=protected-access
163      update_fn=assign_sub_fn,
164      value=value,
165      use_locking=use_locking,
166      name=name,
167      read_value=read_value)
168
169
170def assign_on_each_device(var, assign_func, value, read_value):
171  """Update the variable on each replica with the given assign_func and value."""
172  if var._packed_variable is not None:  # pylint: disable=protected-access
173    update = control_flow_ops.group(
174        tuple(
175            assign_func(d, var._packed_variable, value) for d in var._devices))  # pylint: disable=protected-access
176  else:
177    update = control_flow_ops.group(
178        tuple(assign_func(v.device, v, value) for v in var._values))  # pylint: disable=protected-access
179  if not read_value:
180    return update
181  with ops.control_dependencies([update] if update else []):
182    return var.read_value()
183
184
185def on_read_assign_sub_cross_replica(var, value, read_value=True):
186  with ds_context.enter_or_assert_strategy(var.distribute_strategy):
187    if ds_context.in_cross_replica_context():
188      if var.aggregation == vs.VariableAggregation.SUM:
189        raise ValueError(
190            "SyncOnReadVariable does not support `assign_sub` in "
191            "cross-replica context when aggregation is set to "
192            "`tf.VariableAggregation.SUM`.")
193      return assign_on_each_device(var, assign_sub_on_device,
194                                   value, read_value)
195
196
197def on_read_assign_add_cross_replica(var, value, read_value=True):
198  with ds_context.enter_or_assert_strategy(var.distribute_strategy):
199    if ds_context.in_cross_replica_context():
200      if var.aggregation == vs.VariableAggregation.SUM:
201        raise ValueError(
202            "SyncOnReadVariable does not support `assign_add` in "
203            "cross-replica context when aggregation is set to "
204            "`tf.VariableAggregation.SUM`.")
205      return assign_on_each_device(var, assign_add_on_device,
206                                   value, read_value)
207
208
209def on_read_assign_cross_replica(var, value, read_value=True):
210  """Return the value of the variable in cross replica context."""
211  with ds_context.enter_or_assert_strategy(var.distribute_strategy):
212    if ds_context.in_cross_replica_context():
213      # To preserve the sum across save and restore, we have to divide the
214      # total across all devices when restoring a variable that was summed
215      # when saving.
216      tensor = value
217      if var.aggregation == vs.VariableAggregation.SUM:
218        strategy = var._distribute_strategy  # pylint: disable=protected-access
219        tensor = math_ops.cast(tensor / strategy.num_replicas_in_sync,
220                               var.dtype)
221      return assign_on_each_device(var, assign_on_device, tensor,
222                                   read_value)
223
224
225def scatter_sub(var, sparse_delta, use_locking=False, name=None):
226  scatter_sub_fn = lambda var, *a, **kw: var.scatter_sub(*a, **kw)
227  return var._update(  # pylint: disable=protected-access
228      update_fn=scatter_sub_fn,
229      value=sparse_delta,
230      use_locking=use_locking,
231      name=name)
232
233
234def scatter_add(var, sparse_delta, use_locking=False, name=None):
235  scatter_add_fn = lambda var, *a, **kw: var.scatter_add(*a, **kw)
236  return var._update(  # pylint: disable=protected-access
237      update_fn=scatter_add_fn,
238      value=sparse_delta,
239      use_locking=use_locking,
240      name=name)
241
242
243def scatter_mul(var, sparse_delta, use_locking=False, name=None):
244  scatter_mul_fn = lambda var, *a, **kw: var.scatter_mul(*a, **kw)
245  return var._update(  # pylint: disable=protected-access
246      update_fn=scatter_mul_fn,
247      value=sparse_delta,
248      use_locking=use_locking,
249      name=name)
250
251
252def scatter_div(var, sparse_delta, use_locking=False, name=None):
253  scatter_div_fn = lambda var, *a, **kw: var.scatter_div(*a, **kw)
254  return var._update(  # pylint: disable=protected-access
255      update_fn=scatter_div_fn,
256      value=sparse_delta,
257      use_locking=use_locking,
258      name=name)
259
260
261def scatter_min(var, sparse_delta, use_locking=False, name=None):
262  scatter_min_fn = lambda var, *a, **kw: var.scatter_min(*a, **kw)
263  return var._update(  # pylint: disable=protected-access
264      update_fn=scatter_min_fn,
265      value=sparse_delta,
266      use_locking=use_locking,
267      name=name)
268
269
270def scatter_max(var, sparse_delta, use_locking=False, name=None):
271  scatter_max_fn = lambda var, *a, **kw: var.scatter_max(*a, **kw)
272  return var._update(  # pylint: disable=protected-access
273      update_fn=scatter_max_fn,
274      value=sparse_delta,
275      use_locking=use_locking,
276      name=name)
277
278
279def scatter_update(var, sparse_delta, use_locking=False, name=None):
280  scatter_update_fn = lambda var, *a, **kw: var.scatter_update(*a, **kw)
281  return var._update(  # pylint: disable=protected-access
282      update_fn=scatter_update_fn,
283      value=sparse_delta,
284      use_locking=use_locking,
285      name=name)
286
287
288def get_current_replica_id_as_int():
289  """Returns the current replica ID as an integer, or `None`."""
290  replica_context = ds_context.get_replica_context()
291  if replica_context:
292    replica_id = replica_context._replica_id  # pylint: disable=protected-access
293    if not isinstance(replica_id, int):
294      replica_id = tensor_util.constant_value(replica_id)
295  else:
296    replica_id = distribute_lib.get_update_replica_id()
297  return replica_id
298
299
300def assign_on_device(device, variable, tensor):
301  with ops.device(device):
302    return variable.assign(tensor)
303
304
305def assign_add_on_device(device, variable, tensor):
306  with ops.device(device):
307    return variable.assign_add(tensor)
308
309
310def assign_sub_on_device(device, variable, tensor):
311  with ops.device(device):
312    return variable.assign_sub(tensor)
313
314
315def assert_replica_context(strategy):
316  replica_context = ds_context.get_replica_context()
317  if not replica_context:
318    raise RuntimeError(
319        "Replica-local variables may only be assigned in a replica context.")
320  if replica_context.strategy is not strategy:
321    raise RuntimeError(
322        "Replica-local variables may only be assigned in a replica context.")
323
324
325def apply_aggregation(strategy, value, aggregation, destinations):
326  if aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
327    return strategy.extended.broadcast_to(
328        strategy.experimental_local_results(value)[0],
329        destinations=destinations)
330  reduce_op = reduce_util.ReduceOp.from_variable_aggregation(aggregation)
331  return strategy.extended.reduce_to(reduce_op, value, destinations)
332
333
334aggregation_error_msg = (
335    "You must specify an aggregation method to update a "
336    "{variable_type} in Replica Context. You can do so by passing "
337    "an explicit value for argument `aggregation` to tf.Variable(..)."
338    "e.g. `tf.Variable(..., aggregation=tf.VariableAggregation.SUM)`"
339    "`tf.VariableAggregation` lists the possible aggregation methods."
340    "This is required because {variable_type} should always be "
341    "kept in sync. When updating them or assigning to them in a "
342    "replica context, we automatically try to aggregate the values "
343    "before updating the variable. For this aggregation, we need to "
344    "know the aggregation method. "
345    "Another alternative is to not try to update such "
346    "{variable_type} in replica context, but in cross replica "
347    "context. You can enter cross replica context by calling "
348    "`tf.distribute.get_replica_context().merge_call(merge_fn, ..)`."
349    "Inside `merge_fn`, you can then update the {variable_type} "
350    "using `tf.distribute.StrategyExtended.update()`.")
351
352
353scatter_error_msg = ("{op_name} is only supported for mirrored "
354                     "variable (variable created within certain "
355                     "`tf.distribute.Strategy` scope) with NONE or "
356                     "`ONLY_FIRST_REPLICA` aggregation, got: {aggregation}.")
357
358
359def is_saving_non_distributed():
360  """Returns whether we're saving a non-distributed version of the model.
361
362  It returns True iff we are in saving context and are saving a non-distributed
363  version of the model. That is, SaveOptions.experimental_variable_policy is
364  NONE.
365
366  Returns:
367    A boolean.
368  """
369  if not save_context.in_save_context():
370    return False
371  options = save_context.get_save_options()
372  return (options.experimental_variable_policy !=
373          save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES)
374
375
376def mark_as_unsaveable():
377  """Marks the function as unsaveable if not inside save context."""
378  if ops.inside_function() and not save_context.in_save_context():
379    ops.get_default_graph().mark_as_unsaveable("""
380ConcreteFunction that uses distributed variables in certain way cannot be saved.
381If you're saving with
382
383tf.saved_model.save(..., signatures=f.get_concrete_function())
384
385do
386
387@tf.function(input_signature=...)
388def f_with_input_signature():
389  ...
390
391tf.saved_model.save(..., signatures=f_with_input_signature)`
392
393instead.""")
394