1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Distributed variable implementation for TPUs.
16
17N.B. This is an experimental feature that should only be used for Keras support.
18
19It is unsupported and will be removed in favor of Distribution Strategy soon.
20"""
21
22from __future__ import absolute_import
23from __future__ import division
24from __future__ import print_function
25
26import contextlib
27
28import numpy as np
29
30from tensorflow.python.client import session as session_lib
31from tensorflow.python.framework import dtypes as dtypes_module
32from tensorflow.python.framework import ops
33from tensorflow.python.keras import backend
34from tensorflow.python.ops import control_flow_ops
35from tensorflow.python.ops import gen_resource_variable_ops
36from tensorflow.python.ops import math_ops
37from tensorflow.python.ops import resource_variable_ops
38from tensorflow.python.ops import variable_scope
39
40
41@contextlib.contextmanager
42def _handle_graph(handle):
43  with handle.graph.as_default():
44    yield
45
46
47def _enclosing_tpu_context():
48  # pylint: disable=protected-access
49  context = ops.get_default_graph()._get_control_flow_context()
50  # pylint: enable=protected-access
51  while context is not None and not isinstance(
52      context, control_flow_ops.XLAControlFlowContext):
53    context = context.outer_context
54  return context
55
56
57class ReplicatedVariable(object):
58  """A replicated variable for use on TPUs.
59
60  When accessed inside a tpu.replicate() context, this variable acts as if it
61  is a single variable whose handle is a replicated input to the computation.
62
63  Outside a tpu.replicate() context currently this object has pretty murky
64  semantics, especially with respect to things such as
65  * initialization
66  * colocation.
67  """
68
69  def __init__(self, name, variables):
70    self._name = name
71    self._primary_var = variables[0]
72    self._common_name = self._primary_var.name.split(":")[0]
73    self._vars = variables
74    self._cached_value = None
75    self._dtype = variables[0].dtype
76
77  @property
78  def handle(self):
79    tpu_context = _enclosing_tpu_context()
80    if tpu_context is None:
81      return self._primary_var.handle
82
83    return tpu_context.get_replicated_var_handle(self._name, self._vars)
84
85  @contextlib.contextmanager
86  def _assign_dependencies(self):
87    """Makes assignments depend on the cached value, if any.
88
89    This prevents undefined behavior with reads not ordered wrt writes.
90
91    Yields:
92      None.
93    """
94    if self._cached_value is not None:
95      with ops.control_dependencies([self._cached_value]):
96        yield
97    else:
98      yield
99
100  @property
101  def initializer(self):
102    return control_flow_ops.group([v.initializer for v in self._vars])
103
104  @property
105  def graph(self):
106    return self._primary_var.graph
107
108  @property
109  def _shared_name(self):
110    return self._common_name
111
112  @property
113  def _unique_id(self):
114    return self._primary_var._unique_id  # pylint: disable=protected-access
115
116  @property
117  def name(self):
118    return self._name
119
120  @property
121  def dtype(self):
122    return self._primary_var.dtype
123
124  @property
125  def shape(self):
126    return self._primary_var.shape
127
128  def get_shape(self):
129    return self._primary_var.get_shape()
130
131  def to_proto(self, export_scope=None):
132    return self._primary_var.to_proto(export_scope=export_scope)
133
134  @property
135  def constraint(self):
136    return None
137
138  @property
139  def op(self):
140    return self.get().op
141
142  @property
143  def is_tensor_like(self):
144    return True
145
146  def _read_variable_op(self):
147    if _enclosing_tpu_context() is None:
148      return self._primary_var.read_value()
149    v = gen_resource_variable_ops.read_variable_op(self.handle, self._dtype)
150    return v
151
152  def read_value(self):
153    return self._read_variable_op()
154
155  def is_initialized(self, name=None):
156    return self._vars[0].is_initialized(name=name)
157
158  def __getitem__(self, *args):
159    return self.read_value().__getitem__(*args)
160
161  def assign(self, value, use_locking=None, name=None, read_value=False):
162    """Assign `value` to all replicas.
163
164    Outside of the tpu.rewrite context, assign explicitly to all replicas.
165    Inside of the tpu.rewrite context, assigns to the local replica.
166
167    Arguments:
168      value: Tensor to assign
169      use_locking: ignored
170      name: ignored
171      read_value: return the value from the assignment
172    Returns:
173      Assignment operation, or new value of the variable if `read_value` is True
174    """
175    del use_locking
176    if _enclosing_tpu_context() is None:
177      assign_ops = []
178      with self._assign_dependencies():
179        for var in self._vars:
180          assign_ops.append(var.assign(value, use_locking=None, name=name))
181
182        if read_value:
183          with ops.control_dependencies(assign_ops):
184            return self.read_value()
185        else:
186          return control_flow_ops.group(assign_ops)
187
188    with _handle_graph(self.handle), self._assign_dependencies():
189      value_tensor = ops.convert_to_tensor(value, dtype=self.dtype)
190      assign_op = gen_resource_variable_ops.assign_variable_op(
191          self.handle, value_tensor, name=name)
192    if read_value:
193      return self._read_variable_op()
194    return assign_op
195
196  def assign_add(self, delta, use_locking=None, name=None, read_value=True):
197    del use_locking
198    with _handle_graph(self.handle), self._assign_dependencies():
199      assign_add_op = gen_resource_variable_ops.assign_add_variable_op(
200          self.handle,
201          ops.convert_to_tensor(delta, dtype=self.dtype),
202          name=name)
203    if read_value:
204      return self._read_variable_op()
205    return assign_add_op
206
207  def assign_sub(self, delta, use_locking=None, name=None, read_value=True):
208    del use_locking
209    with _handle_graph(self.handle), self._assign_dependencies():
210      assign_sub_op = gen_resource_variable_ops.assign_sub_variable_op(
211          self.handle,
212          ops.convert_to_tensor(delta, dtype=self.dtype),
213          name=name)
214    if read_value:
215      return self._read_variable_op()
216    return assign_sub_op
217
218  def get(self):
219    return self._primary_var
220
221  @property
222  def _in_graph_mode(self):
223    return self._primary_var._in_graph_mode   # pylint: disable=protected-access
224
225  def _should_act_as_resource_variable(self):
226    """Pass resource_variable_ops.is_resource_variable check."""
227    pass
228
229  def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
230    """Converts a variable to a tensor."""
231    # pylint: disable=protected-access
232    if _enclosing_tpu_context() is None:
233      return self._primary_var._dense_var_to_tensor(dtype, name, as_ref)
234    # pylint: enable=protected-access
235    if dtype is not None and dtype != self.dtype:
236      return math_ops.cast(self._read_variable_op(), dtype)
237    if as_ref:
238      return self.handle
239    else:
240      return self.read_value()
241
242
243# Register a conversion function which reads the value of the variable,
244# allowing instances of the class to be used as tensors.
245def _tensor_conversion(var, dtype=None, name=None, as_ref=False):
246  return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref)  # pylint: disable=protected-access
247
248
249def replicated_fetch_function(var):
250  # pylint: disable=protected-access
251  return ([var._dense_var_to_tensor()], lambda v: v[0])
252  # pylint: enable=protected-access
253
254
255ops.register_tensor_conversion_function(ReplicatedVariable, _tensor_conversion)
256ops.register_dense_tensor_like_type(ReplicatedVariable)
257session_lib.register_session_run_conversion_functions(
258    ReplicatedVariable, replicated_fetch_function)
259
260
261def replicated_scope(num_replicas):
262  """Variable scope for constructing replicated variables."""
263
264  def _replicated_variable_getter(getter, name, *args, **kwargs):
265    """Getter that constructs replicated variables."""
266    collections = kwargs.pop("collections", None)
267    if collections is None:
268      collections = [ops.GraphKeys.GLOBAL_VARIABLES]
269    kwargs["collections"] = []
270
271    variables = []
272    index = {}
273    for i in range(num_replicas):
274      replica_name = "{}/{}".format(name, i)
275      with ops.device("device:TPU:{}".format(i)):
276        v = getter(*args, name=replica_name, **kwargs)
277        variables.append(v)
278      index[i] = v
279    result = ReplicatedVariable(name, variables)
280
281    g = ops.get_default_graph()
282    # If "trainable" is True, next_creator() will add the member variables
283    # to the TRAINABLE_VARIABLES collection, so we manually remove
284    # them and replace with the MirroredVariable. We can't set
285    # "trainable" to False for next_creator() since that causes functions
286    # like implicit_gradients to skip those variables.
287    if kwargs.get("trainable", True):
288      collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
289      l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
290      for v in index.values():
291        if v in l:
292          l.remove(v)
293    g.add_to_collections(collections, result)
294
295    return result
296
297  return variable_scope.variable_scope(
298      "", custom_getter=_replicated_variable_getter)
299
300
301@contextlib.contextmanager
302def replicated_variable_for_optimizer(num_replicas):
303  """Context manager for optimizer weights. Overrides K.variable."""
304  if num_replicas == 1:
305    yield
306    return
307
308  try:
309    old_v = backend.variable
310
311    def opt_variable(value, dtype=None, name=None, constraint=None):
312      """Instantiates a variable and returns it."""
313      if dtype is None:
314        dtype = backend.floatx()
315
316      variables = []
317      for i in range(num_replicas):
318        # Keras holds the variables in optimizer class instance , so the name
319        # does not matter here. ResourceVariable constructor will find a unique
320        # name (including name=None) for each replica.
321        with ops.device("device:TPU:{}".format(i)):
322          v = resource_variable_ops.ResourceVariable(
323              value,
324              dtype=dtypes_module.as_dtype(dtype),
325              name=name,
326              constraint=constraint)
327          variables.append(v)
328      name = "replicate_{}_{}".format("variable" if name is None else name,
329                                      ops.uid())
330      v = ReplicatedVariable(name, variables)
331
332      # pylint: disable=protected-access
333
334      if isinstance(value, np.ndarray):
335        v._keras_shape = value.shape
336      elif hasattr(value, "shape"):
337        v._keras_shape = backend.int_shape(value)
338      v._uses_learning_phase = False
339      backend.track_variable(v)
340      return v
341
342    backend.variable = opt_variable
343    yield
344
345  finally:
346    backend.variable = old_v
347