1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Standard functions for creating slots.
17
18A slot is a `Variable` created with the same first m-dimension as a primary
19variable or `Tensor`. A slot is always scoped in the namespace of the primary
20object and typically has the same device and type.
21
22Slots are typically used as accumulators to track values associated with
23the primary object:
24
25```python
26# Optimizers can create a slot for each variable to track accumulators
27accumulators = {var : create_zeros_slot(var, "momentum") for var in vs}
28for var in vs:
29  apply_momentum(var, accumulators[var], lr, grad, momentum_tensor)
30
31# Slots can also be used for moving averages
32mavg = create_slot(var, var.initialized_value(), "exponential_moving_avg")
33update_mavg = mavg.assign_sub((mavg - var) * (1 - decay))
34```
35"""
36# pylint: disable=g-bad-name
37
38from __future__ import absolute_import
39from __future__ import division
40from __future__ import print_function
41
42from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding
43from tensorflow.python.distribute import distribution_strategy_context
44from tensorflow.python.ops import array_ops
45from tensorflow.python.ops import init_ops
46from tensorflow.python.ops import resource_variable_ops
47from tensorflow.python.ops import variable_scope
48from tensorflow.python.ops import variables
49
50
51def _create_slot_var(primary,
52                     val,
53                     scope,
54                     validate_shape,
55                     shape,
56                     dtype,
57                     *,
58                     copy_xla_sharding=False):
59  """Helper function for creating a slot variable."""
60
61  # TODO(lukaszkaiser): Consider allowing partitioners to be set in the current
62  # scope.
63  current_partitioner = variable_scope.get_variable_scope().partitioner
64  variable_scope.get_variable_scope().set_partitioner(None)
65  # When init from val instead of callable initializer, the shape is expected to
66  # be None, not <unknown> or any fully defined shape.
67  shape = shape if callable(val) else None
68  if resource_variable_ops.is_resource_variable(primary):
69    use_resource = True
70  elif isinstance(primary, variables.RefVariable):
71    use_resource = False
72  else:
73    use_resource = None
74  slot = variable_scope.get_variable(
75      scope,
76      initializer=val,
77      trainable=False,
78      use_resource=use_resource,
79      shape=shape,
80      dtype=dtype,
81      validate_shape=validate_shape)
82  variable_scope.get_variable_scope().set_partitioner(current_partitioner)
83
84  # pylint: disable=protected-access
85  if isinstance(primary, variables.Variable) and primary._save_slice_info:
86    # Primary is a partitioned variable, so we need to also indicate that
87    # the slot is a partitioned variable.  Slots have the same partitioning
88    # as their primaries.
89    # For examples when using AdamOptimizer in linear model, slot.name
90    # here can be "linear//weights/Adam:0", while primary.op.name is
91    # "linear//weight". We want to get 'Adam' as real_slot_name, so we
92    # remove "'linear//weight' + '/'" and ':0'.
93    real_slot_name = slot.name[len(primary.op.name + "/"):-2]
94    slice_info = primary._save_slice_info
95    # support slot's shape not same as primary's shape
96    # example: primary's shape = [10, 20, 30], slot's shape =
97    # None, [], [10], [10, 20] or [10, 20, 30] is allowed
98    # slot's shape = None or [10, 20, 30], set slot's slice_info same as primary
99    # slot's shape = [], don't set slot's slice_info
100    # slot's shape = [10] or [10, 20], set slot's slice_info according to ndims
101    n = slot.shape.ndims
102    if n is None or n > 0:
103      slot._set_save_slice_info(
104          variables.Variable.SaveSliceInfo(
105              slice_info.full_name + "/" + real_slot_name,
106              slice_info.full_shape[:n], slice_info.var_offset[:n],
107              slice_info.var_shape[:n]))
108  # pylint: enable=protected-access
109
110  # Copy XLA sharding attributes from primary.
111  if copy_xla_sharding:
112    slot = xla_sharding.copy_sharding(primary, slot, use_sharding_op=False)
113  return slot
114
115
116def create_slot(primary,
117                val,
118                name,
119                colocate_with_primary=True,
120                *,
121                copy_xla_sharding=False):
122  """Create a slot initialized to the given value.
123
124  The type of the slot is determined by the given value.
125
126  Args:
127    primary: The primary `Variable` or `Tensor`.
128    val: A `Tensor` specifying the initial value of the slot.
129    name: Name to use for the slot variable.
130    colocate_with_primary: Boolean.  If True the slot is located
131      on the same device as `primary`.
132    copy_xla_sharding: Boolean. If True also copies XLA sharding
133      from primary.
134
135  Returns:
136    A `Variable` object.
137  """
138  # Scope the slot name in the namespace of the primary variable.
139  # Set primary's name + '/' + name as default name, so the scope name of
140  # optimizer can be shared when reuse is True. Meanwhile when reuse is False
141  # and the same name has been previously used, the scope name will add '_N'
142  # as suffix for unique identifications.
143  validate_shape = val.get_shape().is_fully_defined()
144  if isinstance(primary, variables.Variable):
145    prefix = primary._shared_name  # pylint: disable=protected-access
146  else:
147    prefix = primary.op.name
148  with variable_scope.variable_scope(None, prefix + "/" + name):
149    if colocate_with_primary:
150      distribution_strategy = distribution_strategy_context.get_strategy()
151      with distribution_strategy.extended.colocate_vars_with(primary):
152        return _create_slot_var(
153            primary,
154            val,
155            "",
156            validate_shape,
157            None,
158            None,
159            copy_xla_sharding=copy_xla_sharding)
160    else:
161      return _create_slot_var(
162          primary,
163          val,
164          "",
165          validate_shape,
166          None,
167          None,
168          copy_xla_sharding=copy_xla_sharding)
169
170
171def create_slot_with_initializer(primary,
172                                 initializer,
173                                 shape,
174                                 dtype,
175                                 name,
176                                 colocate_with_primary=True,
177                                 *,
178                                 copy_xla_sharding=False):
179  """Creates a slot initialized using an `Initializer`.
180
181  The type of the slot is determined by the given value.
182
183  Args:
184    primary: The primary `Variable` or `Tensor`.
185    initializer: An `Initializer`.  The initial value of the slot.
186    shape: Shape of the initial value of the slot.
187    dtype: Type of the value of the slot.
188    name: Name to use for the slot variable.
189    colocate_with_primary: Boolean.  If True the slot is located
190      on the same device as `primary`.
191    copy_xla_sharding: Boolean. If True also copies XLA sharding
192      from primary.
193
194  Returns:
195    A `Variable` object.
196  """
197  # Scope the slot name in the namespace of the primary variable.
198  # Set "primary.op.name + '/' + name" as default name, so the scope name of
199  # optimizer can be shared when reuse is True. Meanwhile when reuse is False
200  # and the same name has been previously used, the scope name will add '_N'
201  # as suffix for unique identifications.
202  validate_shape = shape.is_fully_defined()
203  if isinstance(primary, variables.Variable):
204    prefix = primary._shared_name  # pylint: disable=protected-access
205  else:
206    prefix = primary.op.name
207  with variable_scope.variable_scope(None, prefix + "/" + name):
208    if colocate_with_primary:
209      distribution_strategy = distribution_strategy_context.get_strategy()
210      with distribution_strategy.extended.colocate_vars_with(primary):
211        return _create_slot_var(
212            primary,
213            initializer,
214            "",
215            validate_shape,
216            shape,
217            dtype,
218            copy_xla_sharding=copy_xla_sharding)
219    else:
220      return _create_slot_var(
221          primary,
222          initializer,
223          "",
224          validate_shape,
225          shape,
226          dtype,
227          copy_xla_sharding=copy_xla_sharding)
228
229
230def create_zeros_slot(primary,
231                      name,
232                      dtype=None,
233                      colocate_with_primary=True,
234                      *,
235                      copy_xla_sharding=False):
236  """Create a slot initialized to 0 with same shape as the primary object.
237
238  Args:
239    primary: The primary `Variable` or `Tensor`.
240    name: Name to use for the slot variable.
241    dtype: Type of the slot variable.  Defaults to the type of `primary`.
242    colocate_with_primary: Boolean.  If True the slot is located
243      on the same device as `primary`.
244    copy_xla_sharding: Boolean. If True also copies XLA sharding
245      from primary.
246
247  Returns:
248    A `Variable` object.
249  """
250  if dtype is None:
251    dtype = primary.dtype
252  slot_shape = primary.get_shape()
253  if slot_shape.is_fully_defined():
254    initializer = init_ops.zeros_initializer()
255    return create_slot_with_initializer(
256        primary,
257        initializer,
258        slot_shape,
259        dtype,
260        name,
261        colocate_with_primary=colocate_with_primary,
262        copy_xla_sharding=copy_xla_sharding)
263  else:
264    if isinstance(primary, variables.Variable):
265      slot_shape = array_ops.shape(primary.initialized_value())
266    else:
267      slot_shape = array_ops.shape(primary)
268    val = array_ops.zeros(slot_shape, dtype=dtype)
269    return create_slot(
270        primary,
271        val,
272        name,
273        colocate_with_primary=colocate_with_primary,
274        copy_xla_sharding=copy_xla_sharding)
275