1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains AutoCastVariable, a variable which automatically casts itself."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import threading
21
22from tensorflow.python.distribute import distribute_utils
23from tensorflow.python.distribute import ps_values as ps_distribute_values
24from tensorflow.python.eager import context
25from tensorflow.python.framework import ops
26from tensorflow.python.ops import math_ops
27from tensorflow.python.ops import resource_variable_ops
28from tensorflow.python.ops import variables
29from tensorflow.python.types import core
30
31
32# _autocast_dtype.dtype is the dtype AutoCastVariables should be cast to, or
33# None if AutoCastVariables should not be cast.
34_autocast_dtype = threading.local()
35
36
37def numpy_text(tensor, is_repr=False):
38  """Human readable representation of a tensor's numpy value."""
39  if tensor.dtype.is_numpy_compatible:
40    # pylint: disable=protected-access
41    text = repr(tensor._numpy()) if is_repr else str(tensor._numpy())
42    # pylint: enable=protected-access
43  else:
44    text = '<unprintable>'
45  if '\n' in text:
46    text = '\n' + text
47  return text
48
49
50class AutoCastVariable(variables.Variable, core.Tensor):
51  """Variable that will cast itself to a different dtype in applicable contexts.
52
53  This class wraps a floating-point `tf.Variable`. It emulates the variable
54  interface and delegates to the wrapped variable, but it additionally will cast
55  the wrapped variable under an `enable_auto_cast_variables(dtype)` context
56  manager.
57
58  For example:
59
60  >>> v = tf.Variable(1.0, dtype=tf.float32)
61  >>> v = AutoCastVariable(v)
62  >>> tf.identity(v).dtype
63  tf.float32
64  >>> with enable_auto_cast_variables(tf.float16):
65  ...   tf.identity(v).dtype
66  tf.float16
67
68  The purpose of this class is to allow Keras layers to create variables in
69  float32, and automatically cast them to float16 or bfloat16 when the layer is
70  called.
71  """
72
73  def __init__(self, variable):
74    """Creates an AutoCastVariable instance.
75
76    Args:
77      variable: A floating-point resource variable to wrap.
78
79    Raises:
80      ValueError: If `variable` is not a floating-point resource variable
81    """
82    if not isinstance(variable, variables.Variable):
83      raise ValueError('variable must be of type tf.ResourceVariable, but got: '
84                       '%s' % variable)
85    if not variable.dtype.is_floating:
86      raise ValueError('variable must be a floating point variable but has '
87                       'type: %s' % variable.dtype.name)
88    self._variable = variable
89    # 'delegate' means AutoCastVariable.op return self._variable.op, which will
90    # raise an AttributeError in Eager (as intended). If set to any other value,
91    # AutoCastVariable.op returns that value instead, which is used to set the
92    # op attribute in AutoCastVariable.assign().
93    self._op = 'delegate'
94
95  def _should_cast(self):
96    """Returns True if this variable should be casted when accessed."""
97    autocast_dtype = getattr(_autocast_dtype, 'dtype', None)
98    return autocast_dtype is not None and self.dtype != autocast_dtype
99
100  @property
101  def dtype(self):
102    """The dtype of the underlying variable, before any casts are done."""
103    return self._variable.dtype
104
105  @property
106  def true_dtype(self):
107    """Deprecated alias of `dtype`."""
108    return self._variable.dtype
109
110  @property
111  def _cast_dtype(self):
112    dtype = getattr(_autocast_dtype, 'dtype', None)
113    return dtype or self._variable.dtype
114
115  def value(self):
116    val = self._variable.value()
117    if not self._should_cast():
118      return val
119    return math_ops.cast(val, self._cast_dtype)
120
121  def read_value(self):
122    val = self._variable.read_value()
123    return math_ops.cast(val, self._cast_dtype)
124
125  def sparse_read(self, indices, name=None):
126    """Reads the value of this variable sparsely, using `gather`."""
127    val = self._variable.sparse_read(indices, name=name)
128    return math_ops.cast(val, self._cast_dtype)
129
130  def gather_nd(self, indices, name=None):
131    """Gather slices of the variable into a Tensor."""
132    val = self._variable.gather_nd(indices, name=name)
133    return math_ops.cast(val, self._cast_dtype)
134
135  def __getattr__(self, name):
136    return getattr(self._variable, name)
137
138  def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
139    """Converts this variable to a tensor."""
140    if as_ref:
141      # This ValueError should not occur in practice since it is impossible to
142      # pass as_ref=True using public APIs.
143      raise ValueError('Cannot convert AutoCastVariable to a tensor if '
144                       'as_ref=True is passed to convert_to_tensor')
145    if not self._should_cast():
146      return ops.convert_to_tensor_v2_with_dispatch(self._variable, dtype=dtype,
147                                                    name=name)
148    if dtype is not None and not dtype.is_compatible_with(self._cast_dtype):
149      raise ValueError(
150          'Incompatible type conversion requested to type {!r} for '
151          'AutoCastVariable which is casted to type {!r}'.format(
152              dtype.name, self._cast_dtype.name))
153    val = ops.convert_to_tensor_v2_with_dispatch(
154        self._variable, dtype=self._variable.dtype, name=name)
155    return math_ops.cast(val, self._cast_dtype)
156
157  def _should_act_as_resource_variable(self):
158    """Pass resource_variable_ops.is_resource_variable check."""
159    pass
160
161  def __repr__(self):
162    if context.executing_eagerly() and not self._in_graph_mode:
163      repr_str = ("<AutoCastVariable '{v.name}' shape={v.shape} "
164                  'dtype={v.dtype.name} dtype_to_cast_to={v._cast_dtype.name}, '
165                  'numpy={np_repr}>')
166      return repr_str.format(
167          v=self, np_repr=numpy_text(self.read_value(), is_repr=True))
168    else:
169      repr_str = ("<AutoCastVariable '{v.name}' shape={v.shape} "
170                  'dtype={v.dtype.name} dtype_to_cast_to={v._cast_dtype.name}>')
171      return repr_str.format(v=self)
172
173  # Method delegations: We delegate the following methods to self._variable.
174  # Each of these methods simply calls the same method on self._variable. The
175  # base Variable raises NotImplementedError for most of these, so we must
176  # override them.
177  #
178  # We do not define the following methods from Variable for the following
179  # reasons:
180  #   * 'count_up_to': This method only applies to int variables, which cannot
181  #     be wrapped with an AutoCastVariable.
182  #   * 'ref': Instead we inherit the definition from Variable.
183  #     If we defined and delegated to Variable, the ref of an AutoCastVariable
184  #     would be the same as the ref of the underlying variable, which would be
185  #     strange as they are different Python objects.
186
187  def set_shape(self, shape):
188    return self._variable.set_shape(self, shape)
189
190  @property
191  def trainable(self):
192    return self._variable.trainable
193
194  @property
195  def synchronization(self):
196    return self._variable.synchronization
197
198  @property
199  def aggregation(self):
200    return self._variable.aggregation
201
202  def eval(self, session=None):
203    return self._variable.eval(session)
204
205  def initialized_value(self):
206    return self._variable.initialized_value()
207
208  @property
209  def initial_value(self):
210    return self._variable.initial_value
211
212  @property
213  def constraint(self):
214    return self._variable.constraint
215
216  def _apply_assign_update(self,
217                           update_fn,
218                           value,
219                           use_locking=None,
220                           name=None,
221                           read_value=True):
222    # TODO(b/146181571): This logic can be simplified once
223    # DistributedVariable.assign returns a DistributedVariable. Currently for
224    # MirroredStrategy, it returns a Mirrored value.
225    if ops.executing_eagerly_outside_functions():
226      assign_op = update_fn(value, use_locking, name, False)
227      if read_value:
228        # We create a new AutoCastVariable with the same underlying tf.Variable.
229        # The new AutoCastVariable is identical except the 'op' attribute is
230        # defined. This matches the behavior of tf.Variable.assign.
231        var = create_autocast_variable(self._variable)
232        var._op = assign_op  # pylint:disable=protected-access
233        return var
234      return assign_op
235
236    # Fallback to wrapping the returned variable in graph mode if possible
237    assign_var = update_fn(value, use_locking, name, read_value)
238    if read_value and resource_variable_ops.is_resource_variable(assign_var):
239      return create_autocast_variable(assign_var)
240    return assign_var
241
242  def _apply_update(self, update_fn, *args, **kwargs):
243    update_var = update_fn(*args, **kwargs)
244    if ops.executing_eagerly_outside_functions():
245      return self
246
247    # Fallback to wrapping the returned variable in graph mode if possible
248    if resource_variable_ops.is_resource_variable(update_var):
249      return create_autocast_variable(update_var)
250    return update_var
251
252  def assign(self, value, use_locking=None, name=None, read_value=True):
253    return self._apply_assign_update(self._variable.assign, value, use_locking,
254                                     name, read_value)
255
256  def assign_add(self, delta, use_locking=None, name=None, read_value=True):
257    return self._apply_assign_update(self._variable.assign_add, delta,
258                                     use_locking, name, read_value)
259
260  def assign_sub(self, delta, use_locking=None, name=None, read_value=True):
261    return self._apply_assign_update(self._variable.assign_sub, delta,
262                                     use_locking, name, read_value)
263
264  def scatter_sub(self, sparse_delta, use_locking=False, name=None):
265    return self._apply_update(self._variable.scatter_sub, sparse_delta,
266                              use_locking, name)
267
268  def scatter_add(self, sparse_delta, use_locking=False, name=None):
269    return self._apply_update(self._variable.scatter_add, sparse_delta,
270                              use_locking, name)
271
272  def scatter_max(self, sparse_delta, use_locking=False, name=None):
273    return self._apply_update(self._variable.scatter_max, sparse_delta,
274                              use_locking, name)
275
276  def scatter_min(self, sparse_delta, use_locking=False, name=None):
277    return self._apply_update(self._variable.scatter_min, sparse_delta,
278                              use_locking, name)
279
280  def scatter_mul(self, sparse_delta, use_locking=False, name=None):
281    return self._apply_update(self._variable.scatter_mul, sparse_delta,
282                              use_locking, name)
283
284  def scatter_div(self, sparse_delta, use_locking=False, name=None):
285    return self._apply_update(self._variable.scatter_div, sparse_delta,
286                              use_locking, name)
287
288  def scatter_update(self, sparse_delta, use_locking=False, name=None):
289    return self._apply_update(self._variable.scatter_update, sparse_delta,
290                              use_locking, name)
291
292  def batch_scatter_update(self, sparse_delta, use_locking=False, name=None):
293    return self._apply_update(self._variable.batch_scatter_update, sparse_delta,
294                              use_locking, name)
295
296  def scatter_nd_sub(self, indices, updates, name=None):
297    return self._apply_update(self._variable.scatter_nd_sub, indices, updates,
298                              name)
299
300  def scatter_nd_add(self, indices, updates, name=None):
301    return self._apply_update(self._variable.scatter_nd_add, indices, updates,
302                              name)
303
304  def scatter_nd_update(self, indices, updates, name=None):
305    return self._apply_update(self._variable.scatter_nd_update, indices,
306                              updates, name)
307
308  def load(self, value, session=None):
309    return self._variable.load(value, session)
310
311  @property
312  def name(self):
313    return self._variable.name
314
315  @property
316  def _shared_name(self):
317    return self._variable._shared_name  # pylint:disable=protected-access
318
319  @property
320  def initializer(self):
321    return self._variable.initializer
322
323  @property
324  def device(self):
325    return self._variable.device
326
327  @property
328  def op(self):
329    if self._op == 'delegate':
330      return self._variable.op
331    return self._op
332
333  def _as_graph_element(self):
334    graph_element = self._variable._as_graph_element()  # pylint:disable=protected-access
335    if graph_element is None:
336      return self._op
337    return graph_element
338
339  @property
340  def graph(self):
341    return self._variable.graph
342
343  @property
344  def shape(self):
345    return self._variable.shape
346
347  def get_shape(self):
348    return self._variable.get_shape()
349
350  def _gather_saveables_for_checkpoint(self):
351    # By delegating this method to the wrapped variable, checkpoints with
352    # AutoCastVariables are identical to checkpoints with normal variables.
353    # Therefore models checkpointed with AutoCastVariables can be restored on
354    # models with normal variables, and vice versa.
355    return self._variable._gather_saveables_for_checkpoint()  # pylint:disable=protected-access
356
357  def _map_resources(self, save_options):
358    # By delegating this method to the wrapped variable, SavedModel with
359    # AutoCastVariables are identical to SavedModel with normal variables.
360    obj_map, resource_map = self._variable._map_resources(save_options)  # pylint:disable=protected-access
361    obj_map[self] = obj_map[self._variable]
362    return obj_map, resource_map
363
364  # TODO(reedwm): Maybe encode the fact the variable is an AutoCastVariable in
365  # to_proto().
366  def to_proto(self, export_scope=None):
367    return self._variable.to_proto(export_scope)
368
369  def from_proto(self, variable_def, import_scope=None):
370    return self._variable.from_proto(variable_def, import_scope)
371
372  # Delegate the private attributes _handle_name and _initializer_op to
373  # self._variable. SavedModel sets these attributes when loading a model. For
374  # example, it sets _handle_name here:
375  # https://github.com/tensorflow/tensorflow/blob/db26bd574fa95b5bdd53c08463dd19407cc0297e/tensorflow/python/keras/saving/saved_model/load.py#L211
376  # We need to expose these attributes on AutoCastVariable as well for
377  # SavedModel to work properly.
378  # TODO(reedwm/kathywu): Find a better way to support SavedModel. Exposing
379  # private attributes is hacky and difficult to maintain.
380  @property
381  def _handle_name(self):
382    return self._variable._handle_name  # pylint: disable=protected-access
383
384  @_handle_name.setter
385  def _handle_name(self, handle_name):
386    self._variable._handle_name = handle_name  # pylint: disable=protected-access
387
388  @property
389  def _initializer_op(self):
390    return self._variable._initializer_op  # pylint: disable=protected-access
391
392  @_initializer_op.setter
393  def _initializer_op(self, initializer_op):
394    self._variable._initializer_op = initializer_op  # pylint: disable=protected-access
395
396  # Operator overloads:
397  # Note we only overload operators that support floating-point types, as
398  # non-float variables cannot be wrapped with an AutoCastVariable.
399  # Also note: We call read_value() instead of value(), because value() causes
400  # gradients not to work properly when TPUStrategy is used: b/143380936
401
402  def __add__(self, o):
403    return self.read_value() + o
404
405  def __radd__(self, o):
406    return o + self.read_value()
407
408  def __sub__(self, o):
409    return self.read_value() - o
410
411  def __rsub__(self, o):
412    return o - self.read_value()
413
414  def __mul__(self, o):
415    return self.read_value() * o
416
417  def __rmul__(self, o):
418    return o * self.read_value()
419
420  def __truediv__(self, o):
421    return self.read_value() / o
422
423  def __rtruediv__(self, o):
424    return o / self.read_value()
425
426  def __floordiv__(self, o):
427    return self.read_value() // o
428
429  def __rfloordiv__(self, o):
430    return o // self.read_value()
431
432  def __mod__(self, o):
433    return self.read_value() % o
434
435  def __rmod__(self, o):
436    return o % self.read_value()
437
438  def __lt__(self, o):
439    return self.read_value() < o
440
441  def __le__(self, o):
442    return self.read_value() <= o
443
444  def __gt__(self, o):
445    return self.read_value() > o
446
447  def __ge__(self, o):
448    return self.read_value() >= o
449
450  def __getitem__(self, o):
451    return self.read_value()[o]
452
453  def __pow__(self, o, modulo=None):
454    return pow(self.read_value(), o, modulo)
455
456  def __rpow__(self, o):
457    return pow(o, self.read_value())
458
459  def __neg__(self):
460    return -self.read_value()
461
462  def __abs__(self):
463    return abs(self.read_value())
464
465  def __div__(self, o):
466    try:
467      return self.read_value().__div__(o)
468    except AttributeError:
469      # See https://docs.python.org/3/library/constants.html#NotImplemented
470      return NotImplemented
471
472  def __rdiv__(self, o):
473    try:
474      return self.read_value().__rdiv__(o)
475    except AttributeError:
476      # See https://docs.python.org/3/library/constants.html#NotImplemented
477      return NotImplemented
478
479  def __matmul__(self, o):
480    try:
481      return self.read_value().__matmul__(o)
482    except AttributeError:
483      # See https://docs.python.org/3/library/constants.html#NotImplemented
484      return NotImplemented
485
486  def __rmatmul__(self, o):
487    try:
488      return self.read_value().__rmatmul__(o)
489    except AttributeError:
490      # See https://docs.python.org/3/library/constants.html#NotImplemented
491      return NotImplemented
492
493  # pylint: enable=multiple-statements
494
495
496ops.register_tensor_conversion_function(AutoCastVariable,
497                                        AutoCastVariable._dense_var_to_tensor)  # pylint:disable=protected-access
498
499
500def create_autocast_variable(variable):
501  """Creates an AutoCastVariable that wraps another variable.
502
503  This typically just returns `AutoCastVariable(variable)`. But, if the variable
504  is a DistributedVariable or one of its subclasses, we instead dynamically
505  create a class that subclasses from both AutoCastVariable and
506  variable.__class__. This is so the returned variable will still pass
507  `isinstance(variable, variable.__class__)`, which is required for
508  DistributedVariables and its subclasses to work properly.
509
510  Args:
511    variable: A floating-point resource variable to wrap.
512
513  Returns:
514    An AutoCastVariable that wraps the variable.
515  """
516  if (not distribute_utils.is_distributed_variable(variable) and
517      not isinstance(variable, ps_distribute_values.AggregatingVariable)):
518    return AutoCastVariable(variable)
519
520  class AutoCastDistributedVariable(AutoCastVariable, variable.__class__):
521    """An AutoCastVariable that also subclasses from variable.__class__.
522
523    variable.__class__ is either a DistributedVariable or an
524    AggregatingVariable.
525    """
526
527    def __repr__(self):
528      if issubclass(ps_distribute_values.AggregatingVariable,
529                    variable.__class__):
530        # AggregatingVariable's __repr__ simply calls super.__repr__. So we do
531        # the same here for consistency, which calls AutoCastVariable.__repr__.
532        return super(AutoCastDistributedVariable, self).__repr__()
533
534      # pylint: disable=missing-format-attribute
535      return ('<AutoCastDistributedVariable dtype={v.dtype.name} '
536              'dtype_to_cast_to={v._cast_dtype.name} '
537              'inner_variable={v._variable}>'
538             ).format(v=self)
539      # pylint: enable=missing-format-attribute
540
541  return AutoCastDistributedVariable(variable)
542
543
544class enable_auto_cast_variables(object):  # pylint:disable=invalid-name
545  """Context manager which enables the autocasting of `AutoCastVariable`s.
546
547  Under this context manager, `AutoCastVariable`s will be cast to `dtype` if
548  `dtype` is floating-point. Otherwise, `AutoCastVariable`s will not be cast.
549  """
550
551  __slots__ = ['_dtype', '_prev_dtype']
552
553  def __init__(self, dtype):
554    if dtype and not dtype.is_floating:
555      dtype = None
556    self._dtype = dtype
557
558  def __enter__(self):
559    self._prev_dtype = getattr(_autocast_dtype, 'dtype', None)
560    _autocast_dtype.dtype = self._dtype
561
562  def __exit__(self, type_arg, value_arg, traceback_arg):
563    _autocast_dtype.dtype = self._prev_dtype
564