1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Various classes representing TPU distributed values.
16
17Note that the tests are in values_test.py .
18
19"""
20
21from __future__ import absolute_import
22from __future__ import division
23from __future__ import print_function
24
25import contextlib
26
27from tensorflow.python.distribute import packed_distributed_variable as packed
28from tensorflow.python.distribute import tpu_util
29from tensorflow.python.distribute import values
30from tensorflow.python.distribute import values_util
31from tensorflow.python.eager import context
32from tensorflow.python.eager import tape
33from tensorflow.python.framework import ops
34from tensorflow.python.ops import gen_resource_variable_ops
35from tensorflow.python.ops import math_ops
36from tensorflow.python.ops import variable_scope
37
38
39@contextlib.contextmanager
40def _maybe_enter_graph(tensor):
41  # Note: might have an eager tensor but not be executing eagerly when
42  # building functions.
43  if (context.executing_eagerly() or isinstance(tensor, ops.EagerTensor) or
44      ops.has_default_graph()):
45    yield
46  else:
47    with tensor.graph.as_default():
48      yield
49
50
51@contextlib.contextmanager
52def _maybe_on_device(var):
53  # Add a device scope for packed variables.
54  if isinstance(var, packed.PackedVarAndDevice):
55    with ops.device(var.device):
56      yield
57  else:
58    yield
59
60
61def _make_raw_assign_fn(raw_assign_fn):  # pylint: disable=missing-docstring
62
63  def assign_fn(var, value, use_locking=False, name=None, read_value=True):  # pylint: disable=missing-docstring
64    del use_locking  # Unused.
65
66    handle = var.handle
67    with _maybe_enter_graph(handle), _maybe_on_device(var):
68      op = raw_assign_fn(
69          handle, ops.convert_to_tensor(value, dtype=var.dtype), name=name)
70      with ops.control_dependencies([op]):
71        return var._read_variable_op() if read_value else op  # pylint: disable=protected-access
72
73  return assign_fn
74
75
76class TPUVariableMixin(object):
77  """Mixin for TPU variables."""
78
79  def __init__(self, *args, **kwargs):
80    super(TPUVariableMixin, self).__init__(*args, **kwargs)
81
82    # Handle ID is needed for `get_replicated_var_handle` to cache the variables
83    # correctly since in eager mode different variables can have the same name.
84    if ops.executing_eagerly_outside_functions():
85      self._handle_id = self._common_name + "_" + str(id(self._primary))
86    else:
87      self._handle_id = self._common_name
88
89  def __getattr__(self, name):
90    if tpu_util.enclosing_tpu_context() is None:
91      return super(TPUVariableMixin, self).__getattr__(name)
92    else:
93      raise AttributeError(
94          "'{}' not accessible within a TPU context.".format(name))
95
96  def get(self):
97    if tpu_util.enclosing_tpu_context() is None:
98      return super(TPUVariableMixin, self).get()
99    else:
100      raise NotImplementedError(
101          "`TPUVariableMixin.get()` is not supported within a TPU context.")
102
103  def _get_as_operand(self):
104    return self.read_value()
105
106  def _is_mirrored(self):
107    raise NotImplementedError(
108        "`TPUVariableMixin._is_mirrored()` must be implemented by subclasses.")
109
110  @property
111  def handle(self):
112    """The handle by which this variable can be accessed."""
113    # If we're in a tpu.rewrite(), return the replicated handle.
114    tpu_context = tpu_util.enclosing_tpu_context()
115    if tpu_context is None or context.executing_eagerly():
116      var = self._get_on_device_or_primary()
117      if isinstance(var, packed.PackedVarAndDevice):
118        return var.on_device_handle()
119      else:
120        return var.handle
121    else:
122      is_packed = self._packed_var is not None
123      val = self._values
124      if is_packed:
125        val = [self._packed_var]
126
127      return tpu_context.get_replicated_var_handle(self._handle_id, val,
128                                                   self._is_mirrored(),
129                                                   is_packed)
130
131  @property
132  def device(self):
133    return self.handle.device
134
135  def _read_variable_op(self):
136    """Reads the value of this variable."""
137    if self.trainable:
138      tape.variable_accessed(self)
139
140    handle = self.handle
141    if getattr(handle, "is_packed", False):
142      # Add a device scope for a packed variable handle.
143      with ops.device(self._get_on_device_or_primary().device):
144        return gen_resource_variable_ops.read_variable_op(handle, self.dtype)
145    else:
146      return gen_resource_variable_ops.read_variable_op(handle, self.dtype)
147
148  def read_value(self):
149    if tpu_util.enclosing_tpu_context() is None:
150      return super(TPUVariableMixin, self).read_value()
151    else:
152      return self._read_variable_op()
153
154  def value(self):
155    if tpu_util.enclosing_tpu_context() is None:
156      return super(TPUVariableMixin, self).value()
157    else:
158      return self._read_variable_op()
159
160  def _as_graph_element(self):
161    if tpu_util.enclosing_tpu_context() is None:
162      return super(TPUVariableMixin, self)._as_graph_element()  # pylint: disable=protected-access
163    else:
164      return None
165
166  @property
167  def op(self):
168    if values_util.is_saving_non_distributed():
169      return self._primary.op
170    return values.DistributedVarOp(self._primary.op.name,
171                                   self._primary.op.graph,
172                                   self._primary.op.traceback,
173                                   self._primary.op.type)
174
175  def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
176    """Converts a variable to a tensor."""
177    # pylint: disable=protected-access
178    if tpu_util.enclosing_tpu_context() is None:
179      return super(TPUVariableMixin, self)._dense_var_to_tensor(
180          dtype=dtype, name=name, as_ref=as_ref)
181    # pylint: enable=protected-access
182    elif dtype is not None and dtype != self.dtype:
183      return math_ops.cast(self.read_value(), dtype)
184    else:
185      return self.handle if as_ref else self.read_value()
186
187
188class TPUDistributedVariable(TPUVariableMixin, values.DistributedVariable):
189  """DistributedVariable subclass for TPUStrategy."""
190
191  def _is_mirrored(self):
192    return self._policy._is_mirrored()  # pylint: disable=protected-access
193
194  def assign_sub(self, value, use_locking=False, name=None, read_value=True):
195    if values_util.is_saving_non_distributed():
196      return self._primary.assign_sub(value, use_locking, name, read_value)
197    return self._policy.assign_sub(
198        self, value, use_locking=use_locking, name=name, read_value=read_value)
199
200  def assign_add(self, value, use_locking=False, name=None, read_value=True):
201    if values_util.is_saving_non_distributed():
202      return self._primary.assign_add(value, use_locking, name, read_value)
203    return self._policy.assign_add(
204        self, value, use_locking=use_locking, name=name, read_value=read_value)
205
206  def assign(self, value, use_locking=False, name=None, read_value=True):
207    if values_util.is_saving_non_distributed():
208      return self._primary.assign(value, use_locking, name, read_value)
209    return self._policy.assign(
210        self, value, use_locking=use_locking, name=name, read_value=read_value)
211
212  def scatter_sub(self, sparse_delta, use_locking=False, name=None):
213    if values_util.is_saving_non_distributed():
214      return self._primary.scatter_sub(sparse_delta, use_locking, name)
215    return self._policy.scatter_sub(
216        self, sparse_delta, use_locking=use_locking, name=name)
217
218  def scatter_add(self, sparse_delta, use_locking=False, name=None):
219    if values_util.is_saving_non_distributed():
220      return self._primary.scatter_add(sparse_delta, use_locking, name)
221    return self._policy.scatter_add(
222        self, sparse_delta, use_locking=use_locking, name=name)
223
224  def scatter_mul(self, sparse_delta, use_locking=False, name=None):
225    if values_util.is_saving_non_distributed():
226      return self._primary.scatter_mul(sparse_delta, use_locking, name)
227    return self._policy.scatter_mul(
228        self, sparse_delta, use_locking=use_locking, name=name)
229
230  def scatter_div(self, sparse_delta, use_locking=False, name=None):
231    if values_util.is_saving_non_distributed():
232      return self._primary.scatter_div(sparse_delta, use_locking, name)
233    return self._policy.scatter_div(
234        self, sparse_delta, use_locking=use_locking, name=name)
235
236  def scatter_min(self, sparse_delta, use_locking=False, name=None):
237    if values_util.is_saving_non_distributed():
238      return self._primary.scatter_min(sparse_delta, use_locking, name)
239    return self._policy.scatter_min(
240        self, sparse_delta, use_locking=use_locking, name=name)
241
242  def scatter_max(self, sparse_delta, use_locking=False, name=None):
243    if values_util.is_saving_non_distributed():
244      return self._primary.scatter_max(sparse_delta, use_locking, name)
245    return self._policy.scatter_max(
246        self, sparse_delta, use_locking=use_locking, name=name)
247
248  def scatter_update(self, sparse_delta, use_locking=False, name=None):
249    if values_util.is_saving_non_distributed():
250      return self._primary.scatter_update(sparse_delta, use_locking, name)
251    return self._policy.scatter_update(
252        self, sparse_delta, use_locking=use_locking, name=name)
253
254
255class TPUMirroredVariable(TPUVariableMixin, values.MirroredVariable):
256  """Holds a map from replica to TPU variables whose values are kept in sync."""
257
258  def assign_sub(self, value, use_locking=False, name=None, read_value=True):
259    if (tpu_util.enclosing_tpu_context() and
260        self.aggregation == variable_scope.VariableAggregation.NONE):
261      return _make_raw_assign_fn(
262          gen_resource_variable_ops.assign_sub_variable_op)(
263              self,
264              value=value,
265              use_locking=use_locking,
266              name=name,
267              read_value=read_value)
268    return assign_sub(
269        self, value, use_locking=use_locking, name=name, read_value=read_value)
270
271  def assign_add(self, value, use_locking=False, name=None, read_value=True):
272    if (tpu_util.enclosing_tpu_context() and
273        self.aggregation == variable_scope.VariableAggregation.NONE):
274      return _make_raw_assign_fn(
275          gen_resource_variable_ops.assign_add_variable_op)(
276              self,
277              value=value,
278              use_locking=use_locking,
279              name=name,
280              read_value=read_value)
281    return assign_add(
282        self, value, use_locking=use_locking, name=name, read_value=read_value)
283
284  def assign(self, value, use_locking=False, name=None, read_value=True):
285    if (tpu_util.enclosing_tpu_context() and
286        self.aggregation == variable_scope.VariableAggregation.NONE):
287      return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)(
288          self,
289          value=value,
290          use_locking=use_locking,
291          name=name,
292          read_value=read_value)
293    return assign(
294        self, value, use_locking=use_locking, name=name, read_value=read_value)
295
296  def scatter_sub(self, *args, **kwargs):
297    if values_util.is_saving_non_distributed():
298      return self._primary.scatter_sub(*args, **kwargs)
299    raise NotImplementedError
300
301  def scatter_add(self, *args, **kwargs):
302    if values_util.is_saving_non_distributed():
303      return self._primary.scatter_add(*args, **kwargs)
304    raise NotImplementedError
305
306  def scatter_max(self, *args, **kwargs):
307    if values_util.is_saving_non_distributed():
308      return self._primary.scatter_max(*args, **kwargs)
309    raise NotImplementedError
310
311  def scatter_min(self, *args, **kwargs):
312    if values_util.is_saving_non_distributed():
313      return self._primary.scatter_min(*args, **kwargs)
314    raise NotImplementedError
315
316  def scatter_mul(self, *args, **kwargs):
317    if values_util.is_saving_non_distributed():
318      return self._primary.scatter_mul(*args, **kwargs)
319    raise NotImplementedError
320
321  def scatter_div(self, *args, **kwargs):
322    if values_util.is_saving_non_distributed():
323      return self._primary.scatter_div(*args, **kwargs)
324    raise NotImplementedError
325
326  def scatter_update(self, *args, **kwargs):
327    if values_util.is_saving_non_distributed():
328      return self._primary.scatter_update(*args, **kwargs)
329    raise NotImplementedError
330
331  def _is_mirrored(self):
332    return True
333
334
335class TPUSyncOnReadVariable(TPUVariableMixin, values.SyncOnReadVariable):
336  """Holds a map from replica to variables whose values are reduced on save."""
337
338  def assign_sub(self, *args, **kwargs):
339    if tpu_util.enclosing_tpu_context() is None:
340      return values.SyncOnReadVariable.assign_sub(self, *args, **kwargs)
341    else:
342      return _make_raw_assign_fn(
343          gen_resource_variable_ops.assign_sub_variable_op)(self, *args,
344                                                            **kwargs)
345
346  def assign_add(self, *args, **kwargs):
347    if tpu_util.enclosing_tpu_context() is None:
348      return values.SyncOnReadVariable.assign_add(self, *args, **kwargs)
349    else:
350      return _make_raw_assign_fn(
351          gen_resource_variable_ops.assign_add_variable_op)(self, *args,
352                                                            **kwargs)
353
354  def assign(self, *args, **kwargs):
355    if tpu_util.enclosing_tpu_context() is None:
356      return values.SyncOnReadVariable.assign(self, *args, **kwargs)
357    else:
358      return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)(
359          self, *args, **kwargs)
360
361  def _is_mirrored(self):
362    return False
363
364
365# Common method between OnWrite and Mirrored variables.
366def assign_sub(var, value, use_locking=False, name=None, read_value=True):
367  assign_sub_fn = _make_raw_assign_fn(
368      gen_resource_variable_ops.assign_sub_variable_op)
369  return var._update(  # pylint: disable=protected-access
370      update_fn=assign_sub_fn,
371      value=value,
372      use_locking=use_locking,
373      name=name,
374      read_value=read_value)
375
376
377def assign_add(var, value, use_locking=False, name=None, read_value=True):
378  assign_add_fn = _make_raw_assign_fn(
379      gen_resource_variable_ops.assign_add_variable_op)
380  return var._update(  # pylint: disable=protected-access
381      update_fn=assign_add_fn,
382      value=value,
383      use_locking=use_locking,
384      name=name,
385      read_value=read_value)
386
387
388def assign(var, value, use_locking=False, name=None, read_value=True):
389  assign_fn = _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)
390  return var._update(  # pylint: disable=protected-access
391      update_fn=assign_fn,
392      value=value,
393      use_locking=use_locking,
394      name=name,
395      read_value=read_value)
396
397
398class TPUOnWritePolicy(values.OnWritePolicy):
399  """Policy defined for `tf.VariableSynchronization.ON_WRITE` synchronization.
400
401  This policy is created when `synchronization` is set to
402  `tf.VariableSynchronization.AUTO` or `tf.VariableSynchronization.ON_WRITE`.
403  """
404
405  def assign_sub(self,
406                 var,
407                 value,
408                 use_locking=False,
409                 name=None,
410                 read_value=True):
411    if (tpu_util.enclosing_tpu_context() and
412        var.aggregation == variable_scope.VariableAggregation.NONE):
413      return _make_raw_assign_fn(
414          gen_resource_variable_ops.assign_sub_variable_op)(
415              var,
416              value=value,
417              use_locking=use_locking,
418              name=name,
419              read_value=read_value)
420    return assign_sub(
421        var, value, use_locking=use_locking, name=name, read_value=read_value)
422
423  def assign_add(self,
424                 var,
425                 value,
426                 use_locking=False,
427                 name=None,
428                 read_value=True):
429    if (tpu_util.enclosing_tpu_context() and
430        var.aggregation == variable_scope.VariableAggregation.NONE):
431      return _make_raw_assign_fn(
432          gen_resource_variable_ops.assign_add_variable_op)(
433              var,
434              value=value,
435              use_locking=use_locking,
436              name=name,
437              read_value=read_value)
438    return assign_add(
439        var, value, use_locking=use_locking, name=name, read_value=read_value)
440
441  def assign(self, var, value, use_locking=False, name=None, read_value=True):
442    if (tpu_util.enclosing_tpu_context() and
443        var.aggregation == variable_scope.VariableAggregation.NONE):
444      return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)(
445          var,
446          value=value,
447          use_locking=use_locking,
448          name=name,
449          read_value=read_value)
450    return assign(
451        var, value, use_locking=use_locking, name=name, read_value=read_value)
452
453  def scatter_sub(self, *args, **kwargs):
454    raise NotImplementedError
455
456  def scatter_add(self, *args, **kwargs):
457    raise NotImplementedError
458
459  def scatter_max(self, *args, **kwargs):
460    raise NotImplementedError
461
462  def scatter_min(self, *args, **kwargs):
463    raise NotImplementedError
464
465  def scatter_mul(self, *args, **kwargs):
466    raise NotImplementedError
467
468  def scatter_div(self, *args, **kwargs):
469    raise NotImplementedError
470
471  def scatter_update(self, *args, **kwargs):
472    raise NotImplementedError
473
474  def _is_mirrored(self):
475    return True
476
477
478class TPUOnReadPolicy(values.OnReadPolicy):
479  """Policy defined for `tf.VariableSynchronization.ON_READ` synchronization.
480
481  This policy is created when `synchronization` is set to
482  `tf.VariableSynchronization.ON_READ` and `aggregation` is set to any of the
483  values allowed by the `tf.VariableAggregation` enum such as `NONE`, `SUM`,
484  `MEAN` or `ONLY_FIRST_REPLICA`when creating a `tf.Variable` in `tf.distribute`
485  scope.
486  """
487
488  def assign_sub(self, var, *args, **kwargs):
489    if tpu_util.enclosing_tpu_context() is None:
490      return super(TPUOnReadPolicy, self).assign_sub(var, *args, **kwargs)
491    else:
492      return _make_raw_assign_fn(
493          gen_resource_variable_ops.assign_sub_variable_op)(var, *args,
494                                                            **kwargs)
495
496  def assign_add(self, var, *args, **kwargs):
497    if tpu_util.enclosing_tpu_context() is None:
498      return super(TPUOnReadPolicy, self).assign_add(var, *args, **kwargs)
499    else:
500      return _make_raw_assign_fn(
501          gen_resource_variable_ops.assign_add_variable_op)(var, *args,
502                                                            **kwargs)
503
504  def assign(self, var, *args, **kwargs):
505    if tpu_util.enclosing_tpu_context() is None:
506      return super(TPUOnReadPolicy, self).assign(var, *args, **kwargs)
507    else:
508      return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)(
509          var, *args, **kwargs)
510
511  def _is_mirrored(self):
512    return False
513
514  def scatter_sub(self, *args, **kwargs):
515    raise NotImplementedError
516
517  def scatter_add(self, *args, **kwargs):
518    raise NotImplementedError
519
520  def scatter_max(self, *args, **kwargs):
521    raise NotImplementedError
522
523  def scatter_min(self, *args, **kwargs):
524    raise NotImplementedError
525
526  def scatter_mul(self, *args, **kwargs):
527    raise NotImplementedError
528
529  def scatter_div(self, *args, **kwargs):
530    raise NotImplementedError
531
532  def scatter_update(self, *args, **kwargs):
533    raise NotImplementedError
534