1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains AutoCastVariable, a variable which automatically casts itself."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.distribute import values as distribute_values
21from tensorflow.python.framework import ops
22from tensorflow.python.ops import math_ops
23from tensorflow.python.ops import resource_variable_ops
24
25
26# TODO(reedwm): Make checkpointable?
27class AutoCastVariable(object):
28  """Variable that will cast itself to a different dtype in applicable contexts.
29
30  This class wraps a floating-point tf.Variable. It emulates the variable
31  interface and delegates to the wrapped variable, but it additionally will cast
32  the wrapped variable under a `Graph._enable_variable_auto_cast(dtype)` context
33  manager.
34
35  For example:
36
37  ```
38  v = tf.Variable(1.0, dtype=tf.float32)
39  v = AutoCastVariable(v)
40  print(tf.identity(v).dtype)  # tf.float32
41  with ops.get_default_graph()._enable_variable_auto_cast(tf.float16):
42    print(tf.identity(v).dtype)  # tf.float16, as v will cast itself to float16
43    print(v.dtype)  # tf.float16, as v.dtype also changes under the ctx manager.
44  ```
45
46  The purpose of this class is to allow Keras layers to create variables in
47  float32, and automatically cast them to float16 or bfloat16 when the layer is
48  called.
49  """
50
51  def __init__(self, variable):
52    """Creates an AutoCastVariable instance.
53
54    Args:
55      variable: A floating-point resource variable to wrap.
56
57    Raises:
58      ValueError: If `variable` is not a floating-point resource variable
59    """
60    if not resource_variable_ops.is_resource_variable(variable):
61      raise ValueError('variable must be of type tf.ResourceVariable, but got: '
62                       '%s' % variable)
63    if not variable.dtype.is_floating:
64      raise ValueError('variable must be a floating point variable but has '
65                       'type: %s' % variable.dtype.name)
66    self._variable = variable
67
68  @property
69  def name(self):
70    return self._variable.name
71
72  def _should_cast(self):
73    """Returns True if this variable should be casted when accessed."""
74    g = ops.get_default_graph()
75    # pylint:disable=protected-access
76    return (g._auto_cast_variable_read_dtype is not None and
77            self.true_dtype != g._auto_cast_variable_read_dtype)
78    # pylint:enable=protected-access
79
80  @property
81  def dtype(self):
82    """The dtype this variable will be casted to when read."""
83    if self._should_cast():
84      return ops.get_default_graph()._auto_cast_variable_read_dtype  # pylint:disable=protected-access
85    else:
86      return self._variable.dtype
87
88  @property
89  def true_dtype(self):
90    """The dtype of the underlying variable, before any casts are done."""
91    return self._variable.dtype
92
93  def value(self):
94    val = self._variable.value()
95    if not self._should_cast():
96      return val
97    # We colocate_with(None) to ignore the existing device constraints, so that
98    # the cast is always done on the variable's device
99    with ops.colocate_with(None, ignore_existing=True):
100      with ops.device(val.device):
101        return math_ops.cast(val, self.dtype)
102
103  def read_value(self):
104    val = self._variable.read_value()
105    if not self._should_cast():
106      return val
107    return math_ops.cast(val, self.dtype)
108
109  def sparse_read(self, indices, name=None):
110    """Reads the value of this variable sparsely, using `gather`."""
111    val = self._variable.sparse_read(indices, name=name)
112    if not self._should_cast():
113      return val
114    return math_ops.cast(val, self.dtype)
115
116  def assign(self, value, use_locking=None, name=None, read_value=True):
117    return self._variable.assign(
118        value, use_locking=use_locking, name=name, read_value=read_value)
119
120  def assign_add(self, delta, use_locking=None, name=None, read_value=True):
121    return self._variable.assign_add(
122        delta, use_locking=use_locking, name=name, read_value=read_value)
123
124  def assign_sub(self, delta, use_locking=None, name=None, read_value=True):
125    return self._variable.assign_sub(
126        delta, use_locking=use_locking, name=name, read_value=read_value)
127
128  # TODO(reedwm): Support assigning variables with tf.assign(), var.scatter_add,
129  # etc.
130
131  def __getattr__(self, name):
132    return getattr(self._variable, name)
133
134  def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
135    """Converts this variable to a tensor."""
136    if not self._should_cast():
137      return ops.internal_convert_to_tensor(self._variable, dtype, name,
138                                            as_ref)
139    # TODO(reedwm): Support as_ref?
140    assert not as_ref
141    if dtype is not None and not dtype.is_compatible_with(self.dtype):
142      raise ValueError(
143          'Incompatible type conversion requested to type {!r} for variable '
144          'of type {!r}'.format(dtype.name, self.dtype.name))
145    val = ops.internal_convert_to_tensor(self._variable,
146                                         self._variable.dtype, name,
147                                         as_ref=False)
148    with ops.colocate_with(None, ignore_existing=True):
149      with ops.device(val.device):
150        return math_ops.cast(val, self.dtype)
151
152  def _should_act_as_resource_variable(self):
153    """Pass resource_variable_ops.is_resource_variable check."""
154    pass
155
156  # TODO(reedwm): Define operator overloads.
157
158
159ops.register_tensor_conversion_function(
160    AutoCastVariable, AutoCastVariable._dense_var_to_tensor)  # pylint:disable=protected-access
161ops.register_dense_tensor_like_type(AutoCastVariable)
162
163
164# We have DistributedVariable subclass to pass
165# isinstance(..., DistributedVariable) checks when wrapping a
166# DistributedVariable.
167# TODO(reedwm): We should not wrap DistributedVariable, but instead have
168# DistributedVariable wrap AutoCastVariable. Subclassing DistributedVariable is
169# messy, because we do not fully implement the interface of DistributedVariable.
170class AutoCastDistributedVariable(AutoCastVariable,
171                                  distribute_values.DistributedVariable):
172  """Version of AutoCastVariable that subclasses DistributedVariable."""
173
174  def __init__(self, variable):
175    if not isinstance(variable, distribute_values.DistributedValues):
176      raise ValueError('variable must be of type DistributedValues, '
177                       'but got: %s' % variable)
178    super(AutoCastDistributedVariable, self).__init__(variable)
179